/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/JVPCloner.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- JVPCloner.cpp - JVP function generation --------------*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // This file defines a helper class for generating JVP functions for automatic |
14 | | // differentiation. |
15 | | // |
16 | | //===----------------------------------------------------------------------===// |
17 | | |
18 | | #define DEBUG_TYPE "differentiation" |
19 | | |
20 | | #include "swift/SILOptimizer/Differentiation/JVPCloner.h" |
21 | | #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
22 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
23 | | #include "swift/SILOptimizer/Differentiation/AdjointValue.h" |
24 | | #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" |
25 | | #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" |
26 | | #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" |
27 | | #include "swift/SILOptimizer/Differentiation/Thunk.h" |
28 | | |
29 | | #include "swift/SIL/LoopInfo.h" |
30 | | #include "swift/SIL/TypeSubstCloner.h" |
31 | | #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" |
32 | | #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" |
33 | | #include "swift/SILOptimizer/Utils/DifferentiationMangler.h" |
34 | | #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
35 | | #include "llvm/ADT/DenseMap.h" |
36 | | |
37 | | using namespace swift; |
38 | | using namespace autodiff; |
39 | | |
40 | | namespace swift { |
41 | | namespace autodiff { |
42 | | |
43 | | class JVPCloner::Implementation final |
44 | | : public TypeSubstCloner<JVPCloner::Implementation, SILOptFunctionBuilder> { |
45 | | private: |
46 | | /// The global context. |
47 | | ADContext &context; |
48 | | |
49 | | /// The original function. |
50 | | SILFunction *const original; |
51 | | |
52 | | /// The witness. |
53 | | SILDifferentiabilityWitness *const witness; |
54 | | |
55 | | /// The JVP function. |
56 | | SILFunction *const jvp; |
57 | | |
58 | | llvm::BumpPtrAllocator allocator; |
59 | | |
60 | | /// The differentiation invoker. |
61 | | DifferentiationInvoker invoker; |
62 | | |
63 | | /// Info from activity analysis on the original function. |
64 | | const DifferentiableActivityInfo &activityInfo; |
65 | | |
66 | | /// The loop info. |
67 | | SILLoopInfo *loopInfo; |
68 | | |
69 | | /// The differential info. |
70 | | LinearMapInfo differentialInfo; |
71 | | |
72 | | bool errorOccurred = false; |
73 | | |
74 | | //--------------------------------------------------------------------------// |
75 | | // Differential generation related fields |
76 | | //--------------------------------------------------------------------------// |
77 | | |
78 | | /// The builder for the differential function. |
79 | | TangentBuilder differentialBuilder; |
80 | | |
81 | | /// Mapping from original basic blocks to corresponding differential basic |
82 | | /// blocks. |
83 | | llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> diffBBMap; |
84 | | |
85 | | /// Mapping from original basic blocks and original values to corresponding |
86 | | /// tangent values. |
87 | | llvm::DenseMap<SILValue, AdjointValue> tangentValueMap; |
88 | | |
89 | | /// Mapping from original basic blocks and original buffers to corresponding |
90 | | /// tangent buffers. |
91 | | llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap; |
92 | | |
93 | | /// Mapping from differential basic blocks to differential struct arguments. |
94 | | llvm::DenseMap<SILBasicBlock *, SILArgument *> differentialStructArguments; |
95 | | |
96 | | /// Mapping from differential struct field declarations to differential struct |
97 | | /// elements destructured from the linear map basic block argument. In the |
98 | | /// beginning of each differential basic block, the block's differential |
99 | | /// struct is destructured into the individual elements stored here. |
100 | | llvm::DenseMap<SILBasicBlock *, SILInstructionResultArray> differentialTupleElements; |
101 | | |
102 | | /// An auxiliary differential local allocation builder. |
103 | | TangentBuilder diffLocalAllocBuilder; |
104 | | |
105 | | /// Stack buffers allocated for storing local tangent values. |
106 | | SmallVector<SILValue, 8> differentialLocalAllocations; |
107 | | |
108 | | /// Mapping from original blocks to differential values. Used to build |
109 | | /// differential struct instances. |
110 | | llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> differentialValues; |
111 | | |
112 | | //--------------------------------------------------------------------------// |
113 | | // Getters |
114 | | //--------------------------------------------------------------------------// |
115 | | |
116 | 3.20k | ASTContext &getASTContext() const { return jvp->getASTContext(); } |
117 | 13.5k | SILModule &getModule() const { return jvp->getModule(); } |
118 | 9.85k | const AutoDiffConfig getConfig() const { return witness->getConfig(); } |
119 | 19.6k | TangentBuilder &getDifferentialBuilder() { return differentialBuilder; } |
120 | 24.4k | SILFunction &getDifferential() { return differentialBuilder.getFunction(); } |
121 | 0 | SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) { |
122 | 0 | return differentialStructArguments[origBB]; |
123 | 0 | } |
124 | | |
125 | | //--------------------------------------------------------------------------// |
126 | | // Differential tuple mapping |
127 | | //--------------------------------------------------------------------------// |
128 | | |
129 | | void initializeDifferentialTupleElements(SILBasicBlock *origBB, |
130 | | SILInstructionResultArray values); |
131 | | |
132 | | SILValue getDifferentialTupleElement(ApplyInst *ai); |
133 | | |
134 | | //--------------------------------------------------------------------------// |
135 | | // General utilities |
136 | | //--------------------------------------------------------------------------// |
137 | | |
138 | | /// Get the lowered SIL type of the given AST type. |
139 | 1.60k | SILType getLoweredType(Type type) { |
140 | 1.60k | auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature(); |
141 | 1.60k | Lowering::AbstractionPattern pattern(jvpGenSig, |
142 | 1.60k | type->getReducedType(jvpGenSig)); |
143 | 1.60k | return jvp->getLoweredType(pattern, type); |
144 | 1.60k | } |
145 | | |
146 | | /// Get the lowered SIL type of the given nominal type declaration. |
147 | 0 | SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { |
148 | 0 | auto nominalType = |
149 | 0 | getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); |
150 | 0 | return getLoweredType(nominalType); |
151 | 0 | } |
152 | | |
153 | | /// Build a differential struct value for the original block corresponding to |
154 | | /// the given terminator. |
155 | 1.33k | TupleInst *buildDifferentialValueStructValue(TermInst *termInst) { |
156 | 1.33k | assert(termInst->getFunction() == original); |
157 | 0 | auto loc = termInst->getFunction()->getLocation(); |
158 | 1.33k | auto *origBB = termInst->getParent(); |
159 | 1.33k | auto *jvpBB = BBMap[origBB]; |
160 | 1.33k | assert(jvpBB && "Basic block mapping should exist"); |
161 | 0 | auto tupleLoweredTy = |
162 | 1.33k | remapType(differentialInfo.getLinearMapTupleLoweredType(origBB)); |
163 | 1.33k | auto bbDifferentialValues = differentialValues[origBB]; |
164 | 1.33k | if (!origBB->isEntry()) { |
165 | 0 | auto *enumArg = jvpBB->getArguments().back(); |
166 | 0 | bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg); |
167 | 0 | } |
168 | 1.33k | return getBuilder().createTuple(loc, tupleLoweredTy, |
169 | 1.33k | bbDifferentialValues); |
170 | 1.33k | } |
171 | | |
172 | | //--------------------------------------------------------------------------// |
173 | | // Tangent value factory methods |
174 | | //--------------------------------------------------------------------------// |
175 | | |
176 | 3.95k | AdjointValue makeZeroTangentValue(SILType type) { |
177 | 3.95k | return AdjointValue::createZero(allocator, |
178 | 3.95k | remapSILTypeInDifferential(type)); |
179 | 3.95k | } |
180 | | |
181 | 3.64k | AdjointValue makeConcreteTangentValue(SILValue value) { |
182 | 3.64k | return AdjointValue::createConcrete(allocator, value); |
183 | 3.64k | } |
184 | | |
185 | | //--------------------------------------------------------------------------// |
186 | | // Tangent materialization |
187 | | //--------------------------------------------------------------------------// |
188 | | |
189 | 92 | void emitZeroIndirect(CanType type, SILValue buffer, SILLocation loc) { |
190 | 92 | auto builder = getDifferentialBuilder(); |
191 | 92 | auto tangentSpace = getTangentSpace(type); |
192 | 92 | assert(tangentSpace && "No tangent space for this type"); |
193 | 0 | switch (tangentSpace->getKind()) { |
194 | 92 | case TangentSpace::Kind::TangentVector: |
195 | 92 | builder.emitZeroIntoBuffer(loc, buffer, IsInitialization); |
196 | 92 | return; |
197 | 0 | case TangentSpace::Kind::Tuple: { |
198 | 0 | auto tupleType = tangentSpace->getTuple(); |
199 | 0 | SmallVector<SILValue, 8> zeroElements; |
200 | 0 | for (unsigned i : range(tupleType->getNumElements())) { |
201 | 0 | auto eltAddr = builder.createTupleElementAddr(loc, buffer, i); |
202 | 0 | emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), |
203 | 0 | eltAddr, loc); |
204 | 0 | } |
205 | 0 | return; |
206 | 0 | } |
207 | 92 | } |
208 | 92 | } |
209 | | |
210 | 64 | SILValue emitZeroDirect(CanType type, SILLocation loc) { |
211 | 64 | auto diffBuilder = getDifferentialBuilder(); |
212 | 64 | auto silType = getModule().Types.getLoweredLoadableType( |
213 | 64 | type, TypeExpansionContext::minimal(), getModule()); |
214 | 64 | auto *buffer = diffBuilder.createAllocStack(loc, silType); |
215 | 64 | emitZeroIndirect(type, buffer, loc); |
216 | 64 | auto loaded = diffBuilder.emitLoadValueOperation( |
217 | 64 | loc, buffer, LoadOwnershipQualifier::Take); |
218 | 64 | diffBuilder.createDeallocStack(loc, buffer); |
219 | 64 | return loaded; |
220 | 64 | } |
221 | | |
222 | 60 | SILValue materializeTangentDirect(AdjointValue val, SILLocation loc) { |
223 | 60 | assert(val.getType().isObject()); |
224 | 60 | LLVM_DEBUG(getADDebugStream() |
225 | 60 | << "Materializing tangents for " << val << '\n'); |
226 | 60 | switch (val.getKind()) { |
227 | 60 | case AdjointValueKind::Zero: { |
228 | 60 | auto zeroVal = emitZeroDirect(val.getSwiftType(), loc); |
229 | 60 | return zeroVal; |
230 | 0 | } |
231 | 0 | case AdjointValueKind::Concrete: |
232 | 0 | return val.getConcreteValue(); |
233 | 0 | case AdjointValueKind::Aggregate: |
234 | 0 | case AdjointValueKind::AddElement: |
235 | 0 | llvm_unreachable( |
236 | 60 | "Tuples and structs are not supported in forward mode yet."); |
237 | 60 | } |
238 | 0 | llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 |
239 | 0 | } |
240 | | |
241 | 3.92k | SILValue materializeTangent(AdjointValue val, SILLocation loc) { |
242 | 3.92k | if (val.isConcrete()) { |
243 | 3.86k | LLVM_DEBUG(getADDebugStream() |
244 | 3.86k | << "Materializing tangent: Value is concrete.\n"); |
245 | 3.86k | return val.getConcreteValue(); |
246 | 3.86k | } |
247 | 60 | LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is " |
248 | 60 | "non-concrete. Materializing directly.\n"); |
249 | 60 | return materializeTangentDirect(val, loc); |
250 | 3.92k | } |
251 | | |
252 | | //--------------------------------------------------------------------------// |
253 | | // Tangent value mapping |
254 | | //--------------------------------------------------------------------------// |
255 | | |
256 | | /// Get the tangent for an original value. The given value must be in the |
257 | | /// original function. |
258 | | /// |
259 | | /// This method first tries to find an entry in `tangentValueMap`. If an entry |
260 | | /// doesn't exist, create a zero tangent. |
261 | 3.95k | AdjointValue getTangentValue(SILValue originalValue) { |
262 | 3.95k | assert(originalValue->getType().isObject()); |
263 | 0 | assert(originalValue->getFunction() == original); |
264 | 0 | auto insertion = tangentValueMap.try_emplace( |
265 | 3.95k | originalValue, |
266 | 3.95k | makeZeroTangentValue(getRemappedTangentType(originalValue->getType()))); |
267 | 3.95k | return insertion.first->getSecond(); |
268 | 3.95k | } |
269 | | |
270 | | /// Map the tangent value to the given original value. |
271 | | void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, |
272 | 3.64k | AdjointValue newTangentValue) { |
273 | 3.64k | #ifndef NDEBUG |
274 | 3.64k | if (auto *defInst = originalValue->getDefiningInstruction()) { |
275 | 1.88k | bool isTupleTypedApplyResult = |
276 | 1.88k | isa<ApplyInst>(defInst) && originalValue->getType().is<TupleType>(); |
277 | 1.88k | assert(!isTupleTypedApplyResult && |
278 | 1.88k | "Should not set tangent value for tuple-typed result from `apply` " |
279 | 1.88k | "instruction; use `destructure_tuple` on `apply` result and set " |
280 | 1.88k | "tangent value for `destructure_tuple` results instead."); |
281 | 1.88k | } |
282 | 0 | #endif |
283 | 0 | assert(originalValue->getType().isObject()); |
284 | 0 | assert(newTangentValue.getType().isObject()); |
285 | 0 | assert(originalValue->getFunction() == original); |
286 | 3.64k | LLVM_DEBUG(getADDebugStream() |
287 | 3.64k | << "Setting tangent value for " << originalValue); |
288 | | // The tangent value must be in the tangent space. |
289 | 3.64k | assert(newTangentValue.getType() == |
290 | 3.64k | getRemappedTangentType(originalValue->getType())); |
291 | 0 | auto insertion = |
292 | 3.64k | tangentValueMap.try_emplace(originalValue, newTangentValue); |
293 | 3.64k | (void)insertion; |
294 | 3.64k | assert(insertion.second && "The tangent value should not already exist."); |
295 | 3.64k | } |
296 | | |
297 | | //--------------------------------------------------------------------------// |
298 | | // Tangent buffer mapping |
299 | | //--------------------------------------------------------------------------// |
300 | | |
301 | | /// Sets the tangent buffer for the original buffer. Asserts that the |
302 | | /// original buffer does not already have a tangent buffer. |
303 | | void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, |
304 | 2.59k | SILValue tangentBuffer) { |
305 | 2.59k | assert(originalBuffer->getType().isAddress()); |
306 | 0 | auto insertion = |
307 | 2.59k | bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); |
308 | 2.59k | assert(insertion.second && "Tangent buffer already exists"); |
309 | 0 | (void)insertion; |
310 | 2.59k | } |
311 | | |
312 | | /// Returns the tangent buffer for the original buffer. Asserts that the |
313 | | /// original buffer has a tangent buffer. |
314 | 6.06k | SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { |
315 | 6.06k | assert(originalBuffer->getType().isAddress()); |
316 | 0 | assert(originalBuffer->getFunction() == original); |
317 | 0 | auto it = bufferMap.find({origBB, originalBuffer}); |
318 | 6.06k | assert(it != bufferMap.end() && "Tangent buffer should already exist"); |
319 | 0 | return it->getSecond(); |
320 | 6.06k | } |
321 | | |
322 | | //--------------------------------------------------------------------------// |
323 | | // Differential type calculations |
324 | | //--------------------------------------------------------------------------// |
325 | | |
326 | | /// Substitutes all replacement types of the given substitution map using the |
327 | | /// tangent function's substitution map. |
328 | 0 | SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) { |
329 | 0 | return substMap.subst(getDifferential().getForwardingSubstitutionMap()); |
330 | 0 | } |
331 | | |
332 | | /// Remap any archetypes into the differential function's context. |
333 | 0 | Type remapTypeInDifferential(Type ty) { |
334 | 0 | if (ty->hasArchetype()) |
335 | 0 | return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); |
336 | 0 | return getDifferential().mapTypeIntoContext(ty); |
337 | 0 | } |
338 | | |
339 | | /// Remap any archetypes into the differential function's context. |
340 | 16.3k | SILType remapSILTypeInDifferential(SILType ty) { |
341 | 16.3k | if (ty.hasArchetype()) |
342 | 592 | return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); |
343 | 15.7k | return getDifferential().mapTypeIntoContext(ty); |
344 | 16.3k | } |
345 | | |
346 | | /// Find the tangent space of a given canonical type. |
347 | 9.33k | llvm::Optional<TangentSpace> getTangentSpace(CanType type) { |
348 | | // Use witness generic signature to remap types. |
349 | 9.33k | type = witness->getDerivativeGenericSignature().getReducedType( |
350 | 9.33k | type); |
351 | 9.33k | return type->getAutoDiffTangentSpace( |
352 | 9.33k | LookUpConformanceInModule(getModule().getSwiftModule())); |
353 | 9.33k | } |
354 | | |
355 | | /// Assuming the given type conforms to `Differentiable` after remapping, |
356 | | /// returns the associated tangent space SIL type. |
357 | 9.07k | SILType getRemappedTangentType(SILType type) { |
358 | 9.07k | return SILType::getPrimitiveType( |
359 | 9.07k | getTangentSpace(remapSILTypeInDifferential(type).getASTType()) |
360 | 9.07k | ->getCanonicalType(), |
361 | 9.07k | type.getCategory()); |
362 | 9.07k | } |
363 | | |
364 | | /// Set up the differential function. This includes: |
365 | | /// - Creating all differential blocks. |
366 | | /// - Creating differential entry block arguments based on the function type. |
367 | | /// - Creating tangent value mapping for original/differential parameters. |
368 | | /// - Checking for unvaried result and emitting related warnings. |
369 | | void prepareForDifferentialGeneration(); |
370 | | |
371 | | public: |
372 | | explicit Implementation(ADContext &context, |
373 | | SILDifferentiabilityWitness *witness, |
374 | | SILFunction *jvp, DifferentiationInvoker invoker); |
375 | | |
376 | | static SILFunction * |
377 | | createEmptyDifferential(ADContext &context, |
378 | | SILDifferentiabilityWitness *witness, |
379 | | LinearMapInfo *linearMapInfo); |
380 | | |
381 | | /// Run JVP generation. Returns true on error. |
382 | | bool run(); |
383 | | |
384 | 1.33k | SILFunction &getJVP() const { return *jvp; } |
385 | | |
386 | 14.2k | void postProcess(SILInstruction *orig, SILInstruction *cloned) { |
387 | 14.2k | if (errorOccurred) |
388 | 0 | return; |
389 | 14.2k | SILClonerWithScopes::postProcess(orig, cloned); |
390 | 14.2k | } |
391 | | |
392 | | /// Remap original basic blocks. |
393 | 0 | SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { |
394 | 0 | auto *jvpBB = BBMap[bb]; |
395 | 0 | return jvpBB; |
396 | 0 | } |
397 | | |
398 | | /// General visitor for all instructions. If any error is emitted by previous |
399 | | /// visits, bail out. |
400 | 17.2k | void visit(SILInstruction *inst) { |
401 | 17.2k | if (errorOccurred) |
402 | 60 | return; |
403 | 17.2k | if (differentialInfo.shouldDifferentiateInstruction(inst)) { |
404 | 7.58k | LLVM_DEBUG(getADDebugStream() << "JVPCloner visited:\n[ORIG]" << *inst); |
405 | 7.58k | #ifndef NDEBUG |
406 | 7.58k | auto diffBuilder = getDifferentialBuilder(); |
407 | 7.58k | auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); |
408 | 7.58k | #endif |
409 | 7.58k | TypeSubstCloner::visit(inst); |
410 | 7.58k | LLVM_DEBUG({ |
411 | 7.58k | auto &s = llvm::dbgs() << "[TAN] Emitted in differential:\n"; |
412 | 7.58k | auto afterInsertion = diffBuilder.getInsertionPoint(); |
413 | 7.58k | for (auto it = ++beforeInsertion; it != afterInsertion; ++it) |
414 | 7.58k | s << *it; |
415 | 7.58k | }); |
416 | 9.65k | } else { |
417 | 9.65k | TypeSubstCloner::visit(inst); |
418 | 9.65k | } |
419 | 17.2k | } |
420 | | |
421 | 0 | void visitSILInstruction(SILInstruction *inst) { |
422 | 0 | context.emitNondifferentiabilityError( |
423 | 0 | inst, invoker, diag::autodiff_expression_not_differentiable_note); |
424 | 0 | errorOccurred = true; |
425 | 0 | } |
426 | | |
427 | 1.35k | void visitInstructionsInBlock(SILBasicBlock *bb) { |
428 | | // Destructure the differential struct to get the elements. |
429 | 1.35k | auto &diffBuilder = getDifferentialBuilder(); |
430 | 1.35k | auto diffLoc = getDifferential().getLocation(); |
431 | 1.35k | auto *diffBB = diffBBMap.lookup(bb); |
432 | 1.35k | auto *mainDifferentialStruct = diffBB->getArguments().back(); |
433 | 1.35k | diffBuilder.setInsertionPoint(diffBB); |
434 | 1.35k | auto *dsi = |
435 | 1.35k | diffBuilder.createDestructureTuple(diffLoc, mainDifferentialStruct); |
436 | 1.35k | initializeDifferentialTupleElements(bb, dsi->getResults()); |
437 | 1.35k | TypeSubstCloner::visitInstructionsInBlock(bb); |
438 | 1.35k | } |
439 | | |
440 | | // If an `apply` has active results or active inout parameters, replace it |
441 | | // with an `apply` of its JVP. |
442 | 2.06k | void visitApplyInst(ApplyInst *ai) { |
443 | 2.06k | bool shouldDifferentiate = |
444 | 2.06k | differentialInfo.shouldDifferentiateApplySite(ai); |
445 | | // If the function has no active arguments or results, zero-initialize the |
446 | | // tangent buffers of the active indirect results. |
447 | 2.06k | if (!shouldDifferentiate) { |
448 | 460 | for (auto indResult : ai->getIndirectSILResults()) |
449 | 44 | if (activityInfo.isActive(indResult, getConfig())) { |
450 | 20 | auto &tanBuf = getTangentBuffer(ai->getParent(), indResult); |
451 | 20 | emitZeroIndirect(tanBuf->getType().getASTType(), tanBuf, |
452 | 20 | tanBuf.getLoc()); |
453 | 20 | } |
454 | 460 | } |
455 | | // If the function should not be differentiated or its the array literal |
456 | | // initialization intrinsic, just do standard cloning. |
457 | 2.06k | if (!shouldDifferentiate || |
458 | 2.06k | ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { |
459 | 460 | LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); |
460 | 460 | TypeSubstCloner::visitApplyInst(ai); |
461 | 460 | return; |
462 | 460 | } |
463 | | |
464 | 1.60k | auto loc = ai->getLoc(); |
465 | 1.60k | auto &builder = getBuilder(); |
466 | 1.60k | auto origCallee = getOpValue(ai->getCallee()); |
467 | 1.60k | auto originalFnTy = origCallee->getType().castTo<SILFunctionType>(); |
468 | | |
469 | 1.60k | LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); |
470 | | |
471 | | // Get the minimal parameter and result indices required for differentiating |
472 | | // this `apply`. |
473 | 1.60k | SmallVector<SILValue, 4> allResults; |
474 | 1.60k | SmallVector<unsigned, 8> activeParamIndices; |
475 | 1.60k | SmallVector<unsigned, 8> activeResultIndices; |
476 | 1.60k | collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo, |
477 | 1.60k | allResults, activeParamIndices, |
478 | 1.60k | activeResultIndices); |
479 | 1.60k | assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); |
480 | 0 | assert(!activeResultIndices.empty() && "Result indices cannot be empty"); |
481 | 1.60k | LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; |
482 | 1.60k | llvm::interleave( |
483 | 1.60k | activeParamIndices.begin(), activeParamIndices.end(), |
484 | 1.60k | [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
485 | 1.60k | s << "}, results={"; llvm::interleave( |
486 | 1.60k | activeResultIndices.begin(), activeResultIndices.end(), |
487 | 1.60k | [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
488 | 1.60k | s << "}\n";); |
489 | | |
490 | | // Form expected indices. |
491 | 1.60k | auto numResults = |
492 | 1.60k | ai->getSubstCalleeType()->getNumResults() + |
493 | 1.60k | ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); |
494 | 1.60k | AutoDiffConfig config( |
495 | 1.60k | IndexSubset::get(getASTContext(), |
496 | 1.60k | ai->getArgumentsWithoutIndirectResults().size(), |
497 | 1.60k | activeParamIndices), |
498 | 1.60k | IndexSubset::get(getASTContext(), numResults, activeResultIndices)); |
499 | | |
500 | | // Emit the JVP. |
501 | 1.60k | SILValue jvpValue; |
502 | | // If functionSource is a `@differentiable` function, just extract it. |
503 | 1.60k | if (originalFnTy->isDifferentiable()) { |
504 | 24 | auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); |
505 | 36 | for (auto i : config.parameterIndices->getIndices()) { |
506 | 36 | if (!paramIndices->contains(i)) { |
507 | 0 | context.emitNondifferentiabilityError( |
508 | 0 | origCallee, invoker, |
509 | 0 | diag:: |
510 | 0 | autodiff_function_noderivative_parameter_not_differentiable); |
511 | 0 | errorOccurred = true; |
512 | 0 | return; |
513 | 0 | } |
514 | 36 | } |
515 | 24 | builder.emitScopedBorrowOperation( |
516 | 24 | loc, origCallee, [&](SILValue borrowedDiffFunc) { |
517 | 24 | jvpValue = builder.createDifferentiableFunctionExtract( |
518 | 24 | loc, NormalDifferentiableFunctionTypeComponent::JVP, |
519 | 24 | borrowedDiffFunc); |
520 | 24 | jvpValue = builder.emitCopyValueOperation(loc, jvpValue); |
521 | 24 | }); |
522 | 24 | } |
523 | | |
524 | | // If JVP has not yet been found, emit an `differentiable_function` |
525 | | // instruction on the remapped function operand and |
526 | | // an `differentiable_function_extract` instruction to get the JVP. |
527 | | // The `differentiable_function` instruction will be canonicalized during |
528 | | // the transform main loop. |
529 | 1.60k | if (!jvpValue) { |
530 | | // FIXME: Handle indirect differentiation invokers. This may require some |
531 | | // redesign: currently, each original function + witness pair is mapped |
532 | | // only to one invoker. |
533 | | /* |
534 | | DifferentiationInvoker indirect(ai, attr); |
535 | | auto insertion = |
536 | | context.getInvokers().try_emplace({original, attr}, indirect); |
537 | | auto &invoker = insertion.first->getSecond(); |
538 | | invoker = indirect; |
539 | | */ |
540 | | |
541 | | // If the original `apply` instruction has a substitution map, then the |
542 | | // applied function is specialized. |
543 | | // In the JVP, specialization is also necessary for parity. The original |
544 | | // function operand is specialized with a remapped version of same |
545 | | // substitution map using an argument-less `partial_apply`. |
546 | 1.58k | if (ai->getSubstitutionMap().empty()) { |
547 | 984 | origCallee = builder.emitCopyValueOperation(loc, origCallee); |
548 | 984 | } else { |
549 | 596 | auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); |
550 | 596 | auto jvpPartialApply = getBuilder().createPartialApply( |
551 | 596 | ai->getLoc(), origCallee, substMap, {}, |
552 | 596 | ParameterConvention::Direct_Guaranteed); |
553 | 596 | origCallee = jvpPartialApply; |
554 | 596 | } |
555 | | |
556 | | // Check and diagnose non-differentiable original function type. |
557 | 1.58k | auto diagnoseNondifferentiableOriginalFunctionType = |
558 | 1.58k | [&](CanSILFunctionType origFnTy) { |
559 | | // Check and diagnose non-differentiable arguments. |
560 | 2.52k | for (auto paramIndex : config.parameterIndices->getIndices()) { |
561 | 2.52k | if (!originalFnTy->getParameters()[paramIndex] |
562 | 2.52k | .getSILStorageInterfaceType() |
563 | 2.52k | .isDifferentiable(getModule())) { |
564 | 0 | auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex]; |
565 | 0 | auto startLoc = arg.getLoc().getStartSourceLoc(); |
566 | 0 | auto endLoc = arg.getLoc().getEndSourceLoc(); |
567 | 0 | context |
568 | 0 | .emitNondifferentiabilityError( |
569 | 0 | arg, invoker, diag::autodiff_nondifferentiable_argument) |
570 | 0 | .fixItInsert(startLoc, "withoutDerivative(at: ") |
571 | 0 | .fixItInsertAfter(endLoc, ")"); |
572 | 0 | errorOccurred = true; |
573 | 0 | return true; |
574 | 0 | } |
575 | 2.52k | } |
576 | | // Check and diagnose non-differentiable results. |
577 | 1.59k | for (auto resultIndex : config.resultIndices->getIndices()) { |
578 | 1.59k | SILType remappedResultType; |
579 | 1.59k | if (resultIndex >= originalFnTy->getNumResults()) { |
580 | 92 | auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); |
581 | 92 | auto inoutArg = |
582 | 92 | *std::next(ai->getInoutArguments().begin(), inoutArgIdx); |
583 | 92 | remappedResultType = inoutArg->getType(); |
584 | 1.50k | } else { |
585 | 1.50k | remappedResultType = originalFnTy->getResults()[resultIndex] |
586 | 1.50k | .getSILStorageInterfaceType(); |
587 | 1.50k | } |
588 | 1.59k | if (!remappedResultType.isDifferentiable(getModule())) { |
589 | 0 | auto startLoc = ai->getLoc().getStartSourceLoc(); |
590 | 0 | auto endLoc = ai->getLoc().getEndSourceLoc(); |
591 | 0 | context |
592 | 0 | .emitNondifferentiabilityError( |
593 | 0 | origCallee, invoker, |
594 | 0 | diag::autodiff_nondifferentiable_result) |
595 | 0 | .fixItInsert(startLoc, "withoutDerivative(at: ") |
596 | 0 | .fixItInsertAfter(endLoc, ")"); |
597 | 0 | errorOccurred = true; |
598 | 0 | return true; |
599 | 0 | } |
600 | 1.59k | } |
601 | 1.58k | return false; |
602 | 1.58k | }; |
603 | 1.58k | if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) |
604 | 0 | return; |
605 | | |
606 | 1.58k | auto *diffFuncInst = context.createDifferentiableFunction( |
607 | 1.58k | builder, loc, config.parameterIndices, config.resultIndices, |
608 | 1.58k | origCallee); |
609 | | |
610 | | // Record the `differentiable_function` instruction. |
611 | 1.58k | context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst); |
612 | | |
613 | 1.58k | builder.emitScopedBorrowOperation( |
614 | 1.58k | loc, diffFuncInst, [&](SILValue borrowedADFunc) { |
615 | 1.58k | auto extractedJVP = builder.createDifferentiableFunctionExtract( |
616 | 1.58k | loc, NormalDifferentiableFunctionTypeComponent::JVP, |
617 | 1.58k | borrowedADFunc); |
618 | 1.58k | jvpValue = builder.emitCopyValueOperation(loc, extractedJVP); |
619 | 1.58k | }); |
620 | 1.58k | builder.emitDestroyValueOperation(loc, diffFuncInst); |
621 | 1.58k | } |
622 | | |
623 | | // Call the JVP using the original parameters. |
624 | 1.60k | SmallVector<SILValue, 8> jvpArgs; |
625 | 1.60k | auto jvpFnTy = getOpType(jvpValue->getType()).castTo<SILFunctionType>(); |
626 | 1.60k | auto numJVPArgs = |
627 | 1.60k | jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults(); |
628 | 1.60k | jvpArgs.reserve(numJVPArgs); |
629 | | // Collect substituted arguments. |
630 | 1.60k | for (auto origArg : ai->getArguments()) |
631 | 4.43k | jvpArgs.push_back(getOpValue(origArg)); |
632 | 1.60k | assert(jvpArgs.size() == numJVPArgs); |
633 | | // Apply the JVP. |
634 | | // The JVP should be specialized, so no substitution map is necessary. |
635 | 0 | auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(), |
636 | 1.60k | jvpArgs, ai->getApplyOptions()); |
637 | 1.60k | LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall); |
638 | | |
639 | | // Release the differentiable function. |
640 | 1.60k | builder.emitDestroyValueOperation(loc, jvpValue); |
641 | | |
642 | | // Get the JVP results (original results and differential). |
643 | 1.60k | SmallVector<SILValue, 8> jvpDirectResults; |
644 | 1.60k | extractAllElements(jvpCall, builder, jvpDirectResults); |
645 | 1.60k | auto originalDirectResults = |
646 | 1.60k | ArrayRef<SILValue>(jvpDirectResults).drop_back(1); |
647 | 1.60k | auto originalDirectResult = |
648 | 1.60k | joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc()); |
649 | | |
650 | 1.60k | mapValue(ai, originalDirectResult); |
651 | | |
652 | | // Some instructions that produce the callee may have been cloned. |
653 | | // If the original callee did not have any users beyond this `apply`, |
654 | | // recursively kill the cloned callee. |
655 | 1.60k | if (auto *origCallee = cast_or_null<SingleValueInstruction>( |
656 | 1.60k | ai->getCallee()->getDefiningInstruction())) |
657 | 1.58k | if (origCallee->hasOneUse()) |
658 | 1.58k | recursivelyDeleteTriviallyDeadInstructions( |
659 | 1.58k | getOpValue(origCallee)->getDefiningInstruction()); |
660 | | |
661 | | // Add the differential function for when we create the struct we partially |
662 | | // apply to the differential we are generating. |
663 | 1.60k | auto differential = jvpDirectResults.back(); |
664 | 1.60k | auto differentialType = differentialInfo.lookUpLinearMapType(ai); |
665 | 1.60k | auto originalDifferentialType = |
666 | 1.60k | getOpType(differential->getType()).getAs<SILFunctionType>(); |
667 | 1.60k | auto loweredDifferentialType = |
668 | 1.60k | getOpType(getLoweredType(differentialType)).castTo<SILFunctionType>(); |
669 | | // If actual differential type does not match lowered differential type, |
670 | | // reabstract the differential using a thunk. |
671 | 1.60k | if (!loweredDifferentialType->isEqual(originalDifferentialType)) { |
672 | 388 | SILOptFunctionBuilder fb(context.getTransform()); |
673 | 388 | differential = reabstractFunction( |
674 | 388 | builder, fb, loc, differential, loweredDifferentialType, |
675 | 388 | [this](SubstitutionMap subs) -> SubstitutionMap { |
676 | 388 | return this->getOpSubstitutionMap(subs); |
677 | 388 | }); |
678 | 388 | } |
679 | 1.60k | differentialValues[ai->getParent()].push_back(differential); |
680 | | |
681 | | // Differential emission. |
682 | 1.60k | emitTangentForApplyInst(ai, config, originalDifferentialType); |
683 | 1.60k | } |
684 | | |
685 | 1.33k | void visitReturnInst(ReturnInst *ri) { |
686 | 1.33k | auto loc = ri->getOperand().getLoc(); |
687 | 1.33k | auto *origExit = ri->getParent(); |
688 | 1.33k | auto &builder = getBuilder(); |
689 | 1.33k | auto *diffStructVal = buildDifferentialValueStructValue(ri); |
690 | | |
691 | | // Get the JVP value corresponding to the original functions's return value. |
692 | 1.33k | auto *origRetInst = cast<ReturnInst>(origExit->getTerminator()); |
693 | 1.33k | auto origResult = getOpValue(origRetInst->getOperand()); |
694 | 1.33k | SmallVector<SILValue, 8> origResults; |
695 | 1.33k | extractAllElements(origResult, builder, origResults); |
696 | | |
697 | | // Get and partially apply the differential. |
698 | 1.33k | auto jvpSubstMap = jvp->getForwardingSubstitutionMap(); |
699 | 1.33k | auto *differentialRef = builder.createFunctionRef(loc, &getDifferential()); |
700 | 1.33k | auto *differentialPartialApply = builder.createPartialApply( |
701 | 1.33k | loc, differentialRef, jvpSubstMap, {diffStructVal}, |
702 | 1.33k | ParameterConvention::Direct_Guaranteed); |
703 | | |
704 | 1.33k | auto differentialType = jvp->mapTypeIntoContext( |
705 | 1.33k | jvp->getConventions().getSILType( |
706 | 1.33k | jvp->getLoweredFunctionType()->getResults().back(), |
707 | 1.33k | jvp->getTypeExpansionContext())); |
708 | 1.33k | auto differentialFnType = differentialType.castTo<SILFunctionType>(); |
709 | 1.33k | auto differentialSubstType = |
710 | 1.33k | differentialPartialApply->getType().castTo<SILFunctionType>(); |
711 | | |
712 | | // If necessary, convert the differential value to the returned differential |
713 | | // function type. |
714 | 1.33k | SILValue differentialValue; |
715 | 1.33k | if (differentialSubstType == differentialFnType) { |
716 | 1.20k | differentialValue = differentialPartialApply; |
717 | 1.20k | } else if (differentialSubstType |
718 | 132 | ->isABICompatibleWith(differentialFnType, *jvp) |
719 | 132 | .isCompatible()) { |
720 | 132 | differentialValue = builder.createConvertFunction( |
721 | 132 | loc, differentialPartialApply, differentialType, |
722 | 132 | /*withoutActuallyEscaping*/ false); |
723 | 132 | } else { |
724 | 0 | llvm::report_fatal_error("Differential value type is not ABI-compatible " |
725 | 0 | "with the returned differential type"); |
726 | 0 | } |
727 | | |
728 | | // Return a tuple of the original result and differential. |
729 | 1.33k | SmallVector<SILValue, 8> directResults; |
730 | 1.33k | directResults.append(origResults.begin(), origResults.end()); |
731 | 1.33k | directResults.push_back(differentialValue); |
732 | 1.33k | builder.createReturn(ri->getLoc(), |
733 | 1.33k | joinElements(directResults, builder, loc)); |
734 | 1.33k | } |
735 | | |
736 | 0 | void visitBranchInst(BranchInst *bi) { |
737 | 0 | llvm_unreachable("Unsupported SIL instruction."); |
738 | 0 | } |
739 | | |
740 | 0 | void visitCondBranchInst(CondBranchInst *cbi) { |
741 | 0 | llvm_unreachable("Unsupported SIL instruction."); |
742 | 0 | } |
743 | | |
744 | 0 | void visitSwitchEnumInst(SwitchEnumInst *sei) { |
745 | 0 | llvm_unreachable("Unsupported SIL instruction."); |
746 | 0 | } |
747 | | |
748 | 56 | void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { |
749 | | // Clone `differentiable_function` from original to JVP, then add the cloned |
750 | | // instruction to the `differentiable_function` worklist. |
751 | 56 | TypeSubstCloner::visitDifferentiableFunctionInst(dfi); |
752 | 56 | auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi)); |
753 | 56 | context.getDifferentiableFunctionInstWorklist().push_back(newDFI); |
754 | 56 | } |
755 | | |
756 | 0 | void visitLinearFunctionInst(LinearFunctionInst *lfi) { |
757 | | // Clone `linear_function` from original to JVP, then add the cloned |
758 | | // instruction to the `linear_function` worklist. |
759 | 0 | TypeSubstCloner::visitLinearFunctionInst(lfi); |
760 | 0 | auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi)); |
761 | 0 | context.getLinearFunctionInstWorklist().push_back(newLFI); |
762 | 0 | } |
763 | | |
764 | | //--------------------------------------------------------------------------// |
765 | | // Tangent emission helpers |
766 | | //--------------------------------------------------------------------------// |
767 | | |
768 | | #define CLONE_AND_EMIT_TANGENT(INST, ID) \ |
769 | 5.21k | void visit##INST##Inst(INST##Inst *inst) { \ |
770 | 5.21k | TypeSubstCloner::visit##INST##Inst(inst); \ |
771 | 5.21k | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ |
772 | 5.21k | emitTangentFor##INST##Inst(inst); \ |
773 | 5.21k | } \ _ZN5swift8autodiff9JVPCloner14Implementation19visitAllocStackInstEPNS_14AllocStackInstE Line | Count | Source | 769 | 1.26k | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 1.26k | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 1.26k | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 1.26k | emitTangentFor##INST##Inst(inst); \ | 773 | 1.26k | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation18visitCopyValueInstEPNS_13CopyValueInstE Line | Count | Source | 769 | 112 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 112 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 112 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 112 | emitTangentFor##INST##Inst(inst); \ | 773 | 112 | } \ |
Unexecuted instantiation: _ZN5swift8autodiff9JVPCloner14Implementation19visitLoadBorrowInstEPNS_14LoadBorrowInstE _ZN5swift8autodiff9JVPCloner14Implementation20visitBeginBorrowInstEPNS_15BeginBorrowInstE Line | Count | Source | 769 | 28 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 28 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 28 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 28 | emitTangentFor##INST##Inst(inst); \ | 773 | 28 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation20visitBeginAccessInstEPNS_15BeginAccessInstE Line | Count | Source | 769 | 604 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 604 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 604 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 604 | emitTangentFor##INST##Inst(inst); \ | 773 | 604 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation14visitTupleInstEPNS_9TupleInstE Line | Count | Source | 769 | 196 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 196 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 196 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 196 | emitTangentFor##INST##Inst(inst); \ | 773 | 196 | } \ |
Unexecuted instantiation: _ZN5swift8autodiff9JVPCloner14Implementation21visitTupleExtractInstEPNS_16TupleExtractInstE _ZN5swift8autodiff9JVPCloner14Implementation25visitTupleElementAddrInstEPNS_20TupleElementAddrInstE Line | Count | Source | 769 | 328 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 328 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 328 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 328 | emitTangentFor##INST##Inst(inst); \ | 773 | 328 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation15visitStructInstEPNS_10StructInstE Line | Count | Source | 769 | 24 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 24 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 24 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 24 | emitTangentFor##INST##Inst(inst); \ | 773 | 24 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation22visitStructExtractInstEPNS_17StructExtractInstE Line | Count | Source | 769 | 220 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 220 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 220 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 220 | emitTangentFor##INST##Inst(inst); \ | 773 | 220 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation26visitStructElementAddrInstEPNS_21StructElementAddrInstE Line | Count | Source | 769 | 184 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 184 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 184 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 184 | emitTangentFor##INST##Inst(inst); \ | 773 | 184 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation21visitDeallocStackInstEPNS_16DeallocStackInstE Line | Count | Source | 769 | 1.24k | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 1.24k | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 1.24k | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 1.24k | emitTangentFor##INST##Inst(inst); \ | 773 | 1.24k | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation21visitDestroyValueInstEPNS_16DestroyValueInstE Line | Count | Source | 769 | 188 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 188 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 188 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 188 | emitTangentFor##INST##Inst(inst); \ | 773 | 188 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation18visitEndBorrowInstEPNS_13EndBorrowInstE Line | Count | Source | 769 | 28 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 28 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 28 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 28 | emitTangentFor##INST##Inst(inst); \ | 773 | 28 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation18visitEndAccessInstEPNS_13EndAccessInstE Line | Count | Source | 769 | 592 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 592 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 592 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 592 | emitTangentFor##INST##Inst(inst); \ | 773 | 592 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation20visitDestroyAddrInstEPNS_15DestroyAddrInstE Line | Count | Source | 769 | 176 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 176 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 176 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 176 | emitTangentFor##INST##Inst(inst); \ | 773 | 176 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation37visitUnconditionalCheckedCastAddrInstEPNS_32UnconditionalCheckedCastAddrInstE Line | Count | Source | 769 | 4 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 4 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 4 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 4 | emitTangentFor##INST##Inst(inst); \ | 773 | 4 | } \ |
_ZN5swift8autodiff9JVPCloner14Implementation25visitDestructureTupleInstEPNS_20DestructureTupleInstE Line | Count | Source | 769 | 24 | void visit##INST##Inst(INST##Inst *inst) { \ | 770 | 24 | TypeSubstCloner::visit##INST##Inst(inst); \ | 771 | 24 | if (differentialInfo.shouldDifferentiateInstruction(inst)) \ | 772 | 24 | emitTangentFor##INST##Inst(inst); \ | 773 | 24 | } \ |
|
774 | | void emitTangentFor##INST##Inst(INST##Inst *(ID)) |
775 | | |
776 | 0 | CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { |
777 | 0 | auto &diffBuilder = getDifferentialBuilder(); |
778 | 0 | auto loc = bbi->getLoc(); |
779 | 0 | auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); |
780 | 0 | auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); |
781 | 0 | setTangentValue(bbi->getParent(), bbi, |
782 | 0 | makeConcreteTangentValue(tanValBorrow)); |
783 | 0 | } |
784 | | |
785 | 0 | CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { |
786 | 0 | auto &diffBuilder = getDifferentialBuilder(); |
787 | 0 | auto loc = ebi->getLoc(); |
788 | 0 | auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); |
789 | 0 | diffBuilder.emitEndBorrowOperation(loc, tanVal); |
790 | 0 | } |
791 | | |
792 | 16 | CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { |
793 | 16 | auto &diffBuilder = getDifferentialBuilder(); |
794 | 16 | auto loc = dvi->getLoc(); |
795 | 16 | auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); |
796 | 16 | diffBuilder.emitDestroyValueOperation(loc, tanVal); |
797 | 16 | } |
798 | | |
799 | 12 | CLONE_AND_EMIT_TANGENT(CopyValue, cvi) { |
800 | 12 | auto &diffBuilder = getDifferentialBuilder(); |
801 | 12 | auto tan = getTangentValue(cvi->getOperand()); |
802 | 12 | auto tanVal = materializeTangent(tan, cvi->getLoc()); |
803 | 12 | auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal); |
804 | 12 | setTangentValue(cvi->getParent(), cvi, |
805 | 12 | makeConcreteTangentValue(tanValCopy)); |
806 | 12 | } |
807 | | |
808 | | /// Handle `load` instruction. |
809 | | /// Original: y = load x |
810 | | /// Tangent: tan[y] = load tan[x] |
811 | 564 | void visitLoadInst(LoadInst *li) { |
812 | 564 | TypeSubstCloner::visitLoadInst(li); |
813 | | // If an active buffer is loaded with take to a non-active value, destroy |
814 | | // the active buffer's tangent buffer. |
815 | 564 | if (!differentialInfo.shouldDifferentiateInstruction(li)) { |
816 | 12 | auto isTake = |
817 | 12 | (li->getOwnershipQualifier() == LoadOwnershipQualifier::Take); |
818 | 12 | if (isTake && activityInfo.isActive(li->getOperand(), getConfig())) { |
819 | 0 | auto &tanBuf = getTangentBuffer(li->getParent(), li->getOperand()); |
820 | 0 | getDifferentialBuilder().emitDestroyOperation(tanBuf.getLoc(), tanBuf); |
821 | 0 | } |
822 | 12 | return; |
823 | 12 | } |
824 | | // Otherwise, do standard differential cloning. |
825 | 552 | auto &diffBuilder = getDifferentialBuilder(); |
826 | 552 | auto *bb = li->getParent(); |
827 | 552 | auto loc = li->getLoc(); |
828 | 552 | auto tanBuf = getTangentBuffer(bb, li->getOperand()); |
829 | 552 | auto tanVal = diffBuilder.emitLoadValueOperation( |
830 | 552 | loc, tanBuf, li->getOwnershipQualifier()); |
831 | 552 | setTangentValue(bb, li, makeConcreteTangentValue(tanVal)); |
832 | 552 | } |
833 | | |
834 | | /// Handle `load_borrow` instruction. |
835 | | /// Original: y = load_borrow x |
836 | | /// Tangent: tan[y] = load_borrow tan[x] |
837 | 0 | CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) { |
838 | 0 | auto &diffBuilder = getDifferentialBuilder(); |
839 | 0 | auto *bb = lbi->getParent(); |
840 | 0 | auto loc = lbi->getLoc(); |
841 | 0 | auto tanBuf = getTangentBuffer(bb, lbi->getOperand()); |
842 | 0 | auto tanVal = diffBuilder.emitLoadBorrowOperation(loc, tanBuf); |
843 | 0 | setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal)); |
844 | 0 | } |
845 | | |
846 | | /// Handle `store` instruction in the differential. |
847 | | /// Original: store x to y |
848 | | /// Tangent: store tan[x] to tan[y] |
849 | 884 | void visitStoreInst(StoreInst *si) { |
850 | 884 | TypeSubstCloner::visitStoreInst(si); |
851 | | // If a non-active value is stored into an active buffer, zero-initialize |
852 | | // the active buffer's tangent buffer. |
853 | 884 | if (!differentialInfo.shouldDifferentiateInstruction(si)) { |
854 | 116 | if (activityInfo.isActive(si->getDest(), getConfig())) { |
855 | 0 | auto &tanBufDest = getTangentBuffer(si->getParent(), si->getDest()); |
856 | 0 | emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest, |
857 | 0 | tanBufDest.getLoc()); |
858 | 0 | } |
859 | 116 | return; |
860 | 116 | } |
861 | | // Otherwise, do standard differential cloning. |
862 | 768 | auto &diffBuilder = getDifferentialBuilder(); |
863 | 768 | auto loc = si->getLoc(); |
864 | 768 | auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); |
865 | 768 | auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); |
866 | 768 | diffBuilder.emitStoreValueOperation(loc, tanValSrc, tanValDest, |
867 | 768 | si->getOwnershipQualifier()); |
868 | 768 | } |
869 | | |
870 | | /// Handle `store_borrow` instruction in the differential. |
871 | | /// Original: store_borrow x to y |
872 | | /// Tangent: store_borrow tan[x] to tan[y] |
873 | 0 | void visitStoreBorrowInst(StoreBorrowInst *sbi) { |
874 | 0 | TypeSubstCloner::visitStoreBorrowInst(sbi); |
875 | | // If a non-active value is stored into an active buffer, zero-initialize |
876 | | // the active buffer's tangent buffer. |
877 | 0 | if (!differentialInfo.shouldDifferentiateInstruction(sbi)) { |
878 | 0 | if (activityInfo.isActive(sbi->getDest(), getConfig())) { |
879 | 0 | auto &tanBufDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); |
880 | 0 | emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest, |
881 | 0 | tanBufDest.getLoc()); |
882 | 0 | } |
883 | 0 | return; |
884 | 0 | } |
885 | | // Otherwise, do standard differential cloning. |
886 | 0 | auto &diffBuilder = getDifferentialBuilder(); |
887 | 0 | auto loc = sbi->getLoc(); |
888 | 0 | auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); |
889 | 0 | auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); |
890 | 0 | diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); |
891 | 0 | } |
892 | | |
893 | | /// Handle `copy_addr` instruction. |
894 | | /// Original: copy_addr x to y |
895 | | /// Tangent: copy_addr tan[x] to tan[y] |
896 | 248 | void visitCopyAddrInst(CopyAddrInst *cai) { |
897 | 248 | TypeSubstCloner::visitCopyAddrInst(cai); |
898 | | // If a non-active buffer is copied into an active buffer, zero-initialize |
899 | | // the destination buffer's tangent buffer. |
900 | | // If an active buffer is copied with take into a non-active buffer, destroy |
901 | | // the source buffer's tangent buffer. |
902 | 248 | if (!differentialInfo.shouldDifferentiateInstruction(cai)) { |
903 | 4 | if (activityInfo.isActive(cai->getDest(), getConfig())) { |
904 | 0 | auto &tanBufDest = getTangentBuffer(cai->getParent(), cai->getDest()); |
905 | 0 | emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest, |
906 | 0 | tanBufDest.getLoc()); |
907 | 0 | } |
908 | 4 | if (cai->isTakeOfSrc() && |
909 | 4 | activityInfo.isActive(cai->getSrc(), getConfig())) { |
910 | 0 | auto &tanBufSrc = getTangentBuffer(cai->getParent(), cai->getSrc()); |
911 | 0 | getDifferentialBuilder().emitDestroyOperation(tanBufSrc.getLoc(), |
912 | 0 | tanBufSrc); |
913 | 0 | } |
914 | 4 | return; |
915 | 4 | } |
916 | | // Otherwise, do standard differential cloning. |
917 | 244 | auto diffBuilder = getDifferentialBuilder(); |
918 | 244 | auto loc = cai->getLoc(); |
919 | 244 | auto *bb = cai->getParent(); |
920 | 244 | auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); |
921 | 244 | auto tanDest = getTangentBuffer(bb, cai->getDest()); |
922 | 244 | diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), |
923 | 244 | cai->isInitializationOfDest()); |
924 | 244 | } |
925 | | |
926 | | /// Handle `unconditional_checked_cast_addr` instruction. |
927 | | /// Original: unconditional_checked_cast_addr $X in x to $Y in y |
928 | | /// Tangent: unconditional_checked_cast_addr $X.Tan in tan[x] |
929 | | /// to $Y.Tan in tan[y] |
930 | 4 | CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai) { |
931 | 4 | auto diffBuilder = getDifferentialBuilder(); |
932 | 4 | auto loc = uccai->getLoc(); |
933 | 4 | auto *bb = uccai->getParent(); |
934 | 4 | auto &tanSrc = getTangentBuffer(bb, uccai->getSrc()); |
935 | 4 | auto tanDest = getTangentBuffer(bb, uccai->getDest()); |
936 | | |
937 | 4 | diffBuilder.createUnconditionalCheckedCastAddr( |
938 | 4 | loc, tanSrc, tanSrc->getType().getASTType(), tanDest, |
939 | 4 | tanDest->getType().getASTType()); |
940 | 4 | } |
941 | | |
942 | | /// Handle `begin_access` instruction (and do differentiability checks). |
943 | | /// Original: y = begin_access x |
944 | | /// Tangent: tan[y] = begin_access tan[x] |
945 | 588 | CLONE_AND_EMIT_TANGENT(BeginAccess, bai) { |
946 | | // Check for non-differentiable writes. |
947 | 588 | if (bai->getAccessKind() == SILAccessKind::Modify) { |
948 | 260 | if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) { |
949 | 0 | context.emitNondifferentiabilityError( |
950 | 0 | bai, invoker, |
951 | 0 | diag::autodiff_cannot_differentiate_writes_to_global_variables); |
952 | 0 | errorOccurred = true; |
953 | 0 | return; |
954 | 0 | } |
955 | 260 | if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) { |
956 | 0 | context.emitNondifferentiabilityError( |
957 | 0 | bai, invoker, |
958 | 0 | diag::autodiff_cannot_differentiate_writes_to_mutable_captures); |
959 | 0 | errorOccurred = true; |
960 | 0 | return; |
961 | 0 | } |
962 | 260 | } |
963 | | |
964 | 588 | auto &diffBuilder = getDifferentialBuilder(); |
965 | 588 | auto *bb = bai->getParent(); |
966 | | |
967 | 588 | auto tanSrc = getTangentBuffer(bb, bai->getSource()); |
968 | 588 | auto *tanDest = diffBuilder.createBeginAccess( |
969 | 588 | bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(), |
970 | 588 | bai->hasNoNestedConflict(), bai->isFromBuiltin()); |
971 | 588 | setTangentBuffer(bb, bai, tanDest); |
972 | 588 | } |
973 | | |
974 | | /// Handle `end_access` instruction. |
975 | | /// Original: begin_access x |
976 | | /// Tangent: end_access tan[x] |
977 | 576 | CLONE_AND_EMIT_TANGENT(EndAccess, eai) { |
978 | 576 | auto &diffBuilder = getDifferentialBuilder(); |
979 | 576 | auto *bb = eai->getParent(); |
980 | 576 | auto loc = eai->getLoc(); |
981 | 576 | auto tanOperand = getTangentBuffer(bb, eai->getOperand()); |
982 | 576 | diffBuilder.createEndAccess(loc, tanOperand, eai->isAborting()); |
983 | 576 | } |
984 | | |
985 | | /// Handle `alloc_stack` instruction. |
986 | | /// Original: y = alloc_stack $T |
987 | | /// Tangent: tan[y] = alloc_stack $T.Tangent |
988 | 1.18k | CLONE_AND_EMIT_TANGENT(AllocStack, asi) { |
989 | 1.18k | auto &diffBuilder = getDifferentialBuilder(); |
990 | 1.18k | auto *mappedAllocStackInst = diffBuilder.createAllocStack( |
991 | 1.18k | asi->getLoc(), getRemappedTangentType(asi->getElementType()), |
992 | 1.18k | asi->getVarInfo()); |
993 | 1.18k | setTangentBuffer(asi->getParent(), asi, mappedAllocStackInst); |
994 | 1.18k | } |
995 | | |
996 | | /// Handle `dealloc_stack` instruction. |
997 | | /// Original: dealloc_stack x |
998 | | /// Tangent: dealloc_stack tan[x] |
999 | 1.16k | CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) { |
1000 | 1.16k | auto &diffBuilder = getDifferentialBuilder(); |
1001 | 1.16k | auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand()); |
1002 | 1.16k | diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf); |
1003 | 1.16k | } |
1004 | | |
1005 | | /// Handle `destroy_addr` instruction. |
1006 | | /// Original: destroy_addr x |
1007 | | /// Tangent: destroy_addr tan[x] |
1008 | 164 | CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) { |
1009 | 164 | auto &diffBuilder = getDifferentialBuilder(); |
1010 | 164 | auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand()); |
1011 | 164 | diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf); |
1012 | 164 | } |
1013 | | |
1014 | | /// Handle `struct` instruction. |
1015 | | /// Original: y = struct $T (x0, x1, x2, ...) |
1016 | | /// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) |
1017 | 24 | CLONE_AND_EMIT_TANGENT(Struct, si) { |
1018 | 24 | auto &diffBuilder = getDifferentialBuilder(); |
1019 | 24 | SmallVector<SILValue, 4> tangentElements; |
1020 | 24 | for (auto elem : si->getElements()) |
1021 | 32 | tangentElements.push_back(getTangentValue(elem).getConcreteValue()); |
1022 | 24 | auto tanExtract = diffBuilder.createStruct( |
1023 | 24 | si->getLoc(), getRemappedTangentType(si->getType()), tangentElements); |
1024 | 24 | setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract)); |
1025 | 24 | } |
1026 | | |
1027 | | /// Handle `struct_extract` instruction. |
1028 | | /// Original: y = struct_extract x, #field |
1029 | | /// Tangent: tan[y] = struct_extract tan[x], #field' |
1030 | | /// ^~~~~~~ |
1031 | | /// field in tangent space corresponding to #field |
1032 | 204 | CLONE_AND_EMIT_TANGENT(StructExtract, sei) { |
1033 | 204 | assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
1034 | 204 | "`struct_extract` with `@noDerivative` field should not be " |
1035 | 204 | "differentiated; activity analysis should not marked as varied."); |
1036 | 0 | auto diffBuilder = getDifferentialBuilder(); |
1037 | 204 | auto loc = getValidLocation(sei); |
1038 | | // Find the corresponding field in the tangent space. |
1039 | 204 | auto structType = |
1040 | 204 | remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType(); |
1041 | 204 | auto *tanField = |
1042 | 204 | getTangentStoredProperty(context, sei, structType, invoker); |
1043 | 204 | if (!tanField) { |
1044 | 8 | errorOccurred = true; |
1045 | 8 | return; |
1046 | 8 | } |
1047 | | // Emit tangent `struct_extract`. |
1048 | 196 | auto tanStruct = |
1049 | 196 | materializeTangent(getTangentValue(sei->getOperand()), loc); |
1050 | 196 | auto tangentInst = |
1051 | 196 | diffBuilder.createStructExtract(loc, tanStruct, tanField); |
1052 | | // Update tangent value mapping for `struct_extract` result. |
1053 | 196 | auto tangentResult = makeConcreteTangentValue(tangentInst); |
1054 | 196 | setTangentValue(sei->getParent(), sei, tangentResult); |
1055 | 196 | } |
1056 | | |
1057 | | /// Handle `struct_element_addr` instruction. |
1058 | | /// Original: y = struct_element_addr x, #field |
1059 | | /// Tangent: tan[y] = struct_element_addr tan[x], #field' |
1060 | | /// ^~~~~~~ |
1061 | | /// field in tangent space corresponding to #field |
1062 | 180 | CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { |
1063 | 180 | assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
1064 | 180 | "`struct_element_addr` with `@noDerivative` field should not be " |
1065 | 180 | "differentiated; activity analysis should not marked as varied."); |
1066 | 0 | auto diffBuilder = getDifferentialBuilder(); |
1067 | 180 | auto *bb = seai->getParent(); |
1068 | 180 | auto loc = getValidLocation(seai); |
1069 | | // Find the corresponding field in the tangent space. |
1070 | 180 | auto structType = |
1071 | 180 | remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType(); |
1072 | 180 | auto *tanField = |
1073 | 180 | getTangentStoredProperty(context, seai, structType, invoker); |
1074 | 180 | if (!tanField) { |
1075 | 12 | errorOccurred = true; |
1076 | 12 | return; |
1077 | 12 | } |
1078 | | // Emit tangent `struct_element_addr`. |
1079 | 168 | auto tanOperand = getTangentBuffer(bb, seai->getOperand()); |
1080 | 168 | auto tangentInst = |
1081 | 168 | diffBuilder.createStructElementAddr(loc, tanOperand, tanField); |
1082 | | // Update tangent buffer map for `struct_element_addr`. |
1083 | 168 | setTangentBuffer(bb, seai, tangentInst); |
1084 | 168 | } |
1085 | | |
1086 | | /// Handle `tuple` instruction. |
1087 | | /// Original: y = tuple (x0, x1, x2, ...) |
1088 | | /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) |
1089 | | /// ^~~ |
1090 | | /// excluding non-differentiable elements |
1091 | 8 | CLONE_AND_EMIT_TANGENT(Tuple, ti) { |
1092 | 8 | auto diffBuilder = getDifferentialBuilder(); |
1093 | | // Get the tangents of all the tuple elements. |
1094 | 8 | SmallVector<SILValue, 8> tangentTupleElements; |
1095 | 16 | for (auto elem : ti->getElements()) { |
1096 | 16 | if (!getTangentSpace(elem->getType().getASTType())) |
1097 | 0 | continue; |
1098 | 16 | tangentTupleElements.push_back( |
1099 | 16 | materializeTangent(getTangentValue(elem), ti->getLoc())); |
1100 | 16 | } |
1101 | | // Emit the instruction and add the tangent mapping. |
1102 | 8 | auto tanTuple = |
1103 | 8 | joinElements(tangentTupleElements, diffBuilder, ti->getLoc()); |
1104 | 8 | setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple)); |
1105 | 8 | } |
1106 | | |
1107 | | /// Handle `tuple_extract` instruction. |
1108 | | /// Original: y = tuple_extract x, <n> |
1109 | | /// Tangent: tan[y] = tuple_extract tan[x], <n'> |
1110 | | /// ^~~~ |
1111 | | /// tuple tangent space index corresponding to n |
1112 | 0 | CLONE_AND_EMIT_TANGENT(TupleExtract, tei) { |
1113 | 0 | auto &diffBuilder = getDifferentialBuilder(); |
1114 | 0 | auto loc = tei->getLoc(); |
1115 | 0 | auto origTupleTy = tei->getOperand()->getType().castTo<TupleType>(); |
1116 | 0 | unsigned tanIndex = 0; |
1117 | 0 | for (unsigned i : range(tei->getFieldIndex())) { |
1118 | 0 | if (getTangentSpace( |
1119 | 0 | origTupleTy->getElement(i).getType()->getCanonicalType())) |
1120 | 0 | ++tanIndex; |
1121 | 0 | } |
1122 | 0 | auto tanType = getRemappedTangentType(tei->getType()); |
1123 | 0 | auto tanSource = |
1124 | 0 | materializeTangent(getTangentValue(tei->getOperand()), loc); |
1125 | | // If the tangent value of the source does not have a tuple type, then |
1126 | | // it must represent a "single element tuple type". Use it directly. |
1127 | 0 | if (!tanSource->getType().is<TupleType>()) { |
1128 | 0 | setTangentValue(tei->getParent(), tei, |
1129 | 0 | makeConcreteTangentValue(tanSource)); |
1130 | 0 | } else { |
1131 | 0 | auto tanElt = |
1132 | 0 | diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType); |
1133 | 0 | setTangentValue(tei->getParent(), tei, makeConcreteTangentValue(tanElt)); |
1134 | 0 | } |
1135 | 0 | } |
1136 | | |
1137 | | /// Handle `tuple_element_addr` instruction. |
1138 | | /// Original: y = tuple_element_addr x, <n> |
1139 | | /// Tangent: tan[y] = tuple_element_addr tan[x], <n'> |
1140 | | /// ^~~~ |
1141 | | /// tuple tangent space index corresponding to n |
1142 | 272 | CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) { |
1143 | 272 | auto &diffBuilder = getDifferentialBuilder(); |
1144 | 272 | auto origTupleTy = teai->getOperand()->getType().castTo<TupleType>(); |
1145 | 272 | unsigned tanIndex = 0; |
1146 | 272 | for (unsigned i : range(teai->getFieldIndex())) { |
1147 | 128 | if (getTangentSpace( |
1148 | 128 | origTupleTy->getElement(i).getType()->getCanonicalType())) |
1149 | 64 | ++tanIndex; |
1150 | 128 | } |
1151 | 272 | auto tanType = getRemappedTangentType(teai->getType()); |
1152 | 272 | auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand()); |
1153 | 272 | SILValue tanBuf; |
1154 | | // If the tangent buffer of the source does not have a tuple type, then |
1155 | | // it must represent a "single element tuple type". Use it directly. |
1156 | 272 | if (!tanSource->getType().is<TupleType>()) { |
1157 | 52 | tanBuf = tanSource; |
1158 | 220 | } else { |
1159 | 220 | tanBuf = diffBuilder.createTupleElementAddr(teai->getLoc(), tanSource, |
1160 | 220 | tanIndex, tanType); |
1161 | 220 | } |
1162 | 272 | setTangentBuffer(teai->getParent(), teai, tanBuf); |
1163 | 272 | } |
1164 | | |
1165 | | /// Handle `destructure_tuple` instruction. |
1166 | | /// Original: (y0, y1, ...) = destructure_tuple x, <n> |
1167 | | /// Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], <n'> |
1168 | | /// ^~~~ |
1169 | | /// tuple tangent space index corresponding to n |
1170 | 12 | CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { |
1171 | 12 | assert(llvm::any_of(dti->getResults(), |
1172 | 12 | [&](SILValue elt) { |
1173 | 12 | return activityInfo.isActive(elt, getConfig()); |
1174 | 12 | }) && |
1175 | 12 | "`destructure_tuple` should have at least one active result"); |
1176 | | |
1177 | 0 | auto &diffBuilder = getDifferentialBuilder(); |
1178 | 12 | auto *bb = dti->getParent(); |
1179 | 12 | auto loc = dti->getLoc(); |
1180 | | |
1181 | 12 | auto tanTuple = materializeTangent(getTangentValue(dti->getOperand()), loc); |
1182 | 12 | SmallVector<SILValue, 4> tanElts; |
1183 | 12 | if (tanTuple->getType().is<TupleType>()) { |
1184 | 12 | auto *tanDti = diffBuilder.createDestructureTuple(loc, tanTuple); |
1185 | 12 | tanElts.append(tanDti->getResults().begin(), tanDti->getResults().end()); |
1186 | 12 | } else { |
1187 | 0 | tanElts.push_back(tanTuple); |
1188 | 0 | } |
1189 | 12 | unsigned tanIdx = 0; |
1190 | 24 | for (auto i : range(dti->getNumResults())) { |
1191 | 24 | auto origElt = dti->getResult(i); |
1192 | 24 | if (!getTangentSpace(origElt->getType().getASTType())) |
1193 | 0 | continue; |
1194 | 24 | setTangentValue(bb, origElt, makeConcreteTangentValue(tanElts[tanIdx++])); |
1195 | 24 | } |
1196 | 12 | } |
1197 | | |
1198 | | #undef CLONE_AND_EMIT_TANGENT |
1199 | | |
1200 | | /// Handle `apply` instruction, given: |
1201 | | /// - The minimal indices for differentiating the `apply`. |
1202 | | /// - The original non-reabstracted differential type. |
1203 | | /// |
1204 | | /// Original: y = apply f(x0, x1, ...) |
1205 | | /// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...) |
1206 | | void emitTangentForApplyInst(ApplyInst *ai, const AutoDiffConfig &applyConfig, |
1207 | 1.60k | CanSILFunctionType originalDifferentialType) { |
1208 | 1.60k | assert(differentialInfo.shouldDifferentiateApplySite(ai)); |
1209 | 0 | auto *bb = ai->getParent(); |
1210 | 1.60k | auto loc = ai->getLoc(); |
1211 | 1.60k | auto &diffBuilder = getDifferentialBuilder(); |
1212 | | |
1213 | | // Get the differential value. |
1214 | 1.60k | SILValue differential = getDifferentialTupleElement(ai); |
1215 | 1.60k | auto differentialType = remapSILTypeInDifferential(differential->getType()) |
1216 | 1.60k | .castTo<SILFunctionType>(); |
1217 | | |
1218 | | // Get the differential arguments. |
1219 | 1.60k | SmallVector<SILValue, 8> diffArgs; |
1220 | | |
1221 | 1.60k | for (auto indRes : ai->getIndirectSILResults()) |
1222 | 452 | diffArgs.push_back(getTangentBuffer(bb, indRes)); |
1223 | | |
1224 | 1.60k | auto origArgs = ai->getArgumentsWithoutIndirectResults(); |
1225 | | // Get the tangent value of the original arguments. |
1226 | 3.98k | for (auto i : indices(origArgs)) { |
1227 | 3.98k | auto origArg = origArgs[i]; |
1228 | | // If the argument is not active: |
1229 | | // - Skip the element, if it is not differentiable. |
1230 | | // - Otherwise, add a zero value to that location. |
1231 | 3.98k | if (!activityInfo.isActive(origArg, getConfig())) { |
1232 | 1.42k | auto origCalleeType = ai->getSubstCalleeType(); |
1233 | 1.42k | if (!origCalleeType->isDifferentiable()) |
1234 | 1.41k | continue; |
1235 | 8 | auto actualOrigCalleeIndices = |
1236 | 8 | origCalleeType->getDifferentiabilityParameterIndices(); |
1237 | 8 | if (actualOrigCalleeIndices->contains(i)) { |
1238 | 4 | SILValue tanParam; |
1239 | 4 | if (origArg->getType().isObject()) { |
1240 | 4 | tanParam = emitZeroDirect( |
1241 | 4 | getRemappedTangentType(origArg->getType()).getASTType(), loc); |
1242 | 4 | diffArgs.push_back(tanParam); |
1243 | 4 | } else { |
1244 | 0 | tanParam = diffBuilder.createAllocStack( |
1245 | 0 | loc, getRemappedTangentType(origArg->getType())); |
1246 | 0 | emitZeroIndirect( |
1247 | 0 | getRemappedTangentType(origArg->getType()).getASTType(), |
1248 | 0 | tanParam, loc); |
1249 | 0 | } |
1250 | 4 | } |
1251 | 8 | } |
1252 | | // Otherwise, if the argument is active, handle the argument normally by |
1253 | | // getting its tangent value. |
1254 | 2.56k | else { |
1255 | 2.56k | SILValue tanParam; |
1256 | 2.56k | if (origArg->getType().isObject()) { |
1257 | 1.71k | tanParam = materializeTangent(getTangentValue(origArg), loc); |
1258 | 1.71k | } else { |
1259 | 844 | tanParam = getTangentBuffer(ai->getParent(), origArg); |
1260 | 844 | } |
1261 | 2.56k | diffArgs.push_back(tanParam); |
1262 | 2.56k | if (errorOccurred) |
1263 | 0 | return; |
1264 | 2.56k | } |
1265 | 3.98k | } |
1266 | | |
1267 | | // If callee differential was reabstracted in JVP, reabstract the callee |
1268 | | // differential. |
1269 | 1.60k | if (!differentialType->isEqual(originalDifferentialType)) { |
1270 | 388 | SILOptFunctionBuilder fb(context.getTransform()); |
1271 | 388 | differential = reabstractFunction( |
1272 | 388 | diffBuilder, fb, loc, differential, originalDifferentialType, |
1273 | 388 | [this](SubstitutionMap subs) -> SubstitutionMap { |
1274 | 388 | return this->getOpSubstitutionMap(subs); |
1275 | 388 | }); |
1276 | 388 | } |
1277 | | |
1278 | | // Call the differential. |
1279 | 1.60k | auto *differentialCall = |
1280 | 1.60k | diffBuilder.createApply(loc, differential, SubstitutionMap(), diffArgs); |
1281 | 1.60k | diffBuilder.emitDestroyValueOperation(loc, differential); |
1282 | | |
1283 | | // Get the original `apply` results. |
1284 | 1.60k | SmallVector<SILValue, 8> origDirectResults; |
1285 | 1.60k | forEachApplyDirectResult(ai, [&](SILValue directResult) { |
1286 | 1.08k | origDirectResults.push_back(directResult); |
1287 | 1.08k | }); |
1288 | 1.60k | SmallVector<SILValue, 8> origAllResults; |
1289 | 1.60k | collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); |
1290 | | |
1291 | | // Get the callee differential `apply` results. |
1292 | 1.60k | SmallVector<SILValue, 8> differentialDirectResults; |
1293 | 1.60k | extractAllElements(differentialCall, getDifferentialBuilder(), |
1294 | 1.60k | differentialDirectResults); |
1295 | 1.60k | SmallVector<SILValue, 8> differentialAllResults; |
1296 | 1.60k | collectAllActualResultsInTypeOrder( |
1297 | 1.60k | differentialCall, differentialDirectResults, differentialAllResults); |
1298 | 1.60k | for (auto inoutArg : ai->getInoutArguments()) |
1299 | 92 | origAllResults.push_back(inoutArg); |
1300 | 1.60k | for (auto inoutArg : differentialCall->getInoutArguments()) |
1301 | 92 | differentialAllResults.push_back(inoutArg); |
1302 | 1.60k | assert(applyConfig.resultIndices->getNumIndices() == |
1303 | 1.60k | differentialAllResults.size()); |
1304 | | |
1305 | | // Set tangent values for original `apply` results. |
1306 | 0 | unsigned differentialResultIndex = 0; |
1307 | 1.61k | for (auto resultIndex : applyConfig.resultIndices->getIndices()) { |
1308 | 1.61k | auto origResult = origAllResults[resultIndex]; |
1309 | 1.61k | auto differentialResult = |
1310 | 1.61k | differentialAllResults[differentialResultIndex++]; |
1311 | 1.61k | if (origResult->getType().isObject()) { |
1312 | 1.07k | if (!origResult->getType().is<TupleType>()) { |
1313 | 1.07k | setTangentValue(bb, origResult, |
1314 | 1.07k | makeConcreteTangentValue(differentialResult)); |
1315 | 1.07k | } else if (auto *dti = getSingleDestructureTupleUser(ai)) { |
1316 | 0 | bool notSetValue = true; |
1317 | 0 | for (auto result : dti->getResults()) { |
1318 | 0 | if (activityInfo.isActive(result, getConfig())) { |
1319 | 0 | assert(notSetValue && |
1320 | 0 | "This was incorrectly set, should only have one active " |
1321 | 0 | "result from the tuple."); |
1322 | 0 | notSetValue = false; |
1323 | 0 | setTangentValue(bb, result, |
1324 | 0 | makeConcreteTangentValue(differentialResult)); |
1325 | 0 | } |
1326 | 0 | } |
1327 | 0 | } |
1328 | 1.07k | } |
1329 | 1.61k | } |
1330 | 1.60k | } |
1331 | | |
1332 | | /// Generate a `return` instruction in the current differential basic block. |
1333 | 1.35k | void emitReturnInstForDifferential() { |
1334 | 1.35k | auto &differential = getDifferential(); |
1335 | 1.35k | auto diffLoc = differential.getLocation(); |
1336 | 1.35k | auto &diffBuilder = getDifferentialBuilder(); |
1337 | | |
1338 | | // Collect original results. |
1339 | 1.35k | SmallVector<SILValue, 2> originalResults; |
1340 | 1.35k | collectAllDirectResultsInTypeOrder(*original, originalResults); |
1341 | | // Collect differential direct results. |
1342 | 1.35k | SmallVector<SILValue, 8> retElts; |
1343 | 1.35k | for (auto i : range(originalResults.size())) { |
1344 | 1.19k | auto origResult = originalResults[i]; |
1345 | 1.19k | if (!getConfig().resultIndices->contains(i)) |
1346 | 8 | continue; |
1347 | 1.18k | auto tanVal = materializeTangent(getTangentValue(origResult), diffLoc); |
1348 | 1.18k | retElts.push_back(tanVal); |
1349 | 1.18k | } |
1350 | | |
1351 | 1.35k | diffBuilder.createReturn(diffLoc, |
1352 | 1.35k | joinElements(retElts, diffBuilder, diffLoc)); |
1353 | 1.35k | } |
1354 | | }; |
1355 | | |
1356 | | //--------------------------------------------------------------------------// |
1357 | | // Initialization |
1358 | | //--------------------------------------------------------------------------// |
1359 | | |
1360 | | /// Initialization helper function. |
1361 | | /// |
1362 | | /// Returns the substitution map used for type remapping. |
1363 | | static SubstitutionMap getSubstitutionMap(SILFunction *original, |
1364 | 1.35k | SILFunction *jvp) { |
1365 | 1.35k | auto substMap = original->getForwardingSubstitutionMap(); |
1366 | 1.35k | if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { |
1367 | 160 | auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); |
1368 | 160 | substMap = SubstitutionMap::get( |
1369 | 160 | jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, |
1370 | 160 | LookUpConformanceInSubstitutionMap(jvpSubstMap)); |
1371 | 160 | } |
1372 | 1.35k | return substMap; |
1373 | 1.35k | } |
1374 | | |
1375 | | /// Initialization helper function. |
1376 | | /// |
1377 | | /// Returns the activity info for the given original function, autodiff indices, |
1378 | | /// and JVP generic signature. |
1379 | | static const DifferentiableActivityInfo & |
1380 | | getActivityInfo(ADContext &context, SILFunction *original, |
1381 | 1.35k | const AutoDiffConfig &config, SILFunction *jvp) { |
1382 | | // Get activity info of the original function. |
1383 | 1.35k | auto &passManager = context.getPassManager(); |
1384 | 1.35k | auto *activityAnalysis = |
1385 | 1.35k | passManager.getAnalysis<DifferentiableActivityAnalysis>(); |
1386 | 1.35k | auto &activityCollection = *activityAnalysis->get(original); |
1387 | 1.35k | auto &activityInfo = activityCollection.getActivityInfo( |
1388 | 1.35k | jvp->getLoweredFunctionType()->getSubstGenericSignature(), |
1389 | 1.35k | AutoDiffDerivativeFunctionKind::JVP); |
1390 | 1.35k | LLVM_DEBUG(activityInfo.dump(config, getADDebugStream())); |
1391 | 1.35k | return activityInfo; |
1392 | 1.35k | } |
1393 | | |
1394 | | JVPCloner::Implementation::Implementation(ADContext &context, |
1395 | | SILDifferentiabilityWitness *witness, |
1396 | | SILFunction *jvp, |
1397 | | DifferentiationInvoker invoker) |
1398 | | : TypeSubstCloner(*jvp, *witness->getOriginalFunction(), |
1399 | | getSubstitutionMap(witness->getOriginalFunction(), jvp)), |
1400 | | context(context), original(witness->getOriginalFunction()), |
1401 | | witness(witness), jvp(jvp), invoker(invoker), |
1402 | | activityInfo( |
1403 | | getActivityInfo(context, original, witness->getConfig(), jvp)), |
1404 | | loopInfo(context.getPassManager().getAnalysis<SILLoopAnalysis>() |
1405 | | ->get(original)), |
1406 | | differentialInfo(context, AutoDiffLinearMapKind::Differential, original, |
1407 | | jvp, witness->getConfig(), activityInfo, loopInfo), |
1408 | | differentialBuilder(TangentBuilder( |
1409 | | *createEmptyDifferential(context, witness, &differentialInfo), |
1410 | | context)), |
1411 | 1.35k | diffLocalAllocBuilder(getDifferential(), context) { |
1412 | | // Create empty differential function. |
1413 | 1.35k | context.recordGeneratedFunction(&getDifferential()); |
1414 | 1.35k | } |
1415 | | |
1416 | | JVPCloner::JVPCloner(ADContext &context, SILDifferentiabilityWitness *witness, |
1417 | | SILFunction *jvp, DifferentiationInvoker invoker) |
1418 | 1.35k | : impl(*new Implementation(context, witness, jvp, invoker)) {} |
1419 | | |
1420 | 1.35k | JVPCloner::~JVPCloner() { delete &impl; } |
1421 | | |
1422 | | //--------------------------------------------------------------------------// |
1423 | | // Differential struct mapping |
1424 | | //--------------------------------------------------------------------------// |
1425 | | |
1426 | | void JVPCloner::Implementation::initializeDifferentialTupleElements( |
1427 | 1.35k | SILBasicBlock *origBB, SILInstructionResultArray values) { |
1428 | 1.35k | auto *diffTupleTyple = differentialInfo.getLinearMapTupleType(origBB); |
1429 | 1.35k | assert(diffTupleTyple->getNumElements() == values.size() && |
1430 | 1.35k | "The number of differential tuple fields must equal the number of " |
1431 | 1.35k | "differential struct element values"); |
1432 | 0 | auto res = differentialTupleElements.insert({origBB, values}); |
1433 | 1.35k | (void)res; |
1434 | 1.35k | assert(res.second && "A pullback struct element already exists!"); |
1435 | 1.35k | } |
1436 | | |
1437 | | /// Returns the differential tuple element value corresponding to the given |
1438 | | /// original block and apply inst. |
1439 | 1.60k | SILValue JVPCloner::Implementation::getDifferentialTupleElement(ApplyInst *ai) { |
1440 | 1.60k | unsigned idx = differentialInfo.lookUpLinearMapIndex(ai); |
1441 | 1.60k | assert((idx > 0 || (idx == 0 && ai->getParentBlock()->isEntry())) && |
1442 | 1.60k | "impossible linear map index"); |
1443 | 0 | auto values = differentialTupleElements.lookup(ai->getParentBlock()); |
1444 | 1.60k | assert(idx < values.size() && |
1445 | 1.60k | "differential tuple element for this apply does not exist!"); |
1446 | 0 | return values[idx]; |
1447 | 1.60k | } |
1448 | | |
1449 | | //--------------------------------------------------------------------------// |
1450 | | // Tangent emission helpers |
1451 | | //--------------------------------------------------------------------------// |
1452 | | |
1453 | 1.35k | void JVPCloner::Implementation::prepareForDifferentialGeneration() { |
1454 | | // Create differential blocks and arguments. |
1455 | 1.35k | auto &differential = getDifferential(); |
1456 | 1.35k | auto diffLoc = differential.getLocation(); |
1457 | 1.35k | auto *origEntry = original->getEntryBlock(); |
1458 | 1.35k | auto origFnTy = original->getLoweredFunctionType(); |
1459 | | |
1460 | 1.35k | for (auto &origBB : *original) { |
1461 | 1.35k | auto *diffBB = differential.createBasicBlock(); |
1462 | 1.35k | diffBBMap.insert({&origBB, diffBB}); |
1463 | | // If the BB is the original entry, then the differential block that we |
1464 | | // just created must be the differential function's entry. Create |
1465 | | // differential entry arguments and continue. |
1466 | 1.35k | if (&origBB == origEntry) { |
1467 | 1.35k | assert(diffBB->isEntry()); |
1468 | 0 | createEntryArguments(&differential); |
1469 | 1.35k | auto *lastArg = diffBB->getArguments().back(); |
1470 | 1.35k | #ifndef NDEBUG |
1471 | 1.35k | auto diffTupleLoweredType = remapSILTypeInDifferential( |
1472 | 1.35k | differentialInfo.getLinearMapTupleLoweredType(&origBB)); |
1473 | 1.35k | assert(lastArg->getType() == diffTupleLoweredType); |
1474 | 0 | #endif |
1475 | 0 | differentialStructArguments[&origBB] = lastArg; |
1476 | 1.35k | } |
1477 | | |
1478 | 1.35k | LLVM_DEBUG({ |
1479 | 1.35k | auto &s = getADDebugStream() |
1480 | 1.35k | << "Original bb" + std::to_string(origBB.getDebugID()) |
1481 | 1.35k | << ": To differentiate or not to differentiate?\n"; |
1482 | 1.35k | for (auto &inst : origBB) { |
1483 | 1.35k | s << (differentialInfo.shouldDifferentiateInstruction(&inst) ? "[x] " |
1484 | 1.35k | : "[ ] ") |
1485 | 1.35k | << inst; |
1486 | 1.35k | } |
1487 | 1.35k | }); |
1488 | 1.35k | } |
1489 | | |
1490 | 1.35k | assert(diffBBMap.size() == 1 && |
1491 | 1.35k | "Can only currently handle single basic block functions"); |
1492 | | |
1493 | | // The differential function has type: |
1494 | | // (arg0', ..., argn', entry_df_struct) -> result'. |
1495 | 0 | auto diffParamArgs = |
1496 | 1.35k | differential.getArgumentsWithoutIndirectResults().drop_back(); |
1497 | 1.35k | assert(diffParamArgs.size() == |
1498 | 1.35k | witness->getConfig().parameterIndices->getNumIndices()); |
1499 | 0 | auto origParamArgs = original->getArgumentsWithoutIndirectResults(); |
1500 | | |
1501 | | // TODO(TF-788): Re-enable non-varied result warning. |
1502 | | /* |
1503 | | // Check if result is not varied. |
1504 | | SmallVector<SILValue, 8> origFormalResults; |
1505 | | collectAllFormalResultsInTypeOrder(*original, origFormalResults); |
1506 | | std::get<0>(pair); |
1507 | | for (auto resultIndex : getConfig().results->getIndices()) { |
1508 | | auto origResult = origFormalResults[resultIndex]; |
1509 | | // Emit warning if original result is not varied, because it will always |
1510 | | // have a zero derivative. |
1511 | | if (!activityInfo.isVaried(origResult, getConfig().parameters)) { |
1512 | | // Emit fixit if original result has a valid source location. |
1513 | | auto startLoc = origResult.getLoc().getStartSourceLoc(); |
1514 | | auto endLoc = origResult.getLoc().getEndSourceLoc(); |
1515 | | if (startLoc.isValid() && endLoc.isValid()) { |
1516 | | context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) |
1517 | | .fixItInsert(startLoc, "withoutDerivative(at:") |
1518 | | .fixItInsertAfter(endLoc, ")"); |
1519 | | } |
1520 | | } |
1521 | | } |
1522 | | */ |
1523 | | |
1524 | | // Initialize tangent mapping for parameters. |
1525 | 1.35k | auto diffParamsIt = getConfig().parameterIndices->begin(); |
1526 | 2.00k | for (auto index : range(diffParamArgs.size())) { |
1527 | 2.00k | auto *diffArg = diffParamArgs[index]; |
1528 | 2.00k | auto *origArg = origParamArgs[*diffParamsIt]; |
1529 | 2.00k | ++diffParamsIt; |
1530 | 2.00k | if (diffArg->getType().isAddress()) { |
1531 | 248 | setTangentBuffer(origEntry, origArg, diffArg); |
1532 | 1.75k | } else { |
1533 | 1.75k | setTangentValue(origEntry, origArg, makeConcreteTangentValue(diffArg)); |
1534 | 1.75k | } |
1535 | 2.00k | LLVM_DEBUG(getADDebugStream() |
1536 | 2.00k | << "Assigned parameter " << *diffArg |
1537 | 2.00k | << " as the tangent of original result " << *origArg); |
1538 | 2.00k | } |
1539 | | |
1540 | | // Initialize tangent mapping for original indirect results and non-wrt |
1541 | | // `inout` parameters. The tangent buffers of these address values are |
1542 | | // differential indirect results. |
1543 | | |
1544 | | // Collect original results. |
1545 | 1.35k | SmallVector<SILValue, 2> originalResults; |
1546 | 1.35k | collectAllFormalResultsInTypeOrder(*original, originalResults); |
1547 | | |
1548 | | // Iterate over differentiability results. |
1549 | 1.35k | differentialBuilder.setInsertionPoint(differential.getEntryBlock()); |
1550 | 1.35k | auto diffIndResults = differential.getIndirectResults(); |
1551 | 1.35k | unsigned differentialIndirectResultIndex = 0; |
1552 | 1.36k | for (auto resultIndex : getConfig().resultIndices->getIndices()) { |
1553 | 1.36k | auto origResult = originalResults[resultIndex]; |
1554 | | // Handle original formal indirect result. |
1555 | 1.36k | if (resultIndex < origFnTy->getNumResults()) { |
1556 | | // Skip original direct results. |
1557 | 1.32k | if (origResult->getType().isObject()) |
1558 | 1.18k | continue; |
1559 | 140 | auto diffIndResult = diffIndResults[differentialIndirectResultIndex++]; |
1560 | 140 | setTangentBuffer(origEntry, origResult, diffIndResult); |
1561 | | // If original indirect result is non-varied, zero-initialize its tangent |
1562 | | // buffer. |
1563 | 140 | if (!activityInfo.isVaried(origResult, getConfig().parameterIndices)) |
1564 | 8 | emitZeroIndirect(diffIndResult->getType().getASTType(), diffIndResult, |
1565 | 8 | diffLoc); |
1566 | 140 | continue; |
1567 | 1.32k | } |
1568 | | // Handle original non-wrt `inout` parameter. |
1569 | | // Only original *non-wrt* `inout` parameters have corresponding |
1570 | | // differential indirect results. |
1571 | 40 | auto inoutParamIndex = resultIndex - origFnTy->getNumResults(); |
1572 | 40 | auto inoutParamIt = std::next( |
1573 | 40 | origFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); |
1574 | 40 | auto paramIndex = |
1575 | 40 | std::distance(origFnTy->getParameters().begin(), &*inoutParamIt); |
1576 | 40 | if (getConfig().parameterIndices->contains(paramIndex)) |
1577 | 40 | continue; |
1578 | 0 | auto diffIndResult = diffIndResults[differentialIndirectResultIndex++]; |
1579 | 0 | setTangentBuffer(origEntry, origResult, diffIndResult); |
1580 | | // Original `inout` parameters are initialized, so their tangent buffers |
1581 | | // must also be initialized. |
1582 | 0 | emitZeroIndirect(diffIndResult->getType().getASTType(), diffIndResult, |
1583 | 0 | diffLoc); |
1584 | 0 | } |
1585 | 1.35k | } |
1586 | | |
1587 | | /*static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential( |
1588 | | ADContext &context, SILDifferentiabilityWitness *witness, |
1589 | 1.35k | LinearMapInfo *linearMapInfo) { |
1590 | 1.35k | auto &module = context.getModule(); |
1591 | 1.35k | auto *original = witness->getOriginalFunction(); |
1592 | 1.35k | auto *jvp = witness->getJVP(); |
1593 | 1.35k | auto origTy = original->getLoweredFunctionType(); |
1594 | | // Get witness generic signature for remapping types. |
1595 | | // Witness generic signature may have more requirements than JVP generic |
1596 | | // signature: when witness generic signature has same-type requirements |
1597 | | // binding all generic parameters to concrete types, JVP function type uses |
1598 | | // all the concrete types and JVP generic signature is null. |
1599 | 1.35k | auto witnessCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature(); |
1600 | 1.35k | auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); |
1601 | | |
1602 | | // Parameters of the differential are: |
1603 | | // - the tangent values of the wrt parameters. |
1604 | | // - the differential struct for the original entry. |
1605 | | // Result of the differential is in the tangent space of the original |
1606 | | // result. |
1607 | 1.35k | SmallVector<SILParameterInfo, 8> dfParams; |
1608 | 1.35k | SmallVector<SILResultInfo, 8> dfResults; |
1609 | 1.35k | auto origParams = origTy->getParameters(); |
1610 | 1.35k | auto config = witness->getConfig(); |
1611 | | |
1612 | 1.36k | for (auto resultIndex : config.resultIndices->getIndices()) { |
1613 | 1.36k | if (resultIndex < origTy->getNumResults()) { |
1614 | | // Handle formal original result. |
1615 | 1.32k | auto origResult = origTy->getResults()[resultIndex]; |
1616 | 1.32k | origResult = origResult.getWithInterfaceType( |
1617 | 1.32k | origResult.getInterfaceType()->getReducedType(witnessCanGenSig)); |
1618 | 1.32k | dfResults.push_back( |
1619 | 1.32k | SILResultInfo(origResult.getInterfaceType() |
1620 | 1.32k | ->getAutoDiffTangentSpace(lookupConformance) |
1621 | 1.32k | ->getType() |
1622 | 1.32k | ->getReducedType(witnessCanGenSig), |
1623 | 1.32k | origResult.getConvention())); |
1624 | 1.32k | } else { |
1625 | | // Handle original `inout` parameter. |
1626 | 40 | auto inoutParamIndex = resultIndex - origTy->getNumResults(); |
1627 | 40 | auto inoutParamIt = std::next( |
1628 | 40 | origTy->getIndirectMutatingParameters().begin(), inoutParamIndex); |
1629 | 40 | auto paramIndex = |
1630 | 40 | std::distance(origTy->getParameters().begin(), &*inoutParamIt); |
1631 | | // If the original `inout` parameter is a differentiability parameter, |
1632 | | // then it already has a corresponding differential parameter. Do not add |
1633 | | // a corresponding differential result. |
1634 | 40 | if (config.parameterIndices->contains(paramIndex)) |
1635 | 40 | continue; |
1636 | 0 | auto inoutParam = origTy->getParameters()[paramIndex]; |
1637 | 0 | auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace( |
1638 | 0 | lookupConformance); |
1639 | 0 | assert(paramTan && "Parameter type does not have a tangent space?"); |
1640 | 0 | dfResults.push_back( |
1641 | 0 | {paramTan->getCanonicalType(), ResultConvention::Indirect}); |
1642 | 0 | } |
1643 | 1.36k | } |
1644 | | |
1645 | | // Add differential parameters for the requested wrt parameters. |
1646 | 2.00k | for (auto i : config.parameterIndices->getIndices()) { |
1647 | 2.00k | auto origParam = origParams[i]; |
1648 | 2.00k | origParam = origParam.getWithInterfaceType( |
1649 | 2.00k | origParam.getInterfaceType()->getReducedType(witnessCanGenSig)); |
1650 | 2.00k | dfParams.push_back( |
1651 | 2.00k | SILParameterInfo(origParam.getInterfaceType() |
1652 | 2.00k | ->getAutoDiffTangentSpace(lookupConformance) |
1653 | 2.00k | ->getType() |
1654 | 2.00k | ->getReducedType(witnessCanGenSig), |
1655 | 2.00k | origParam.getConvention())); |
1656 | 2.00k | } |
1657 | | |
1658 | | // Accept a differential struct in the differential parameter list. This is |
1659 | | // the returned differential's closure context. |
1660 | 1.35k | auto *origEntry = original->getEntryBlock(); |
1661 | 1.35k | auto dfTupleType = |
1662 | 1.35k | linearMapInfo->getLinearMapTupleLoweredType(origEntry).getASTType(); |
1663 | 1.35k | dfParams.push_back({dfTupleType, ParameterConvention::Direct_Owned}); |
1664 | | |
1665 | 1.35k | Mangle::DifferentiationMangler mangler; |
1666 | 1.35k | auto diffName = mangler.mangleLinearMap( |
1667 | 1.35k | witness->getOriginalFunction()->getName(), |
1668 | 1.35k | AutoDiffLinearMapKind::Differential, witness->getConfig()); |
1669 | | // Set differential generic signature equal to JVP generic signature. |
1670 | | // Do not use witness generic signature, which may have same-type requirements |
1671 | | // binding all generic parameters to concrete types. |
1672 | 1.35k | auto diffGenericSig = |
1673 | 1.35k | jvp->getLoweredFunctionType()->getSubstGenericSignature(); |
1674 | 1.35k | auto *diffGenericEnv = diffGenericSig.getGenericEnvironment(); |
1675 | 1.35k | auto diffType = SILFunctionType::get( |
1676 | 1.35k | diffGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(), |
1677 | 1.35k | origTy->getCalleeConvention(), dfParams, {}, dfResults, llvm::None, |
1678 | 1.35k | origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(), |
1679 | 1.35k | original->getASTContext()); |
1680 | | |
1681 | 1.35k | SILOptFunctionBuilder fb(context.getTransform()); |
1682 | 1.35k | auto linkage = jvp->isSerialized() ? SILLinkage::Public : SILLinkage::Private; |
1683 | 1.35k | auto *differential = fb.createFunction( |
1684 | 1.35k | linkage, context.getASTContext().getIdentifier(diffName).str(), diffType, |
1685 | 1.35k | diffGenericEnv, original->getLocation(), original->isBare(), |
1686 | 1.35k | IsNotTransparent, jvp->isSerialized(), |
1687 | 1.35k | original->isDynamicallyReplaceable(), |
1688 | 1.35k | original->isDistributed(), |
1689 | 1.35k | original->isRuntimeAccessible()); |
1690 | 1.35k | differential->setDebugScope( |
1691 | 1.35k | new (module) SILDebugScope(original->getLocation(), differential)); |
1692 | | |
1693 | 1.35k | return differential; |
1694 | 1.35k | } |
1695 | | |
1696 | 1.35k | bool JVPCloner::Implementation::run() { |
1697 | 1.35k | PrettyStackTraceSILFunction trace("generating JVP and differential for", |
1698 | 1.35k | original); |
1699 | 1.35k | LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() |
1700 | 1.35k | << " to jvp @" << jvp->getName() << '\n'); |
1701 | | // Create JVP and differential entry and arguments. |
1702 | 1.35k | auto *entry = jvp->createBasicBlock(); |
1703 | 1.35k | createEntryArguments(jvp); |
1704 | 1.35k | prepareForDifferentialGeneration(); |
1705 | | // Clone. |
1706 | 1.35k | SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(), |
1707 | 1.35k | entry->getArguments().end()); |
1708 | 1.35k | cloneFunctionBody(original, entry, entryArgs); |
1709 | 1.35k | emitReturnInstForDifferential(); |
1710 | | // If errors occurred, back out. |
1711 | 1.35k | if (errorOccurred) |
1712 | 20 | return true; |
1713 | 1.33k | LLVM_DEBUG(getADDebugStream() |
1714 | 1.33k | << "Generated JVP for " << original->getName() << ":\n" |
1715 | 1.33k | << *jvp); |
1716 | 1.33k | LLVM_DEBUG(getADDebugStream() |
1717 | 1.33k | << "Generated differential for " << original->getName() << ":\n" |
1718 | 1.33k | << getDifferential()); |
1719 | 1.33k | return errorOccurred; |
1720 | 1.35k | } |
1721 | | |
1722 | | } // end namespace autodiff |
1723 | | } // end namespace swift |
1724 | | |
1725 | 1.35k | bool JVPCloner::run() { |
1726 | 1.35k | bool foundError = impl.run(); |
1727 | 1.35k | #ifndef NDEBUG |
1728 | 1.35k | if (!foundError) |
1729 | 1.33k | getJVP().verify(); |
1730 | 1.35k | #endif |
1731 | 1.35k | return foundError; |
1732 | 1.35k | } |
1733 | | |
1734 | 1.33k | SILFunction &JVPCloner::getJVP() const { return impl.getJVP(); } |