/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // This file defines a helper class for generating pullback functions for |
14 | | // automatic differentiation. |
15 | | // |
16 | | //===----------------------------------------------------------------------===// |
17 | | |
18 | | #include "swift/Basic/STLExtras.h" |
19 | | #define DEBUG_TYPE "differentiation" |
20 | | |
21 | | #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" |
22 | | #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
23 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
24 | | #include "swift/SILOptimizer/Differentiation/AdjointValue.h" |
25 | | #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" |
26 | | #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" |
27 | | #include "swift/SILOptimizer/Differentiation/Thunk.h" |
28 | | #include "swift/SILOptimizer/Differentiation/VJPCloner.h" |
29 | | |
30 | | #include "swift/AST/Expr.h" |
31 | | #include "swift/AST/PropertyWrappers.h" |
32 | | #include "swift/AST/TypeCheckRequests.h" |
33 | | #include "swift/SIL/InstructionUtils.h" |
34 | | #include "swift/SIL/Projection.h" |
35 | | #include "swift/SIL/TypeSubstCloner.h" |
36 | | #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" |
37 | | #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
38 | | #include "llvm/ADT/DenseMap.h" |
39 | | #include "llvm/ADT/SmallSet.h" |
40 | | |
41 | | namespace swift { |
42 | | |
43 | | class SILDifferentiabilityWitness; |
44 | | class SILBasicBlock; |
45 | | class SILFunction; |
46 | | class SILInstruction; |
47 | | |
48 | | namespace autodiff { |
49 | | |
50 | | class ADContext; |
51 | | class VJPCloner; |
52 | | |
53 | | /// The implementation class for `PullbackCloner`. |
54 | | /// |
55 | | /// The implementation class is a `SILInstructionVisitor`. Effectively, it acts |
56 | | /// as a `SILCloner` that visits basic blocks in post-order and that visits |
57 | | /// instructions per basic block in reverse order. This visitation order is |
58 | | /// necessary for generating pullback functions, whose control flow graph is |
59 | | /// ~a transposed version of the original function's control flow graph. |
60 | | class PullbackCloner::Implementation final |
61 | | : public SILInstructionVisitor<PullbackCloner::Implementation> { |
62 | | |
63 | | public: |
64 | | explicit Implementation(VJPCloner &vjpCloner); |
65 | | |
66 | | private: |
67 | | /// The parent VJP cloner. |
68 | | VJPCloner &vjpCloner; |
69 | | |
70 | | /// Dominance info for the original function. |
71 | | DominanceInfo *domInfo = nullptr; |
72 | | |
73 | | /// Post-dominance info for the original function. |
74 | | PostDominanceInfo *postDomInfo = nullptr; |
75 | | |
76 | | /// Post-order info for the original function. |
77 | | PostOrderFunctionInfo *postOrderInfo = nullptr; |
78 | | |
79 | | /// Mapping from original basic blocks to corresponding pullback basic blocks. |
80 | | /// Pullback basic blocks always have the predecessor as the single argument. |
81 | | llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap; |
82 | | |
83 | | /// Mapping from original basic blocks and original values to corresponding |
84 | | /// adjoint values. |
85 | | llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap; |
86 | | |
87 | | /// Mapping from original basic blocks and original values to corresponding |
88 | | /// adjoint buffers. |
89 | | llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap; |
90 | | |
91 | | /// Mapping from pullback struct field declarations to pullback struct |
92 | | /// elements destructured from the linear map basic block argument. In the |
93 | | /// beginning of each pullback basic block, the block's pullback struct is |
94 | | /// destructured into individual elements stored here. |
95 | | llvm::DenseMap<SILBasicBlock*, SmallVector<SILValue, 4>> pullbackTupleElements; |
96 | | |
97 | | /// Mapping from original basic blocks and successor basic blocks to |
98 | | /// corresponding pullback trampoline basic blocks. Trampoline basic blocks |
99 | | /// take additional arguments in addition to the predecessor enum argument. |
100 | | llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *> |
101 | | pullbackTrampolineBBMap; |
102 | | |
103 | | /// Mapping from original basic blocks to dominated active values. |
104 | | llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues; |
105 | | |
106 | | /// Mapping from original basic blocks and original active values to |
107 | | /// corresponding pullback block arguments. |
108 | | llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *> |
109 | | activeValuePullbackBBArgumentMap; |
110 | | |
111 | | /// Mapping from original basic blocks to local temporary values to be cleaned |
112 | | /// up. This is populated when pullback emission is run on one basic block and |
113 | | /// cleaned before processing another basic block. |
114 | | llvm::DenseMap<SILBasicBlock *, llvm::SmallSetVector<SILValue, 32>> |
115 | | blockTemporaries; |
116 | | |
117 | | /// The scope cloner. |
118 | | ScopeCloner scopeCloner; |
119 | | |
120 | | /// The main builder. |
121 | | TangentBuilder builder; |
122 | | |
123 | | /// An auxiliary local allocation builder. |
124 | | TangentBuilder localAllocBuilder; |
125 | | |
126 | | /// The original function's exit block. |
127 | | SILBasicBlock *originalExitBlock = nullptr; |
128 | | |
129 | | /// Stack buffers allocated for storing local adjoint values. |
130 | | SmallVector<AllocStackInst *, 64> functionLocalAllocations; |
131 | | |
132 | | /// A set used to remember local allocations that were destroyed. |
133 | | llvm::SmallDenseSet<SILValue> destroyedLocalAllocations; |
134 | | |
135 | | /// The seed arguments of the pullback function. |
136 | | SmallVector<SILArgument *, 4> seeds; |
137 | | |
138 | | /// The `AutoDiffLinearMapContext` object, if any. |
139 | | SILValue contextValue = nullptr; |
140 | | |
141 | | llvm::BumpPtrAllocator allocator; |
142 | | |
143 | | bool errorOccurred = false; |
144 | | |
145 | 185k | ADContext &getContext() const { return vjpCloner.getContext(); } |
146 | 153k | SILModule &getModule() const { return getContext().getModule(); } |
147 | 7.58k | ASTContext &getASTContext() const { return getPullback().getASTContext(); } |
148 | 175k | SILFunction &getOriginal() const { return vjpCloner.getOriginal(); } |
149 | 144k | SILDifferentiabilityWitness *getWitness() const { |
150 | 144k | return vjpCloner.getWitness(); |
151 | 144k | } |
152 | 3.81k | DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); } |
153 | 134k | LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); } |
154 | 173k | const AutoDiffConfig &getConfig() const { return vjpCloner.getConfig(); } |
155 | 152k | const DifferentiableActivityInfo &getActivityInfo() const { |
156 | 152k | return vjpCloner.getActivityInfo(); |
157 | 152k | } |
158 | | |
159 | | //--------------------------------------------------------------------------// |
160 | | // Pullback struct mapping |
161 | | //--------------------------------------------------------------------------// |
162 | | |
163 | | void initializePullbackTupleElements(SILBasicBlock *origBB, |
164 | 1.85k | SILInstructionResultArray values) { |
165 | 1.85k | auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB); |
166 | 1.85k | assert(pbTupleTyple->getNumElements() == values.size() && |
167 | 1.85k | "The number of pullback tuple fields must equal the number of " |
168 | 1.85k | "pullback tuple element values"); |
169 | 0 | auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }}); |
170 | 1.85k | (void)res; |
171 | 1.85k | assert(res.second && "A pullback tuple element already exists!"); |
172 | 1.85k | } |
173 | | |
174 | | void initializePullbackTupleElements(SILBasicBlock *origBB, |
175 | 4.94k | const llvm::ArrayRef<SILArgument *> &values) { |
176 | 4.94k | auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB); |
177 | 4.94k | assert(pbTupleTyple->getNumElements() == values.size() && |
178 | 4.94k | "The number of pullback tuple fields must equal the number of " |
179 | 4.94k | "pullback tuple element values"); |
180 | 0 | auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }}); |
181 | 4.94k | (void)res; |
182 | 4.94k | assert(res.second && "A pullback struct element already exists!"); |
183 | 4.94k | } |
184 | | |
185 | | /// Returns the pullback tuple element value corresponding to the given |
186 | | /// original block and apply inst. |
187 | 5.65k | SILValue getPullbackTupleElement(ApplyInst *ai) { |
188 | 5.65k | unsigned idx = getPullbackInfo().lookUpLinearMapIndex(ai); |
189 | 5.65k | assert((idx > 0 || (idx == 0 && ai->getParentBlock()->isEntry())) && |
190 | 5.65k | "impossible linear map index"); |
191 | 0 | auto values = pullbackTupleElements.lookup(ai->getParentBlock()); |
192 | 5.65k | assert(idx < values.size() && |
193 | 5.65k | "pullback tuple element for this apply does not exist!"); |
194 | 0 | return values[idx]; |
195 | 5.65k | } |
196 | | |
197 | | /// Returns the pullback tuple element value corresponding to the predecessor |
198 | | /// for the given original block. |
199 | 1.75k | SILValue getPullbackPredTupleElement(SILBasicBlock *origBB) { |
200 | 1.75k | assert(!origBB->isEntry() && "no predecessors for entry block"); |
201 | 0 | auto values = pullbackTupleElements.lookup(origBB); |
202 | 1.75k | assert(values.size() && "pullback tuple cannot be empty"); |
203 | 0 | return values[0]; |
204 | 1.75k | } |
205 | | |
206 | | //--------------------------------------------------------------------------// |
207 | | // Type transformer |
208 | | //--------------------------------------------------------------------------// |
209 | | |
210 | | /// Get the type lowering for the given AST type. |
211 | 61.2k | const Lowering::TypeLowering &getTypeLowering(Type type) { |
212 | 61.2k | auto pbGenSig = |
213 | 61.2k | getPullback().getLoweredFunctionType()->getSubstGenericSignature(); |
214 | 61.2k | Lowering::AbstractionPattern pattern(pbGenSig, |
215 | 61.2k | type->getReducedType(pbGenSig)); |
216 | 61.2k | return getPullback().getTypeLowering(pattern, type); |
217 | 61.2k | } |
218 | | |
219 | | /// Remap any archetypes into the current function's context. |
220 | 191k | SILType remapType(SILType ty) { |
221 | 191k | if (ty.hasArchetype()) |
222 | 15.1k | ty = ty.mapTypeOutOfContext(); |
223 | 191k | auto remappedType = ty.getASTType()->getReducedType( |
224 | 191k | getPullback().getLoweredFunctionType()->getSubstGenericSignature()); |
225 | 191k | auto remappedSILType = |
226 | 191k | SILType::getPrimitiveType(remappedType, ty.getCategory()); |
227 | 191k | return getPullback().mapTypeIntoContext(remappedSILType); |
228 | 191k | } |
229 | | |
230 | 144k | llvm::Optional<TangentSpace> getTangentSpace(CanType type) { |
231 | | // Use witness generic signature to remap types. |
232 | 144k | type = |
233 | 144k | getWitness()->getDerivativeGenericSignature().getReducedType( |
234 | 144k | type); |
235 | 144k | return type->getAutoDiffTangentSpace( |
236 | 144k | LookUpConformanceInModule(getModule().getSwiftModule())); |
237 | 144k | } |
238 | | |
239 | | /// Returns the tangent value category of the given value. |
240 | 124k | SILValueCategory getTangentValueCategory(SILValue v) { |
241 | | // Tangent value category table: |
242 | | // |
243 | | // Let $L be a loadable type and $*A be an address-only type. |
244 | | // |
245 | | // Original type | Tangent type loadable? | Tangent value category and type |
246 | | // --------------|------------------------|-------------------------------- |
247 | | // $L | loadable | object, $L' (no mismatch) |
248 | | // $*A | loadable | address, $*L' (create a buffer) |
249 | | // $L | address-only | address, $*A' (no alternative) |
250 | | // $*A | address-only | address, $*A' (no alternative) |
251 | | |
252 | | // TODO(https://github.com/apple/swift/issues/55523): Make "tangent value category" depend solely on whether the tangent type is loadable or address-only. |
253 | | // |
254 | | // For loadable tangent types, using symbolic adjoint values instead of |
255 | | // concrete adjoint buffers is more efficient. |
256 | | |
257 | | // Quick check: if the value has an address type, the tangent value category |
258 | | // is currently always "address". |
259 | 124k | if (v->getType().isAddress()) |
260 | 63.3k | return SILValueCategory::Address; |
261 | | // If the value has an object type and the tangent type is not address-only, |
262 | | // then the tangent value category is "object". |
263 | 61.1k | auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType()); |
264 | 61.1k | auto tanASTType = tanSpace->getCanonicalType(); |
265 | 61.1k | if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable()) |
266 | 58.7k | return SILValueCategory::Object; |
267 | | // Otherwise, the tangent value category is "address". |
268 | 2.44k | return SILValueCategory::Address; |
269 | 61.1k | } |
270 | | |
271 | | /// Assuming the given type conforms to `Differentiable` after remapping, |
272 | | /// returns the associated tangent space type. |
273 | 58.4k | SILType getRemappedTangentType(SILType type) { |
274 | 58.4k | return SILType::getPrimitiveType( |
275 | 58.4k | getTangentSpace(remapType(type).getASTType())->getCanonicalType(), |
276 | 58.4k | type.getCategory()); |
277 | 58.4k | } |
278 | | |
279 | | /// Substitutes all replacement types of the given substitution map using the |
280 | | /// pullback function's substitution map. |
281 | 1.46k | SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) { |
282 | 1.46k | return substMap.subst(getPullback().getForwardingSubstitutionMap()); |
283 | 1.46k | } |
284 | | |
285 | | //--------------------------------------------------------------------------// |
286 | | // Temporary value management |
287 | | //--------------------------------------------------------------------------// |
288 | | |
289 | | /// Record a temporary value for cleanup before its block's terminator. |
290 | 18.0k | SILValue recordTemporary(SILValue value) { |
291 | 18.0k | assert(value->getType().isObject()); |
292 | 0 | assert(value->getFunction() == &getPullback()); |
293 | 0 | auto inserted = blockTemporaries[value->getParentBlock()].insert(value); |
294 | 18.0k | (void)inserted; |
295 | 18.0k | LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value); |
296 | 18.0k | assert(inserted && "Temporary already recorded?"); |
297 | 0 | return value; |
298 | 18.0k | } |
299 | | |
300 | | /// Clean up all temporary values for the given pullback block. |
301 | 6.74k | void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) { |
302 | 6.74k | assert(bb->getParent() == &getPullback()); |
303 | 6.74k | LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb" |
304 | 6.74k | << bb->getDebugID() << '\n'); |
305 | 6.74k | for (auto temp : blockTemporaries[bb]) |
306 | 18.4k | builder.emitDestroyValueOperation(loc, temp); |
307 | 6.74k | blockTemporaries[bb].clear(); |
308 | 6.74k | } |
309 | | |
310 | | //--------------------------------------------------------------------------// |
311 | | // Adjoint value factory methods |
312 | | //--------------------------------------------------------------------------// |
313 | | |
314 | 19.1k | AdjointValue makeZeroAdjointValue(SILType type) { |
315 | 19.1k | return AdjointValue::createZero(allocator, remapType(type)); |
316 | 19.1k | } |
317 | | |
318 | 23.1k | AdjointValue makeConcreteAdjointValue(SILValue value) { |
319 | 23.1k | return AdjointValue::createConcrete(allocator, value); |
320 | 23.1k | } |
321 | | |
322 | | AdjointValue makeAggregateAdjointValue(SILType type, |
323 | 336 | ArrayRef<AdjointValue> elements) { |
324 | 336 | return AdjointValue::createAggregate(allocator, remapType(type), elements); |
325 | 336 | } |
326 | | |
327 | | AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint, |
328 | | AdjointValue eltToAdd, |
329 | 860 | FieldLocator fieldLocator) { |
330 | 860 | auto *addElementValue = |
331 | 860 | new AddElementValue(baseAdjoint, eltToAdd, fieldLocator); |
332 | 860 | return AdjointValue::createAddElement(allocator, baseAdjoint.getType(), |
333 | 860 | addElementValue); |
334 | 860 | } |
335 | | |
336 | | //--------------------------------------------------------------------------// |
337 | | // Adjoint value materialization |
338 | | //--------------------------------------------------------------------------// |
339 | | |
340 | | /// Materializes an adjoint value. The type of the given adjoint value must be |
341 | | /// loadable. |
342 | 17.6k | SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) { |
343 | 17.6k | assert(val.getType().isObject()); |
344 | 17.6k | LLVM_DEBUG(getADDebugStream() |
345 | 17.6k | << "Materializing adjoint for " << val << '\n'); |
346 | 17.6k | SILValue result; |
347 | 17.6k | switch (val.getKind()) { |
348 | 2.36k | case AdjointValueKind::Zero: |
349 | 2.36k | result = recordTemporary(builder.emitZero(loc, val.getSwiftType())); |
350 | 2.36k | break; |
351 | 112 | case AdjointValueKind::Aggregate: { |
352 | 112 | SmallVector<SILValue, 8> elements; |
353 | 176 | for (auto i : range(val.getNumAggregateElements())) { |
354 | 176 | auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc); |
355 | 176 | elements.push_back(builder.emitCopyValueOperation(loc, eltVal)); |
356 | 176 | } |
357 | 112 | if (val.getType().is<TupleType>()) |
358 | 0 | result = recordTemporary( |
359 | 0 | builder.createTuple(loc, val.getType(), elements)); |
360 | 112 | else |
361 | 112 | result = recordTemporary( |
362 | 112 | builder.createStruct(loc, val.getType(), elements)); |
363 | 112 | break; |
364 | 0 | } |
365 | 14.7k | case AdjointValueKind::Concrete: |
366 | 14.7k | result = val.getConcreteValue(); |
367 | 14.7k | break; |
368 | 384 | case AdjointValueKind::AddElement: { |
369 | 384 | auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType(); |
370 | 384 | auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType); |
371 | 384 | materializeAdjointIndirect(val, baseAdjAlloc, loc); |
372 | | |
373 | 384 | auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation( |
374 | 384 | loc, baseAdjAlloc, LoadOwnershipQualifier::Take)); |
375 | | |
376 | 384 | builder.createDeallocStack(loc, baseAdjAlloc); |
377 | | |
378 | 384 | result = baseAdjConcrete; |
379 | 384 | break; |
380 | 0 | } |
381 | 17.6k | } |
382 | 17.6k | if (auto debugInfo = val.getDebugInfo()) |
383 | 6.69k | builder.createDebugValue( |
384 | 6.69k | debugInfo->first.getLocation(), result, debugInfo->second); |
385 | 17.6k | return result; |
386 | 17.6k | } |
387 | | |
388 | | /// Materializes an adjoint value indirectly to a SIL buffer. |
389 | | void materializeAdjointIndirect(AdjointValue val, SILValue destAddress, |
390 | 768 | SILLocation loc) { |
391 | 768 | assert(destAddress->getType().isAddress()); |
392 | 0 | switch (val.getKind()) { |
393 | | /// If adjoint value is a symbolic zero, emit a call to |
394 | | /// `AdditiveArithmetic.zero`. |
395 | 328 | case AdjointValueKind::Zero: |
396 | 328 | builder.emitZeroIntoBuffer(loc, destAddress, IsInitialization); |
397 | 328 | break; |
398 | | /// If adjoint value is a symbolic aggregate (tuple or struct), recursively |
399 | | /// materialize the symbolic tuple or struct, filling the |
400 | | /// buffer. |
401 | 0 | case AdjointValueKind::Aggregate: { |
402 | 0 | if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) { |
403 | 0 | for (auto idx : range(val.getNumAggregateElements())) { |
404 | 0 | auto eltTy = SILType::getPrimitiveAddressType( |
405 | 0 | tupTy->getElementType(idx)->getCanonicalType()); |
406 | 0 | auto *eltBuf = |
407 | 0 | builder.createTupleElementAddr(loc, destAddress, idx, eltTy); |
408 | 0 | materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc); |
409 | 0 | } |
410 | 0 | } else if (auto *structDecl = |
411 | 0 | val.getSwiftType()->getStructOrBoundGenericStruct()) { |
412 | 0 | auto fieldIt = structDecl->getStoredProperties().begin(); |
413 | 0 | for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end(); |
414 | 0 | ++fieldIt, ++i) { |
415 | 0 | auto eltBuf = |
416 | 0 | builder.createStructElementAddr(loc, destAddress, *fieldIt); |
417 | 0 | materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc); |
418 | 0 | } |
419 | 0 | } else { |
420 | 0 | llvm_unreachable("Not an aggregate type"); |
421 | 0 | } |
422 | 0 | break; |
423 | 0 | } |
424 | | /// If adjoint value is concrete, it is already materialized. Store it in |
425 | | /// the destination address. |
426 | 56 | case AdjointValueKind::Concrete: { |
427 | 56 | auto concreteVal = val.getConcreteValue(); |
428 | 56 | auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal); |
429 | 56 | builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress, |
430 | 56 | StoreOwnershipQualifier::Init); |
431 | 56 | break; |
432 | 0 | } |
433 | 384 | case AdjointValueKind::AddElement: { |
434 | 384 | auto baseAdjoint = val; |
435 | 384 | auto baseAdjointType = baseAdjoint.getType(); |
436 | | |
437 | | // Current adjoint may be made up of layers of `AddElement` adjoints. |
438 | | // We can iteratively gather the list of elements to add instead of making |
439 | | // recursive calls to `materializeAdjointIndirect`. |
440 | 384 | SmallVector<AddElementValue *, 4> addEltAdjValues; |
441 | | |
442 | 524 | do { |
443 | 524 | auto addElementValue = baseAdjoint.getAddElementValue(); |
444 | 524 | addEltAdjValues.push_back(addElementValue); |
445 | 524 | baseAdjoint = addElementValue->baseAdjoint; |
446 | 524 | assert(baseAdjointType == baseAdjoint.getType()); |
447 | 524 | } while (baseAdjoint.getKind() == AdjointValueKind::AddElement); |
448 | | |
449 | 0 | materializeAdjointIndirect(baseAdjoint, destAddress, loc); |
450 | | |
451 | 524 | for (auto *addElementValue : addEltAdjValues) { |
452 | 524 | auto eltToAdd = addElementValue->eltToAdd; |
453 | | |
454 | 524 | SILValue baseAdjEltAddr; |
455 | 524 | if (baseAdjoint.getType().is<TupleType>()) { |
456 | 16 | baseAdjEltAddr = builder.createTupleElementAddr( |
457 | 16 | loc, destAddress, addElementValue->getFieldIndex()); |
458 | 508 | } else { |
459 | 508 | baseAdjEltAddr = builder.createStructElementAddr( |
460 | 508 | loc, destAddress, addElementValue->getFieldDecl()); |
461 | 508 | } |
462 | | |
463 | 524 | auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc); |
464 | | // Copy `eltToAddMaterialized` so we have a value with owned ownership |
465 | | // semantics, required for using `eltToAddMaterialized` in a `store` |
466 | | // instruction. |
467 | 524 | auto eltToAddMaterializedCopy = |
468 | 524 | builder.emitCopyValueOperation(loc, eltToAddMaterialized); |
469 | 524 | auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType()); |
470 | 524 | builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy, |
471 | 524 | eltToAddAlloc, |
472 | 524 | StoreOwnershipQualifier::Init); |
473 | | |
474 | 524 | builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc); |
475 | 524 | builder.createDestroyAddr(loc, eltToAddAlloc); |
476 | 524 | builder.createDeallocStack(loc, eltToAddAlloc); |
477 | 524 | } |
478 | | |
479 | 384 | break; |
480 | 0 | } |
481 | 768 | } |
482 | 768 | } |
483 | | |
484 | | //--------------------------------------------------------------------------// |
485 | | // Adjoint value mapping |
486 | | //--------------------------------------------------------------------------// |
487 | | |
488 | | /// Returns true if the given value in the original function has a |
489 | | /// corresponding adjoint value. |
490 | 5.48k | bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const { |
491 | 5.48k | assert(origBB->getParent() == &getOriginal()); |
492 | 0 | assert(originalValue->getType().isObject()); |
493 | 0 | return valueMap.count({origBB, originalValue}); |
494 | 5.48k | } |
495 | | |
496 | | /// Initializes the adjoint value for the original value. Asserts that the |
497 | | /// original value does not already have an adjoint value. |
498 | | void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue, |
499 | 11.4k | AdjointValue adjointValue) { |
500 | 11.4k | LLVM_DEBUG(getADDebugStream() |
501 | 11.4k | << "Setting adjoint value for " << originalValue); |
502 | 11.4k | assert(origBB->getParent() == &getOriginal()); |
503 | 0 | assert(originalValue->getType().isObject()); |
504 | 0 | assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); |
505 | 0 | assert(adjointValue.getType().isObject()); |
506 | 0 | assert(originalValue->getFunction() == &getOriginal()); |
507 | | // The adjoint value must be in the tangent space. |
508 | 0 | assert(adjointValue.getType() == |
509 | 11.4k | getRemappedTangentType(originalValue->getType())); |
510 | | // Try to assign a debug variable. |
511 | 11.4k | if (auto debugInfo = findDebugLocationAndVariable(originalValue)) { |
512 | 5.53k | LLVM_DEBUG({ |
513 | 5.53k | auto &s = getADDebugStream(); |
514 | 5.53k | s << "Found debug variable: \"" << debugInfo->second.Name |
515 | 5.53k | << "\"\nLocation: "; |
516 | 5.53k | debugInfo->first.getLocation().print(s, getASTContext().SourceMgr); |
517 | 5.53k | s << '\n'; |
518 | 5.53k | }); |
519 | 5.53k | adjointValue.setDebugInfo(*debugInfo); |
520 | 5.91k | } else { |
521 | 5.91k | LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n"); |
522 | 5.91k | } |
523 | | // Insert into dictionary. |
524 | 11.4k | auto insertion = |
525 | 11.4k | valueMap.try_emplace({origBB, originalValue}, adjointValue); |
526 | 11.4k | LLVM_DEBUG(getADDebugStream() |
527 | 11.4k | << "The new adjoint value, replacing the existing one, is: " |
528 | 11.4k | << insertion.first->getSecond() << '\n'); |
529 | 11.4k | if (!insertion.second) |
530 | 4.49k | insertion.first->getSecond() = adjointValue; |
531 | 11.4k | } |
532 | | |
533 | | /// Returns the adjoint value for a value in the original function. |
534 | | /// |
535 | | /// This method first tries to find an existing entry in the adjoint value |
536 | | /// mapping. If no entry exists, creates a zero adjoint value. |
537 | 18.5k | AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) { |
538 | 18.5k | assert(origBB->getParent() == &getOriginal()); |
539 | 0 | assert(originalValue->getType().isObject()); |
540 | 0 | assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); |
541 | 0 | assert(originalValue->getFunction() == &getOriginal()); |
542 | 0 | auto insertion = valueMap.try_emplace( |
543 | 18.5k | {origBB, originalValue}, |
544 | 18.5k | makeZeroAdjointValue(getRemappedTangentType(originalValue->getType()))); |
545 | 18.5k | auto it = insertion.first; |
546 | 18.5k | return it->getSecond(); |
547 | 18.5k | } |
548 | | |
549 | | /// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets |
550 | | /// the sum as the new adjoint value. |
551 | | void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue, |
552 | 13.4k | AdjointValue newAdjointValue, SILLocation loc) { |
553 | 13.4k | assert(origBB->getParent() == &getOriginal()); |
554 | 0 | assert(originalValue->getType().isObject()); |
555 | 0 | assert(newAdjointValue.getType().isObject()); |
556 | 0 | assert(originalValue->getFunction() == &getOriginal()); |
557 | 13.4k | LLVM_DEBUG(getADDebugStream() |
558 | 13.4k | << "Adding adjoint value for " << originalValue); |
559 | | // The adjoint value must be in the tangent space. |
560 | 13.4k | assert(newAdjointValue.getType() == |
561 | 13.4k | getRemappedTangentType(originalValue->getType())); |
562 | | // Try to assign a debug variable. |
563 | 13.4k | if (auto debugInfo = findDebugLocationAndVariable(originalValue)) { |
564 | 6.15k | LLVM_DEBUG({ |
565 | 6.15k | auto &s = getADDebugStream(); |
566 | 6.15k | s << "Found debug variable: \"" << debugInfo->second.Name |
567 | 6.15k | << "\"\nLocation: "; |
568 | 6.15k | debugInfo->first.getLocation().print(s, getASTContext().SourceMgr); |
569 | 6.15k | s << '\n'; |
570 | 6.15k | }); |
571 | 6.15k | newAdjointValue.setDebugInfo(*debugInfo); |
572 | 7.26k | } else { |
573 | 7.26k | LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n"); |
574 | 7.26k | } |
575 | 13.4k | auto insertion = |
576 | 13.4k | valueMap.try_emplace({origBB, originalValue}, newAdjointValue); |
577 | 13.4k | auto inserted = insertion.second; |
578 | 13.4k | if (inserted) |
579 | 10.7k | return; |
580 | | // If adjoint already exists, accumulate the adjoint onto the existing |
581 | | // adjoint. |
582 | 2.63k | auto it = insertion.first; |
583 | 2.63k | auto existingValue = it->getSecond(); |
584 | 2.63k | valueMap.erase(it); |
585 | 2.63k | auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc); |
586 | | // If the original value is the `Array` result of an |
587 | | // `array.uninitialized_intrinsic` application, accumulate adjoint buffers |
588 | | // for the array element addresses. |
589 | 2.63k | accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal, |
590 | 2.63k | loc); |
591 | 2.63k | setAdjointValue(origBB, originalValue, adjVal); |
592 | 2.63k | } |
593 | | |
594 | | /// Get the pullback block argument corresponding to the given original block |
595 | | /// and active value. |
596 | | SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, |
597 | 3.81k | SILValue activeValue) { |
598 | 3.81k | assert(getTangentValueCategory(activeValue) == SILValueCategory::Object); |
599 | 0 | assert(origBB->getParent() == &getOriginal()); |
600 | 0 | auto pullbackBBArg = |
601 | 3.81k | activeValuePullbackBBArgumentMap[{origBB, activeValue}]; |
602 | 3.81k | assert(pullbackBBArg); |
603 | 0 | assert(pullbackBBArg->getParent() == getPullbackBlock(origBB)); |
604 | 0 | return pullbackBBArg; |
605 | 3.81k | } |
606 | | |
607 | | //--------------------------------------------------------------------------// |
608 | | // Adjoint value accumulation |
609 | | //--------------------------------------------------------------------------// |
610 | | |
611 | | /// Given two adjoint values, accumulates them and returns their sum. |
612 | | AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, |
613 | | SILLocation loc); |
614 | | |
615 | | //--------------------------------------------------------------------------// |
616 | | // Adjoint buffer mapping |
617 | | //--------------------------------------------------------------------------// |
618 | | |
619 | | /// If the given original value is an address projection, returns a |
620 | | /// corresponding adjoint projection to be used as its adjoint buffer. |
621 | | /// |
622 | | /// Helper function for `getAdjointBuffer`. |
623 | | SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue); |
624 | | |
625 | | /// Returns the adjoint buffer for the original value. |
626 | | /// |
627 | | /// This method first tries to find an existing entry in the adjoint buffer |
628 | | /// mapping. If no entry exists, creates a zero adjoint buffer. |
629 | 40.5k | SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) { |
630 | 40.5k | assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); |
631 | 0 | assert(originalValue->getFunction() == &getOriginal()); |
632 | 0 | auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue()); |
633 | 40.5k | if (!insertion.second) // not inserted |
634 | 23.7k | return insertion.first->getSecond(); |
635 | | |
636 | | // If the original buffer is a projection, return a corresponding projection |
637 | | // into the adjoint buffer. |
638 | 16.7k | if (auto adjProj = getAdjointProjection(origBB, originalValue)) |
639 | 6.65k | return (bufferMap[{origBB, originalValue}] = adjProj); |
640 | | |
641 | 10.0k | LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for " |
642 | 10.0k | << originalValue |
643 | 10.0k | << "in bb" << origBB->getDebugID() << '\n'); |
644 | | |
645 | 10.0k | auto bufType = getRemappedTangentType(originalValue->getType()); |
646 | | // Set insertion point for local allocation builder: before the last local |
647 | | // allocation, or at the start of the pullback function's entry if no local |
648 | | // allocations exist yet. |
649 | 10.0k | auto debugInfo = findDebugLocationAndVariable(originalValue); |
650 | 10.0k | SILLocation loc = debugInfo ? debugInfo->first.getLocation() |
651 | 10.0k | : RegularLocation::getAutoGeneratedLocation(); |
652 | 10.0k | llvm::SmallString<32> adjName; |
653 | 10.0k | auto *newBuf = createFunctionLocalAllocation( |
654 | 10.0k | bufType, loc, /*zeroInitialize*/ true, |
655 | 10.0k | swift::transform(debugInfo, |
656 | 10.0k | [&](AdjointValue::DebugInfo di) { |
657 | 5.32k | llvm::raw_svector_ostream adjNameStream(adjName); |
658 | 5.32k | SILDebugVariable &dv = di.second; |
659 | 5.32k | dv.ArgNo = 0; |
660 | 5.32k | adjNameStream << "derivative of '" << dv.Name << "'"; |
661 | 5.32k | if (SILDebugLocation origBBLoc = origBB->front().getDebugLocation()) { |
662 | 5.32k | adjNameStream << " in scope at "; |
663 | 5.32k | origBBLoc.getLocation().print(adjNameStream, getASTContext().SourceMgr); |
664 | 5.32k | } |
665 | 5.32k | adjNameStream << " (scope #" << origBB->getDebugID() << ")"; |
666 | 5.32k | dv.Name = adjName; |
667 | 5.32k | return dv; |
668 | 5.32k | })); |
669 | 10.0k | return (insertion.first->getSecond() = newBuf); |
670 | 16.7k | } |
671 | | |
672 | | /// Initializes the adjoint buffer for the original value. Asserts that the |
673 | | /// original value does not already have an adjoint buffer. |
674 | | void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, |
675 | 1.74k | SILValue adjointBuffer) { |
676 | 1.74k | assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); |
677 | 0 | auto insertion = |
678 | 1.74k | bufferMap.try_emplace({origBB, originalValue}, adjointBuffer); |
679 | 1.74k | assert(insertion.second && "Adjoint buffer already exists"); |
680 | 0 | (void)insertion; |
681 | 1.74k | } |
682 | | |
683 | | /// Accumulates `rhsAddress` into the adjoint buffer corresponding to the |
684 | | /// original value. |
685 | | void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, |
686 | 7.48k | SILValue rhsAddress, SILLocation loc) { |
687 | 7.48k | assert(getTangentValueCategory(originalValue) == |
688 | 7.48k | SILValueCategory::Address && |
689 | 7.48k | rhsAddress->getType().isAddress()); |
690 | 0 | assert(originalValue->getFunction() == &getOriginal()); |
691 | 0 | assert(rhsAddress->getFunction() == &getPullback()); |
692 | 0 | auto adjointBuffer = getAdjointBuffer(origBB, originalValue); |
693 | | |
694 | 7.48k | LLVM_DEBUG(getADDebugStream() << "Adding" |
695 | 7.48k | << rhsAddress << "to adjoint (" |
696 | 7.48k | << adjointBuffer << ") of " |
697 | 7.48k | << originalValue |
698 | 7.48k | << "in bb" << origBB->getDebugID() << '\n'); |
699 | | |
700 | 7.48k | builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress); |
701 | 7.48k | } |
702 | | |
703 | | /// Returns a next insertion point for creating a local allocation: either |
704 | | /// before the previous local allocation, or at the start of the pullback |
705 | | /// entry if no local allocations exist. |
706 | | /// |
707 | | /// Helper for `createFunctionLocalAllocation`. |
708 | 17.2k | SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() { |
709 | | // If there are no local allocations, insert at the pullback entry start. |
710 | 17.2k | if (functionLocalAllocations.empty()) |
711 | 7.60k | return getPullback().getEntryBlock()->begin(); |
712 | | // Otherwise, insert before the last local allocation. Inserting before |
713 | | // rather than after ensures that allocation and zero initialization |
714 | | // instructions are grouped together. |
715 | 9.61k | auto lastLocalAlloc = functionLocalAllocations.back(); |
716 | 9.61k | return lastLocalAlloc->getDefiningInstruction()->getIterator(); |
717 | 17.2k | } |
718 | | |
719 | | /// Creates and returns a local allocation with the given type. |
720 | | /// |
721 | | /// Local allocations are created uninitialized in the pullback entry and |
722 | | /// deallocated in the pullback exit. All local allocations not in |
723 | | /// `destroyedLocalAllocations` are also destroyed in the pullback exit. |
724 | | /// |
725 | | /// Helper for `getAdjointBuffer`. |
726 | | AllocStackInst *createFunctionLocalAllocation( |
727 | | SILType type, SILLocation loc, bool zeroInitialize = false, |
728 | 12.1k | llvm::Optional<SILDebugVariable> varInfo = llvm::None) { |
729 | | // Set insertion point for local allocation builder: before the last local |
730 | | // allocation, or at the start of the pullback function's entry if no local |
731 | | // allocations exist yet. |
732 | 12.1k | localAllocBuilder.setInsertionPoint( |
733 | 12.1k | getPullback().getEntryBlock(), |
734 | 12.1k | getNextFunctionLocalAllocationInsertionPoint()); |
735 | | // Create and return local allocation. |
736 | 12.1k | auto *alloc = localAllocBuilder.createAllocStack(loc, type, varInfo); |
737 | 12.1k | functionLocalAllocations.push_back(alloc); |
738 | | // Zero-initialize if requested. |
739 | 12.1k | if (zeroInitialize) |
740 | 10.6k | localAllocBuilder.emitZeroIntoBuffer(loc, alloc, IsInitialization); |
741 | 12.1k | return alloc; |
742 | 12.1k | } |
743 | | |
744 | | //--------------------------------------------------------------------------// |
745 | | // Optional differentiation |
746 | | //--------------------------------------------------------------------------// |
747 | | |
748 | | /// Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>` |
749 | | /// type, creates an `Optional<T>.TangentVector` buffer from it. |
750 | | /// |
751 | | /// `wrappedAdjoint` may be an object or address value, both cases are |
752 | | /// handled. |
753 | | AllocStackInst *createOptionalAdjoint(SILBasicBlock *bb, |
754 | | SILValue wrappedAdjoint, |
755 | | SILType optionalTy); |
756 | | |
757 | | /// Accumulate optional buffer from `wrappedAdjoint`. |
758 | | void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb, |
759 | | SILValue optionalBuffer, |
760 | | SILValue wrappedAdjoint); |
761 | | |
762 | | /// Set optional value from `wrappedAdjoint`. |
763 | | void setAdjointValueForOptional(SILBasicBlock *bb, SILValue optionalValue, |
764 | | SILValue wrappedAdjoint); |
765 | | |
766 | | //--------------------------------------------------------------------------// |
767 | | // Array literal initialization differentiation |
768 | | //--------------------------------------------------------------------------// |
769 | | |
770 | | /// Given the adjoint value of an array initialized from an |
771 | | /// `array.uninitialized_intrinsic` application and an array element index, |
772 | | /// returns an `alloc_stack` containing the adjoint value of the array element |
773 | | /// at the given index by applying `Array.TangentVector.subscript`. |
774 | | AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint, |
775 | | int eltIndex, SILLocation loc); |
776 | | |
777 | | /// Given the adjoint value of an array initialized from an |
778 | | /// `array.uninitialized_intrinsic` application, accumulates the adjoint |
779 | | /// value's elements into the adjoint buffers of its element addresses. |
780 | | void accumulateArrayLiteralElementAddressAdjoints( |
781 | | SILBasicBlock *origBB, SILValue originalValue, |
782 | | AdjointValue arrayAdjointValue, SILLocation loc); |
783 | | |
784 | | //--------------------------------------------------------------------------// |
785 | | // CFG mapping |
786 | | //--------------------------------------------------------------------------// |
787 | | |
788 | 18.1k | SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) { |
789 | 18.1k | return pullbackBBMap.lookup(originalBlock); |
790 | 18.1k | } |
791 | | |
792 | | SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock, |
793 | 2.35k | SILBasicBlock *successorBlock) { |
794 | 2.35k | return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock}); |
795 | 2.35k | } |
796 | | |
797 | | //--------------------------------------------------------------------------// |
798 | | // Debug info |
799 | | //--------------------------------------------------------------------------// |
800 | | |
801 | 60.5k | const SILDebugScope *remapScope(const SILDebugScope *DS) { |
802 | 60.5k | return scopeCloner.getOrCreateClonedScope(DS); |
803 | 60.5k | } |
804 | | |
805 | | //--------------------------------------------------------------------------// |
806 | | // Debugging utilities |
807 | | //--------------------------------------------------------------------------// |
808 | | |
809 | 0 | void printAdjointValueMapping() { |
810 | 0 | // Group original/adjoint values by basic block. |
811 | 0 | llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp; |
812 | 0 | for (auto pair : valueMap) { |
813 | 0 | auto origPair = pair.first; |
814 | 0 | auto *origBB = origPair.first; |
815 | 0 | auto origValue = origPair.second; |
816 | 0 | auto adjValue = pair.second; |
817 | 0 | tmp[origBB].insert({origValue, adjValue}); |
818 | 0 | } |
819 | 0 | // Print original/adjoint values per basic block. |
820 | 0 | auto &s = getADDebugStream() << "Adjoint value mapping:\n"; |
821 | 0 | for (auto &origBB : getOriginal()) { |
822 | 0 | if (!pullbackBBMap.count(&origBB)) |
823 | 0 | continue; |
824 | 0 | auto bbValueMap = tmp[&origBB]; |
825 | 0 | s << "bb" << origBB.getDebugID(); |
826 | 0 | s << " (size " << bbValueMap.size() << "):\n"; |
827 | 0 | for (auto valuePair : bbValueMap) { |
828 | 0 | auto origValue = valuePair.first; |
829 | 0 | auto adjValue = valuePair.second; |
830 | 0 | s << "ORIG: " << origValue; |
831 | 0 | s << "ADJ: " << adjValue << '\n'; |
832 | 0 | } |
833 | 0 | s << '\n'; |
834 | 0 | } |
835 | 0 | } |
836 | | |
837 | 0 | void printAdjointBufferMapping() { |
838 | 0 | // Group original/adjoint buffers by basic block. |
839 | 0 | llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp; |
840 | 0 | for (auto pair : bufferMap) { |
841 | 0 | auto origPair = pair.first; |
842 | 0 | auto *origBB = origPair.first; |
843 | 0 | auto origBuf = origPair.second; |
844 | 0 | auto adjBuf = pair.second; |
845 | 0 | tmp[origBB][origBuf] = adjBuf; |
846 | 0 | } |
847 | 0 | // Print original/adjoint buffers per basic block. |
848 | 0 | auto &s = getADDebugStream() << "Adjoint buffer mapping:\n"; |
849 | 0 | for (auto &origBB : getOriginal()) { |
850 | 0 | if (!pullbackBBMap.count(&origBB)) |
851 | 0 | continue; |
852 | 0 | auto bbBufferMap = tmp[&origBB]; |
853 | 0 | s << "bb" << origBB.getDebugID(); |
854 | 0 | s << " (size " << bbBufferMap.size() << "):\n"; |
855 | 0 | for (auto valuePair : bbBufferMap) { |
856 | 0 | auto origBuf = valuePair.first; |
857 | 0 | auto adjBuf = valuePair.second; |
858 | 0 | s << "ORIG: " << origBuf; |
859 | 0 | s << "ADJ: " << adjBuf << '\n'; |
860 | 0 | } |
861 | 0 | s << '\n'; |
862 | 0 | } |
863 | 0 | } |
864 | | |
865 | | public: |
866 | | //--------------------------------------------------------------------------// |
867 | | // Entry point |
868 | | //--------------------------------------------------------------------------// |
869 | | |
870 | | /// Performs pullback generation on the empty pullback function. Returns true |
871 | | /// if any error occurs. |
872 | | bool run(); |
873 | | |
874 | | /// Performs pullback generation on the empty pullback function, given that |
875 | | /// the original function is a "semantic member accessor". |
876 | | /// |
877 | | /// "Semantic member accessors" are attached to member properties that have a |
878 | | /// corresponding tangent stored property in the parent `TangentVector` type. |
879 | | /// These accessors have special-case pullback generation based on their |
880 | | /// semantic behavior. |
881 | | /// |
882 | | /// Returns true if any error occurs. |
883 | | bool runForSemanticMemberAccessor(); |
884 | | bool runForSemanticMemberGetter(); |
885 | | bool runForSemanticMemberSetter(); |
886 | | |
887 | | /// If original result is non-varied, it will always have a zero derivative. |
888 | | /// Skip full pullback generation and simply emit zero derivatives for wrt |
889 | | /// parameters. |
890 | | void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult); |
891 | | |
892 | | /// Public helper so that our users can get the underlying newly created |
893 | | /// function. |
894 | 615k | SILFunction &getPullback() const { return vjpCloner.getPullback(); } |
895 | | |
896 | | using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>; |
897 | | |
898 | | /// Determines the pullback successor block for a given original block and one |
899 | | /// of its predecessors. When a trampoline block is necessary, emits code into |
900 | | /// the trampoline block to trampoline the original block's active value's |
901 | | /// adjoint values. |
902 | | /// |
903 | | /// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint |
904 | | /// values to the pullback successor blocks in which they are used. This |
905 | | /// allows us to release those values in pullback successor blocks that do not |
906 | | /// use them. |
907 | | SILBasicBlock * |
908 | | buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB, |
909 | | llvm::SmallDenseMap<SILValue, TrampolineBlockSet> |
910 | | &pullbackTrampolineBlockMap); |
911 | | |
912 | | /// Emits pullback code in the corresponding pullback block. |
913 | | void visitSILBasicBlock(SILBasicBlock *bb); |
914 | | |
915 | 34.6k | void visit(SILInstruction *inst) { |
916 | 34.6k | if (errorOccurred) |
917 | 0 | return; |
918 | | |
919 | 34.6k | LLVM_DEBUG(getADDebugStream() |
920 | 34.6k | << "PullbackCloner visited:\n[ORIG]" << *inst); |
921 | 34.6k | #ifndef NDEBUG |
922 | 34.6k | auto beforeInsertion = std::prev(builder.getInsertionPoint()); |
923 | 34.6k | #endif |
924 | 34.6k | SILInstructionVisitor::visit(inst); |
925 | 34.6k | LLVM_DEBUG({ |
926 | 34.6k | auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback (pb bb" << |
927 | 34.6k | builder.getInsertionBB()->getDebugID() << "):\n"; |
928 | 34.6k | auto afterInsertion = builder.getInsertionPoint(); |
929 | 34.6k | for (auto it = ++beforeInsertion; it != afterInsertion; ++it) |
930 | 34.6k | s << *it; |
931 | 34.6k | }); |
932 | 34.6k | } |
933 | | |
934 | | /// Fallback instruction visitor for unhandled instructions. |
935 | | /// Emit a general non-differentiability diagnostic. |
936 | 20 | void visitSILInstruction(SILInstruction *inst) { |
937 | 20 | LLVM_DEBUG(getADDebugStream() |
938 | 20 | << "Unhandled instruction in PullbackCloner: " << *inst); |
939 | 20 | getContext().emitNondifferentiabilityError( |
940 | 20 | inst, getInvoker(), diag::autodiff_expression_not_differentiable_note); |
941 | 20 | errorOccurred = true; |
942 | 20 | } |
943 | | |
944 | | /// Handle `apply` instruction. |
945 | | /// Original: (y0, y1, ...) = apply @fn (x0, x1, ...) |
946 | | /// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...) |
947 | 6.07k | void visitApplyInst(ApplyInst *ai) { |
948 | 6.07k | assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); |
949 | | |
950 | | // Skip `array.uninitialized_intrinsic` applications, which have special |
951 | | // `store` and `copy_addr` support. |
952 | 6.07k | if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) |
953 | 208 | return; |
954 | 5.86k | auto loc = ai->getLoc(); |
955 | 5.86k | auto *bb = ai->getParent(); |
956 | | // Handle `array.finalize_intrinsic` applications. |
957 | | // `array.finalize_intrinsic` semantically behaves like an identity |
958 | | // function. |
959 | 5.86k | if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { |
960 | 208 | assert(ai->getNumArguments() == 1 && |
961 | 208 | "Expected intrinsic to have one operand"); |
962 | | // Accumulate result's adjoint into argument's adjoint. |
963 | 0 | auto adjResult = getAdjointValue(bb, ai); |
964 | 208 | auto origArg = ai->getArgumentsWithoutIndirectResults().front(); |
965 | 208 | addAdjointValue(bb, origArg, adjResult, loc); |
966 | 208 | return; |
967 | 208 | } |
968 | | // Replace a call to a function with a call to its pullback. |
969 | 5.65k | auto &nestedApplyInfo = getContext().getNestedApplyInfo(); |
970 | 5.65k | auto applyInfoLookup = nestedApplyInfo.find(ai); |
971 | | // If no `NestedApplyInfo` was found, then this task doesn't need to be |
972 | | // differentiated. |
973 | 5.65k | if (applyInfoLookup == nestedApplyInfo.end()) { |
974 | | // Must not be active. |
975 | 0 | assert(!getActivityInfo().isActive(ai, getConfig())); |
976 | 0 | return; |
977 | 0 | } |
978 | 5.65k | auto applyInfo = applyInfoLookup->getSecond(); |
979 | | |
980 | | // Get the original result of the `apply` instruction. |
981 | 5.65k | SmallVector<SILValue, 8> origDirectResults; |
982 | 5.65k | forEachApplyDirectResult(ai, [&](SILValue directResult) { |
983 | 3.36k | origDirectResults.push_back(directResult); |
984 | 3.36k | }); |
985 | 5.65k | SmallVector<SILValue, 8> origAllResults; |
986 | 5.65k | collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); |
987 | | // Append semantic result arguments after original results. |
988 | 8.75k | for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) { |
989 | 8.75k | auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( |
990 | 8.75k | ai->getNumIndirectResults() + paramIdx); |
991 | 8.75k | if (!paramInfo.isAutoDiffSemanticResult()) |
992 | 8.24k | continue; |
993 | 512 | origAllResults.push_back( |
994 | 512 | ai->getArgumentsWithoutIndirectResults()[paramIdx]); |
995 | 512 | } |
996 | | |
997 | | // Get callee pullback arguments. |
998 | 5.65k | SmallVector<SILValue, 8> args; |
999 | | |
1000 | | // Handle callee pullback indirect results. |
1001 | | // Create local allocations for these and destroy them after the call. |
1002 | 5.65k | auto pullback = getPullbackTupleElement(ai); |
1003 | 5.65k | auto pullbackType = |
1004 | 5.65k | remapType(pullback->getType()).castTo<SILFunctionType>(); |
1005 | | |
1006 | 5.65k | auto actualPullbackType = applyInfo.originalPullbackType |
1007 | 5.65k | ? *applyInfo.originalPullbackType |
1008 | 5.65k | : pullbackType; |
1009 | 5.65k | actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule()); |
1010 | 5.65k | SmallVector<AllocStackInst *, 4> pullbackIndirectResults; |
1011 | 5.65k | for (auto indRes : actualPullbackType->getIndirectFormalResults()) { |
1012 | 2.72k | auto *alloc = builder.createAllocStack( |
1013 | 2.72k | loc, remapType(indRes.getSILStorageInterfaceType())); |
1014 | 2.72k | pullbackIndirectResults.push_back(alloc); |
1015 | 2.72k | args.push_back(alloc); |
1016 | 2.72k | } |
1017 | | |
1018 | | // Collect callee pullback formal arguments. |
1019 | 5.75k | for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) { |
1020 | 5.75k | assert(resultIndex < origAllResults.size()); |
1021 | 0 | auto origResult = origAllResults[resultIndex]; |
1022 | | // Get the seed (i.e. adjoint value of the original result). |
1023 | 5.75k | SILValue seed; |
1024 | 5.75k | switch (getTangentValueCategory(origResult)) { |
1025 | 3.26k | case SILValueCategory::Object: |
1026 | 3.26k | seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc); |
1027 | 3.26k | break; |
1028 | 2.48k | case SILValueCategory::Address: |
1029 | 2.48k | seed = getAdjointBuffer(bb, origResult); |
1030 | 2.48k | break; |
1031 | 5.75k | } |
1032 | 5.75k | args.push_back(seed); |
1033 | 5.75k | } |
1034 | | |
1035 | | // If callee pullback was reabstracted in VJP, reabstract callee pullback. |
1036 | 5.65k | if (applyInfo.originalPullbackType) { |
1037 | 1.46k | SILOptFunctionBuilder fb(getContext().getTransform()); |
1038 | 1.46k | pullback = reabstractFunction( |
1039 | 1.46k | builder, fb, loc, pullback, *applyInfo.originalPullbackType, |
1040 | 1.46k | [this](SubstitutionMap subs) -> SubstitutionMap { |
1041 | 1.46k | return this->remapSubstitutionMap(subs); |
1042 | 1.46k | }); |
1043 | 1.46k | } |
1044 | | |
1045 | | // Call the callee pullback. |
1046 | 5.65k | auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), |
1047 | 5.65k | args); |
1048 | 5.65k | builder.emitDestroyValueOperation(loc, pullback); |
1049 | | |
1050 | | // Extract all results from `pullbackCall`. |
1051 | 5.65k | SmallVector<SILValue, 8> dirResults; |
1052 | 5.65k | extractAllElements(pullbackCall, builder, dirResults); |
1053 | | // Get all results in type-defined order. |
1054 | 5.65k | SmallVector<SILValue, 8> allResults; |
1055 | 5.65k | collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults); |
1056 | | |
1057 | 5.65k | LLVM_DEBUG({ |
1058 | 5.65k | auto &s = getADDebugStream(); |
1059 | 5.65k | s << "All results of the nested pullback call:\n"; |
1060 | 5.65k | llvm::for_each(allResults, [&](SILValue v) { s << v; }); |
1061 | 5.65k | }); |
1062 | | |
1063 | | // Accumulate adjoints for original differentiation parameters. |
1064 | 5.65k | auto allResultsIt = allResults.begin(); |
1065 | 8.75k | for (unsigned i : applyInfo.config.parameterIndices->getIndices()) { |
1066 | 8.75k | auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); |
1067 | | // Skip adjoint accumulation for semantic results arguments. |
1068 | 8.75k | auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( |
1069 | 8.75k | ai->getNumIndirectResults() + i); |
1070 | 8.75k | if (paramInfo.isAutoDiffSemanticResult()) |
1071 | 512 | continue; |
1072 | 8.24k | auto tan = *allResultsIt++; |
1073 | 8.24k | if (tan->getType().isAddress()) { |
1074 | 2.71k | addToAdjointBuffer(bb, origArg, tan, loc); |
1075 | 5.52k | } else { |
1076 | 5.52k | if (origArg->getType().isAddress()) { |
1077 | 0 | auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); |
1078 | 0 | builder.emitStoreValueOperation(loc, tan, tmpBuf, |
1079 | 0 | StoreOwnershipQualifier::Init); |
1080 | 0 | addToAdjointBuffer(bb, origArg, tmpBuf, loc); |
1081 | 0 | builder.emitDestroyAddrAndFold(loc, tmpBuf); |
1082 | 0 | builder.createDeallocStack(loc, tmpBuf); |
1083 | 5.52k | } else { |
1084 | 5.52k | recordTemporary(tan); |
1085 | 5.52k | addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc); |
1086 | 5.52k | } |
1087 | 5.52k | } |
1088 | 8.24k | } |
1089 | | // Destroy unused pullback direct results. Needed for pullback results from |
1090 | | // VJPs extracted from `@differentiable` function callees, where the |
1091 | | // `@differentiable` function's differentiation parameter indices are a |
1092 | | // superset of the active `apply` parameter indices. |
1093 | 5.67k | while (allResultsIt != allResults.end()) { |
1094 | 16 | auto unusedPullbackDirectResult = *allResultsIt++; |
1095 | 16 | if (unusedPullbackDirectResult->getType().isAddress()) |
1096 | 4 | continue; |
1097 | 12 | builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult); |
1098 | 12 | } |
1099 | | // Destroy and deallocate pullback indirect results. |
1100 | 5.65k | for (auto *alloc : llvm::reverse(pullbackIndirectResults)) { |
1101 | 2.72k | builder.emitDestroyAddrAndFold(loc, alloc); |
1102 | 2.72k | builder.createDeallocStack(loc, alloc); |
1103 | 2.72k | } |
1104 | 5.65k | } |
1105 | | |
1106 | 32 | void visitBeginApplyInst(BeginApplyInst *bai) { |
1107 | | // Diagnose `begin_apply` instructions. |
1108 | | // Coroutine differentiation is not yet supported. |
1109 | 32 | getContext().emitNondifferentiabilityError( |
1110 | 32 | bai, getInvoker(), diag::autodiff_coroutines_not_supported); |
1111 | 32 | errorOccurred = true; |
1112 | 32 | return; |
1113 | 32 | } |
1114 | | |
1115 | | /// Handle `struct` instruction. |
1116 | | /// Original: y = struct (x0, x1, x2, ...) |
1117 | | /// Adjoint: adj[x0] += struct_extract adj[y], #x0 |
1118 | | /// adj[x1] += struct_extract adj[y], #x1 |
1119 | | /// adj[x2] += struct_extract adj[y], #x2 |
1120 | | /// ... |
1121 | 60 | void visitStructInst(StructInst *si) { |
1122 | 60 | auto *bb = si->getParent(); |
1123 | 60 | auto loc = si->getLoc(); |
1124 | 60 | auto *structDecl = si->getStructDecl(); |
1125 | 60 | switch (getTangentValueCategory(si)) { |
1126 | 60 | case SILValueCategory::Object: { |
1127 | 60 | auto av = getAdjointValue(bb, si); |
1128 | 60 | switch (av.getKind()) { |
1129 | 0 | case AdjointValueKind::Zero: { |
1130 | 0 | for (auto *field : structDecl->getStoredProperties()) { |
1131 | 0 | auto fv = si->getFieldValue(field); |
1132 | 0 | addAdjointValue( |
1133 | 0 | bb, fv, |
1134 | 0 | makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc); |
1135 | 0 | } |
1136 | 0 | break; |
1137 | 0 | } |
1138 | 60 | case AdjointValueKind::Concrete: { |
1139 | 60 | auto adjStruct = materializeAdjointDirect(std::move(av), loc); |
1140 | 60 | auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); |
1141 | | |
1142 | | // Find the struct `TangentVector` type. |
1143 | 60 | auto structTy = remapType(si->getType()).getASTType(); |
1144 | 60 | #ifndef NDEBUG |
1145 | 60 | auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); |
1146 | 60 | assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); |
1147 | 0 | assert(tangentVectorTy->getStructOrBoundGenericStruct()); |
1148 | 0 | #endif |
1149 | | |
1150 | | // Accumulate adjoints for the fields of the `struct` operand. |
1151 | 0 | unsigned fieldIndex = 0; |
1152 | 60 | for (auto it = structDecl->getStoredProperties().begin(); |
1153 | 164 | it != structDecl->getStoredProperties().end(); |
1154 | 104 | ++it, ++fieldIndex) { |
1155 | 104 | VarDecl *field = *it; |
1156 | 104 | if (field->getAttrs().hasAttribute<NoDerivativeAttr>()) |
1157 | 0 | continue; |
1158 | | // Find the corresponding field in the tangent space. |
1159 | 104 | auto *tanField = getTangentStoredProperty( |
1160 | 104 | getContext(), field, structTy, loc, getInvoker()); |
1161 | 104 | if (!tanField) { |
1162 | 0 | errorOccurred = true; |
1163 | 0 | return; |
1164 | 0 | } |
1165 | 104 | auto tanElt = dti->getResult(fieldIndex); |
1166 | 104 | addAdjointValue(bb, si->getFieldValue(field), |
1167 | 104 | makeConcreteAdjointValue(tanElt), si->getLoc()); |
1168 | 104 | } |
1169 | 60 | break; |
1170 | 60 | } |
1171 | 60 | case AdjointValueKind::Aggregate: { |
1172 | | // Note: All user-called initializations go through the calls to the |
1173 | | // initializer, and synthesized initializers only have one level of |
1174 | | // struct formation which will not result into any aggregate adjoint |
1175 | | // values. |
1176 | 0 | llvm_unreachable( |
1177 | 0 | "Aggregate adjoint values should not occur for `struct` " |
1178 | 0 | "instructions"); |
1179 | 0 | } |
1180 | 0 | case AdjointValueKind::AddElement: { |
1181 | 0 | llvm_unreachable( |
1182 | 0 | "Adjoint of `StructInst` cannot be of kind `AddElement`"); |
1183 | 0 | } |
1184 | 60 | } |
1185 | 60 | break; |
1186 | 60 | } |
1187 | 60 | case SILValueCategory::Address: { |
1188 | 0 | auto adjBuf = getAdjointBuffer(bb, si); |
1189 | | // Find the struct `TangentVector` type. |
1190 | 0 | auto structTy = remapType(si->getType()).getASTType(); |
1191 | | // Accumulate adjoints for the fields of the `struct` operand. |
1192 | 0 | unsigned fieldIndex = 0; |
1193 | 0 | for (auto it = structDecl->getStoredProperties().begin(); |
1194 | 0 | it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) { |
1195 | 0 | VarDecl *field = *it; |
1196 | 0 | if (field->getAttrs().hasAttribute<NoDerivativeAttr>()) |
1197 | 0 | continue; |
1198 | | // Find the corresponding field in the tangent space. |
1199 | 0 | auto *tanField = getTangentStoredProperty(getContext(), field, structTy, |
1200 | 0 | loc, getInvoker()); |
1201 | 0 | if (!tanField) { |
1202 | 0 | errorOccurred = true; |
1203 | 0 | return; |
1204 | 0 | } |
1205 | 0 | auto *adjFieldBuf = |
1206 | 0 | builder.createStructElementAddr(loc, adjBuf, tanField); |
1207 | 0 | auto fieldValue = si->getFieldValue(field); |
1208 | 0 | switch (getTangentValueCategory(fieldValue)) { |
1209 | 0 | case SILValueCategory::Object: { |
1210 | 0 | auto adjField = builder.emitLoadValueOperation( |
1211 | 0 | loc, adjFieldBuf, LoadOwnershipQualifier::Copy); |
1212 | 0 | recordTemporary(adjField); |
1213 | 0 | addAdjointValue(bb, fieldValue, makeConcreteAdjointValue(adjField), |
1214 | 0 | loc); |
1215 | 0 | break; |
1216 | 0 | } |
1217 | 0 | case SILValueCategory::Address: { |
1218 | 0 | addToAdjointBuffer(bb, fieldValue, adjFieldBuf, loc); |
1219 | 0 | break; |
1220 | 0 | } |
1221 | 0 | } |
1222 | 0 | } |
1223 | 0 | } break; |
1224 | 60 | } |
1225 | 60 | } |
1226 | | |
1227 | | /// Handle `struct_extract` instruction. |
1228 | | /// Original: y = struct_extract x, #field |
1229 | | /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) |
1230 | | /// ^~~~~~~ |
1231 | | /// field in tangent space corresponding to #field |
1232 | 508 | void visitStructExtractInst(StructExtractInst *sei) { |
1233 | 508 | auto *bb = sei->getParent(); |
1234 | 508 | auto loc = getValidLocation(sei); |
1235 | | // Find the corresponding field in the tangent space. |
1236 | 508 | auto structTy = remapType(sei->getOperand()->getType()).getASTType(); |
1237 | 508 | auto *tanField = |
1238 | 508 | getTangentStoredProperty(getContext(), sei, structTy, getInvoker()); |
1239 | 508 | assert(tanField && "Invalid projections should have been diagnosed"); |
1240 | | // Check the `struct_extract` operand's value tangent category. |
1241 | 0 | switch (getTangentValueCategory(sei->getOperand())) { |
1242 | 508 | case SILValueCategory::Object: { |
1243 | 508 | auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); |
1244 | 508 | auto tangentVectorSILTy = |
1245 | 508 | SILType::getPrimitiveObjectType(tangentVectorTy); |
1246 | 508 | auto eltAdj = getAdjointValue(bb, sei); |
1247 | | |
1248 | 508 | switch (eltAdj.getKind()) { |
1249 | 0 | case AdjointValueKind::Zero: { |
1250 | 0 | addAdjointValue(bb, sei->getOperand(), |
1251 | 0 | makeZeroAdjointValue(tangentVectorSILTy), loc); |
1252 | 0 | break; |
1253 | 0 | } |
1254 | 0 | case AdjointValueKind::Aggregate: |
1255 | 504 | case AdjointValueKind::Concrete: |
1256 | 508 | case AdjointValueKind::AddElement: { |
1257 | 508 | auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy); |
1258 | 508 | addAdjointValue(bb, sei->getOperand(), |
1259 | 508 | makeAddElementAdjointValue(baseAdj, eltAdj, tanField), |
1260 | 508 | loc); |
1261 | 508 | break; |
1262 | 504 | } |
1263 | 508 | } |
1264 | 508 | break; |
1265 | 508 | } |
1266 | 508 | case SILValueCategory::Address: { |
1267 | 0 | auto adjBase = getAdjointBuffer(bb, sei->getOperand()); |
1268 | 0 | auto *adjBaseElt = |
1269 | 0 | builder.createStructElementAddr(loc, adjBase, tanField); |
1270 | | // Check the `struct_extract`'s value tangent category. |
1271 | 0 | switch (getTangentValueCategory(sei)) { |
1272 | 0 | case SILValueCategory::Object: { |
1273 | 0 | auto adjElt = getAdjointValue(bb, sei); |
1274 | 0 | auto concreteAdjElt = materializeAdjointDirect(adjElt, loc); |
1275 | 0 | auto concreteAdjEltCopy = |
1276 | 0 | builder.emitCopyValueOperation(loc, concreteAdjElt); |
1277 | 0 | auto *alloc = builder.createAllocStack(loc, adjElt.getType()); |
1278 | 0 | builder.emitStoreValueOperation(loc, concreteAdjEltCopy, alloc, |
1279 | 0 | StoreOwnershipQualifier::Init); |
1280 | 0 | builder.emitInPlaceAdd(loc, adjBaseElt, alloc); |
1281 | 0 | builder.createDestroyAddr(loc, alloc); |
1282 | 0 | builder.createDeallocStack(loc, alloc); |
1283 | 0 | break; |
1284 | 0 | } |
1285 | 0 | case SILValueCategory::Address: { |
1286 | 0 | auto adjElt = getAdjointBuffer(bb, sei); |
1287 | 0 | builder.emitInPlaceAdd(loc, adjBaseElt, adjElt); |
1288 | 0 | break; |
1289 | 0 | } |
1290 | 0 | } |
1291 | 0 | break; |
1292 | 0 | } |
1293 | 508 | } |
1294 | 508 | } |
1295 | | |
1296 | | /// Handle `ref_element_addr` instruction. |
1297 | | /// Original: y = ref_element_addr x, <n> |
1298 | | /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) |
1299 | | /// ^~~~~~~ |
1300 | | /// field in tangent space corresponding to #field |
1301 | 128 | void visitRefElementAddrInst(RefElementAddrInst *reai) { |
1302 | 128 | auto *bb = reai->getParent(); |
1303 | 128 | auto loc = reai->getLoc(); |
1304 | 128 | auto adjBuf = getAdjointBuffer(bb, reai); |
1305 | 128 | auto classOperand = reai->getOperand(); |
1306 | 128 | auto classType = remapType(reai->getOperand()->getType()).getASTType(); |
1307 | 128 | auto *tanField = |
1308 | 128 | getTangentStoredProperty(getContext(), reai, classType, getInvoker()); |
1309 | 128 | assert(tanField && "Invalid projections should have been diagnosed"); |
1310 | 0 | switch (getTangentValueCategory(classOperand)) { |
1311 | 36 | case SILValueCategory::Object: { |
1312 | 36 | auto classTy = remapType(classOperand->getType()).getASTType(); |
1313 | 36 | auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType(); |
1314 | 36 | auto tangentVectorSILTy = |
1315 | 36 | SILType::getPrimitiveObjectType(tangentVectorTy); |
1316 | 36 | auto *tangentVectorDecl = |
1317 | 36 | tangentVectorTy->getStructOrBoundGenericStruct(); |
1318 | | // Accumulate adjoint for the `ref_element_addr` operand. |
1319 | 36 | SmallVector<AdjointValue, 8> eltVals; |
1320 | 36 | for (auto *field : tangentVectorDecl->getStoredProperties()) { |
1321 | 36 | if (field == tanField) { |
1322 | 36 | auto adjElt = builder.emitLoadValueOperation( |
1323 | 36 | reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy); |
1324 | 36 | eltVals.push_back(makeConcreteAdjointValue(adjElt)); |
1325 | 36 | recordTemporary(adjElt); |
1326 | 36 | } else { |
1327 | 0 | auto substMap = tangentVectorTy->getMemberSubstitutionMap( |
1328 | 0 | field->getModuleContext(), field); |
1329 | 0 | auto fieldTy = field->getInterfaceType().subst(substMap); |
1330 | 0 | auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); |
1331 | 0 | assert(fieldSILTy.isObject()); |
1332 | 0 | eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); |
1333 | 0 | } |
1334 | 36 | } |
1335 | 36 | addAdjointValue(bb, classOperand, |
1336 | 36 | makeAggregateAdjointValue(tangentVectorSILTy, eltVals), |
1337 | 36 | loc); |
1338 | 36 | break; |
1339 | 0 | } |
1340 | 92 | case SILValueCategory::Address: { |
1341 | 92 | auto adjBufClass = getAdjointBuffer(bb, classOperand); |
1342 | 92 | auto adjBufElt = |
1343 | 92 | builder.createStructElementAddr(loc, adjBufClass, tanField); |
1344 | 92 | builder.emitInPlaceAdd(loc, adjBufElt, adjBuf); |
1345 | 92 | break; |
1346 | 0 | } |
1347 | 128 | } |
1348 | 128 | } |
1349 | | |
1350 | | /// Handle `tuple` instruction. |
1351 | | /// Original: y = tuple (x0, x1, x2, ...) |
1352 | | /// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y] |
1353 | | /// ^~~ |
1354 | | /// excluding non-differentiable elements |
1355 | 68 | void visitTupleInst(TupleInst *ti) { |
1356 | 68 | auto *bb = ti->getParent(); |
1357 | 68 | auto loc = ti->getLoc(); |
1358 | 68 | switch (getTangentValueCategory(ti)) { |
1359 | 68 | case SILValueCategory::Object: { |
1360 | 68 | auto av = getAdjointValue(bb, ti); |
1361 | 68 | switch (av.getKind()) { |
1362 | 0 | case AdjointValueKind::Zero: |
1363 | 0 | for (auto elt : ti->getElements()) { |
1364 | 0 | if (!getTangentSpace(elt->getType().getASTType())) |
1365 | 0 | continue; |
1366 | 0 | addAdjointValue( |
1367 | 0 | bb, elt, |
1368 | 0 | makeZeroAdjointValue(getRemappedTangentType(elt->getType())), |
1369 | 0 | loc); |
1370 | 0 | } |
1371 | 0 | break; |
1372 | 0 | case AdjointValueKind::Concrete: { |
1373 | 0 | auto adjVal = av.getConcreteValue(); |
1374 | 0 | auto adjValCopy = builder.emitCopyValueOperation(loc, adjVal); |
1375 | 0 | SmallVector<SILValue, 4> adjElts; |
1376 | 0 | if (!adjVal->getType().getAs<TupleType>()) { |
1377 | 0 | recordTemporary(adjValCopy); |
1378 | 0 | adjElts.push_back(adjValCopy); |
1379 | 0 | } else { |
1380 | 0 | auto *dti = builder.createDestructureTuple(loc, adjValCopy); |
1381 | 0 | for (auto adjElt : dti->getResults()) |
1382 | 0 | recordTemporary(adjElt); |
1383 | 0 | adjElts.append(dti->getResults().begin(), dti->getResults().end()); |
1384 | 0 | } |
1385 | | // Accumulate adjoints for `tuple` operands, skipping the |
1386 | | // non-`Differentiable` ones. |
1387 | 0 | unsigned adjIndex = 0; |
1388 | 0 | for (auto i : range(ti->getNumOperands())) { |
1389 | 0 | if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) |
1390 | 0 | continue; |
1391 | 0 | auto adjElt = adjElts[adjIndex++]; |
1392 | 0 | addAdjointValue(bb, ti->getOperand(i), |
1393 | 0 | makeConcreteAdjointValue(adjElt), loc); |
1394 | 0 | } |
1395 | 0 | break; |
1396 | 0 | } |
1397 | 68 | case AdjointValueKind::Aggregate: { |
1398 | 68 | unsigned adjIndex = 0; |
1399 | 136 | for (auto i : range(ti->getElements().size())) { |
1400 | 136 | if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) |
1401 | 0 | continue; |
1402 | 136 | addAdjointValue(bb, ti->getElement(i), |
1403 | 136 | av.getAggregateElement(adjIndex++), loc); |
1404 | 136 | } |
1405 | 68 | break; |
1406 | 0 | } |
1407 | 0 | case AdjointValueKind::AddElement: { |
1408 | 0 | llvm_unreachable( |
1409 | 0 | "Adjoint of `TupleInst` cannot be of kind `AddElement`"); |
1410 | 0 | } |
1411 | 68 | } |
1412 | 68 | break; |
1413 | 68 | } |
1414 | 68 | case SILValueCategory::Address: { |
1415 | 0 | auto adjBuf = getAdjointBuffer(bb, ti); |
1416 | | // Accumulate adjoints for `tuple` operands, skipping the |
1417 | | // non-`Differentiable` ones. |
1418 | 0 | unsigned adjIndex = 0; |
1419 | 0 | for (auto i : range(ti->getNumOperands())) { |
1420 | 0 | if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) |
1421 | 0 | continue; |
1422 | 0 | auto adjBufElt = |
1423 | 0 | builder.createTupleElementAddr(loc, adjBuf, adjIndex++); |
1424 | 0 | auto adjElt = getAdjointBuffer(bb, ti->getOperand(i)); |
1425 | 0 | builder.emitInPlaceAdd(loc, adjElt, adjBufElt); |
1426 | 0 | } |
1427 | 0 | break; |
1428 | 68 | } |
1429 | 68 | } |
1430 | 68 | } |
1431 | | |
1432 | | /// Handle `tuple_extract` instruction. |
1433 | | /// Original: y = tuple_extract x, <n> |
1434 | | /// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0) |
1435 | | /// ^~~~~~ |
1436 | | /// n'-th element, where n' is tuple tangent space |
1437 | | /// index corresponding to n |
1438 | 16 | void visitTupleExtractInst(TupleExtractInst *tei) { |
1439 | 16 | auto *bb = tei->getParent(); |
1440 | 16 | auto loc = tei->getLoc(); |
1441 | 16 | auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); |
1442 | 16 | auto eltAdj = getAdjointValue(bb, tei); |
1443 | 16 | switch (eltAdj.getKind()) { |
1444 | 0 | case AdjointValueKind::Zero: { |
1445 | 0 | addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy), |
1446 | 0 | loc); |
1447 | 0 | break; |
1448 | 0 | } |
1449 | 0 | case AdjointValueKind::Aggregate: |
1450 | 12 | case AdjointValueKind::Concrete: |
1451 | 16 | case AdjointValueKind::AddElement: { |
1452 | 16 | auto tupleTy = tei->getTupleType(); |
1453 | 16 | auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>(); |
1454 | 16 | if (!tupleTanTupleTy) { |
1455 | 0 | addAdjointValue(bb, tei->getOperand(), eltAdj, loc); |
1456 | 0 | break; |
1457 | 0 | } |
1458 | | |
1459 | 16 | unsigned elements = 0; |
1460 | 32 | for (unsigned i : range(tupleTy->getNumElements())) { |
1461 | 32 | if (!getTangentSpace( |
1462 | 32 | tupleTy->getElement(i).getType()->getCanonicalType())) |
1463 | 0 | continue; |
1464 | 32 | elements++; |
1465 | 32 | } |
1466 | | |
1467 | 16 | if (elements == 1) { |
1468 | 0 | addAdjointValue(bb, tei->getOperand(), eltAdj, loc); |
1469 | 16 | } else { |
1470 | 16 | auto baseAdj = makeZeroAdjointValue(tupleTanTy); |
1471 | 16 | addAdjointValue( |
1472 | 16 | bb, tei->getOperand(), |
1473 | 16 | makeAddElementAdjointValue(baseAdj, eltAdj, tei->getFieldIndex()), |
1474 | 16 | loc); |
1475 | 16 | } |
1476 | 16 | break; |
1477 | 16 | } |
1478 | 16 | } |
1479 | 16 | } |
1480 | | |
1481 | | /// Handle `destructure_tuple` instruction. |
1482 | | /// Original: (y0, ..., yn) = destructure_tuple x |
1483 | | /// Adjoint: adj[x].0 += adj[y0] |
1484 | | /// ... |
1485 | | /// adj[x].n += adj[yn] |
1486 | 308 | void visitDestructureTupleInst(DestructureTupleInst *dti) { |
1487 | 308 | auto *bb = dti->getParent(); |
1488 | 308 | auto loc = dti->getLoc(); |
1489 | 308 | auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType()); |
1490 | | // Check the `destructure_tuple` operand's value tangent category. |
1491 | 308 | switch (getTangentValueCategory(dti->getOperand())) { |
1492 | 308 | case SILValueCategory::Object: { |
1493 | 308 | SmallVector<AdjointValue, 8> adjValues; |
1494 | 616 | for (auto origElt : dti->getResults()) { |
1495 | | // Skip non-`Differentiable` tuple elements. |
1496 | 616 | if (!getTangentSpace(remapType(origElt->getType()).getASTType())) |
1497 | 208 | continue; |
1498 | 408 | adjValues.push_back(getAdjointValue(bb, origElt)); |
1499 | 408 | } |
1500 | | // Handle tuple tangent type. |
1501 | | // Add adjoints for every tuple element that has a tangent space. |
1502 | 308 | if (tupleTanTy.is<TupleType>()) { |
1503 | 100 | assert(adjValues.size() > 1); |
1504 | 0 | addAdjointValue(bb, dti->getOperand(), |
1505 | 100 | makeAggregateAdjointValue(tupleTanTy, adjValues), loc); |
1506 | 100 | } |
1507 | | // Handle non-tuple tangent type. |
1508 | | // Add adjoint for the single tuple element that has a tangent space. |
1509 | 208 | else { |
1510 | 208 | assert(adjValues.size() == 1); |
1511 | 0 | addAdjointValue(bb, dti->getOperand(), adjValues.front(), loc); |
1512 | 208 | } |
1513 | 0 | break; |
1514 | 0 | } |
1515 | 0 | case SILValueCategory::Address: { |
1516 | 0 | auto adjBuf = getAdjointBuffer(bb, dti->getOperand()); |
1517 | 0 | unsigned adjIndex = 0; |
1518 | 0 | for (auto origElt : dti->getResults()) { |
1519 | | // Skip non-`Differentiable` tuple elements. |
1520 | 0 | if (!getTangentSpace(remapType(origElt->getType()).getASTType())) |
1521 | 0 | continue; |
1522 | | // Handle tuple tangent type. |
1523 | | // Add adjoints for every tuple element that has a tangent space. |
1524 | 0 | if (tupleTanTy.is<TupleType>()) { |
1525 | 0 | auto adjEltBuf = getAdjointBuffer(bb, origElt); |
1526 | 0 | auto adjBufElt = |
1527 | 0 | builder.createTupleElementAddr(loc, adjBuf, adjIndex); |
1528 | 0 | builder.emitInPlaceAdd(loc, adjBufElt, adjEltBuf); |
1529 | 0 | } |
1530 | | // Handle non-tuple tangent type. |
1531 | | // Add adjoint for the single tuple element that has a tangent space. |
1532 | 0 | else { |
1533 | 0 | auto adjEltBuf = getAdjointBuffer(bb, origElt); |
1534 | 0 | addToAdjointBuffer(bb, dti->getOperand(), adjEltBuf, loc); |
1535 | 0 | } |
1536 | 0 | ++adjIndex; |
1537 | 0 | } |
1538 | 0 | break; |
1539 | 0 | } |
1540 | 308 | } |
1541 | 308 | } |
1542 | | |
1543 | | /// Handle `load` or `load_borrow` instruction |
1544 | | /// Original: y = load/load_borrow x |
1545 | | /// Adjoint: adj[x] += adj[y] |
1546 | 2.35k | void visitLoadOperation(SingleValueInstruction *inst) { |
1547 | 2.35k | assert(isa<LoadInst>(inst) || isa<LoadBorrowInst>(inst)); |
1548 | 0 | auto *bb = inst->getParent(); |
1549 | 2.35k | auto loc = inst->getLoc(); |
1550 | 2.35k | switch (getTangentValueCategory(inst)) { |
1551 | 2.30k | case SILValueCategory::Object: { |
1552 | 2.30k | auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc); |
1553 | | // Allocate a local buffer and store the adjoint value. This buffer will |
1554 | | // be used for accumulation into the adjoint buffer. |
1555 | 2.30k | auto adjBuf = builder.createAllocStack( |
1556 | 2.30k | loc, adjVal->getType(), SILDebugVariable()); |
1557 | 2.30k | auto copy = builder.emitCopyValueOperation(loc, adjVal); |
1558 | 2.30k | builder.emitStoreValueOperation(loc, copy, adjBuf, |
1559 | 2.30k | StoreOwnershipQualifier::Init); |
1560 | | // Accumulate the adjoint value in the local buffer into the adjoint |
1561 | | // buffer. |
1562 | 2.30k | addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); |
1563 | 2.30k | builder.emitDestroyAddr(loc, adjBuf); |
1564 | 2.30k | builder.createDeallocStack(loc, adjBuf); |
1565 | 2.30k | break; |
1566 | 0 | } |
1567 | 52 | case SILValueCategory::Address: { |
1568 | 52 | auto adjBuf = getAdjointBuffer(bb, inst); |
1569 | 52 | addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); |
1570 | 52 | break; |
1571 | 0 | } |
1572 | 2.35k | } |
1573 | 2.35k | } |
1574 | 2.29k | void visitLoadInst(LoadInst *li) { visitLoadOperation(li); } |
1575 | 64 | void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); } |
1576 | | |
1577 | | /// Handle `store` or `store_borrow` instruction. |
1578 | | /// Original: store/store_borrow x to y |
1579 | | /// Adjoint: adj[x] += load adj[y]; adj[y] = 0 |
1580 | | void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc, |
1581 | 2.69k | SILValue origDest) { |
1582 | 2.69k | auto adjBuf = getAdjointBuffer(bb, origDest); |
1583 | 2.69k | switch (getTangentValueCategory(origSrc)) { |
1584 | 2.63k | case SILValueCategory::Object: { |
1585 | 2.63k | auto adjVal = builder.emitLoadValueOperation( |
1586 | 2.63k | loc, adjBuf, LoadOwnershipQualifier::Take); |
1587 | 2.63k | recordTemporary(adjVal); |
1588 | 2.63k | addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc); |
1589 | 2.63k | builder.emitZeroIntoBuffer(loc, adjBuf, IsInitialization); |
1590 | 2.63k | break; |
1591 | 0 | } |
1592 | 60 | case SILValueCategory::Address: { |
1593 | 60 | addToAdjointBuffer(bb, origSrc, adjBuf, loc); |
1594 | 60 | builder.emitZeroIntoBuffer(loc, adjBuf, IsNotInitialization); |
1595 | 60 | break; |
1596 | 0 | } |
1597 | 2.69k | } |
1598 | 2.69k | } |
1599 | 2.69k | void visitStoreInst(StoreInst *si) { |
1600 | 2.69k | visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(), |
1601 | 2.69k | si->getDest()); |
1602 | 2.69k | } |
1603 | 0 | void visitStoreBorrowInst(StoreBorrowInst *sbi) { |
1604 | 0 | visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(), |
1605 | 0 | sbi); |
1606 | 0 | } |
1607 | | |
1608 | | /// Handle `copy_addr` instruction. |
1609 | | /// Original: copy_addr x to y |
1610 | | /// Adjoint: adj[x] += adj[y]; adj[y] = 0 |
1611 | 2.01k | void visitCopyAddrInst(CopyAddrInst *cai) { |
1612 | 2.01k | auto *bb = cai->getParent(); |
1613 | 2.01k | auto adjDest = getAdjointBuffer(bb, cai->getDest()); |
1614 | 2.01k | addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc()); |
1615 | 2.01k | builder.emitZeroIntoBuffer(cai->getLoc(), adjDest, IsNotInitialization); |
1616 | 2.01k | } |
1617 | | |
1618 | | /// Handle any ownership instruction that deals with values: copy_value, |
1619 | | /// move_value, begin_borrow. |
1620 | | /// Original: y = copy_value x |
1621 | | /// Adjoint: adj[x] += adj[y] |
1622 | 500 | void visitValueOwnershipInst(SingleValueInstruction *svi) { |
1623 | 500 | assert(svi->getNumOperands() == 1); |
1624 | 0 | auto *bb = svi->getParent(); |
1625 | 500 | switch (getTangentValueCategory(svi)) { |
1626 | 348 | case SILValueCategory::Object: { |
1627 | 348 | auto adj = getAdjointValue(bb, svi); |
1628 | 348 | addAdjointValue(bb, svi->getOperand(0), adj, svi->getLoc()); |
1629 | 348 | break; |
1630 | 0 | } |
1631 | 152 | case SILValueCategory::Address: { |
1632 | 152 | auto adjDest = getAdjointBuffer(bb, svi); |
1633 | 152 | addToAdjointBuffer(bb, svi->getOperand(0), adjDest, svi->getLoc()); |
1634 | 152 | builder.emitZeroIntoBuffer(svi->getLoc(), adjDest, IsNotInitialization); |
1635 | 152 | break; |
1636 | 0 | } |
1637 | 500 | } |
1638 | 500 | } |
1639 | | |
1640 | | /// Handle `copy_value` instruction. |
1641 | | /// Original: y = copy_value x |
1642 | | /// Adjoint: adj[x] += adj[y] |
1643 | 308 | void visitCopyValueInst(CopyValueInst *cvi) { visitValueOwnershipInst(cvi); } |
1644 | | |
1645 | | /// Handle `begin_borrow` instruction. |
1646 | | /// Original: y = begin_borrow x |
1647 | | /// Adjoint: adj[x] += adj[y] |
1648 | 152 | void visitBeginBorrowInst(BeginBorrowInst *bbi) { |
1649 | 152 | visitValueOwnershipInst(bbi); |
1650 | 152 | } |
1651 | | |
1652 | | /// Handle `move_value` instruction. |
1653 | | /// Original: y = move_value x |
1654 | | /// Adjoint: adj[x] += adj[y] |
1655 | 0 | void visitMoveValueInst(MoveValueInst *mvi) { visitValueOwnershipInst(mvi); } |
1656 | | |
1657 | 40 | void visitEndInitLetRefInst(EndInitLetRefInst *eir) { visitValueOwnershipInst(eir); } |
1658 | | |
1659 | | /// Handle `begin_access` instruction. |
1660 | | /// Original: y = begin_access x |
1661 | | /// Adjoint: nothing |
1662 | 3.13k | void visitBeginAccessInst(BeginAccessInst *bai) { |
1663 | | // Check for non-differentiable writes. |
1664 | 3.13k | if (bai->getAccessKind() == SILAccessKind::Modify) { |
1665 | 1.27k | if (isa<GlobalAddrInst>(bai->getSource())) { |
1666 | 4 | getContext().emitNondifferentiabilityError( |
1667 | 4 | bai, getInvoker(), |
1668 | 4 | diag::autodiff_cannot_differentiate_writes_to_global_variables); |
1669 | 4 | errorOccurred = true; |
1670 | 4 | return; |
1671 | 4 | } |
1672 | 1.27k | if (isa<ProjectBoxInst>(bai->getSource())) { |
1673 | 0 | getContext().emitNondifferentiabilityError( |
1674 | 0 | bai, getInvoker(), |
1675 | 0 | diag::autodiff_cannot_differentiate_writes_to_mutable_captures); |
1676 | 0 | errorOccurred = true; |
1677 | 0 | return; |
1678 | 0 | } |
1679 | 1.27k | } |
1680 | 3.13k | } |
1681 | | |
1682 | | /// Handle `unconditional_checked_cast_addr` instruction. |
1683 | | /// Original: y = unconditional_checked_cast_addr x |
1684 | | /// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y] |
1685 | | void visitUnconditionalCheckedCastAddrInst( |
1686 | 16 | UnconditionalCheckedCastAddrInst *uccai) { |
1687 | 16 | auto *bb = uccai->getParent(); |
1688 | 16 | auto adjDest = getAdjointBuffer(bb, uccai->getDest()); |
1689 | 16 | auto adjSrc = getAdjointBuffer(bb, uccai->getSrc()); |
1690 | 16 | auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType()); |
1691 | 16 | builder.createUnconditionalCheckedCastAddr( |
1692 | 16 | uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf, |
1693 | 16 | adjSrc->getType().getASTType()); |
1694 | 16 | addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc()); |
1695 | 16 | builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf); |
1696 | 16 | builder.createDeallocStack(uccai->getLoc(), castBuf); |
1697 | 16 | builder.emitZeroIntoBuffer(uccai->getLoc(), adjDest, IsInitialization); |
1698 | 16 | } |
1699 | | |
1700 | | /// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr` |
1701 | | /// instructions. |
1702 | | /// |
1703 | | /// Original: y = init_enum_data_addr x |
1704 | | /// inject_enum_addr y |
1705 | | /// |
1706 | | /// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y] |
1707 | 8 | void visitInjectEnumAddrInst(InjectEnumAddrInst *inject) { |
1708 | 8 | SILBasicBlock *bb = inject->getParent(); |
1709 | 8 | SILValue origEnum = inject->getOperand(); |
1710 | | |
1711 | | // Only `Optional`-typed operands are supported for now. Diagnose all other |
1712 | | // enum operand types. |
1713 | 8 | auto *optionalEnumDecl = getASTContext().getOptionalDecl(); |
1714 | 8 | if (origEnum->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) { |
1715 | 0 | LLVM_DEBUG(getADDebugStream() |
1716 | 0 | << "Unsupported enum type in PullbackCloner: " << *inject); |
1717 | 0 | getContext().emitNondifferentiabilityError( |
1718 | 0 | inject, getInvoker(), |
1719 | 0 | diag::autodiff_expression_not_differentiable_note); |
1720 | 0 | errorOccurred = true; |
1721 | 0 | return; |
1722 | 0 | } |
1723 | | |
1724 | 8 | InitEnumDataAddrInst *origData = nullptr; |
1725 | 16 | for (auto use : origEnum->getUses()) { |
1726 | 16 | if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser())) { |
1727 | | // We need a more complicated analysis when init_enum_data_addr and |
1728 | | // inject_enum_addr are in different blocks, or there is more than one |
1729 | | // such instruction. Bail out for now. |
1730 | 8 | if (origData || init->getParent() != bb) { |
1731 | 0 | LLVM_DEBUG(getADDebugStream() |
1732 | 0 | << "Could not find a matching init_enum_data_addr for: " |
1733 | 0 | << *inject); |
1734 | 0 | getContext().emitNondifferentiabilityError( |
1735 | 0 | inject, getInvoker(), |
1736 | 0 | diag::autodiff_expression_not_differentiable_note); |
1737 | 0 | errorOccurred = true; |
1738 | 0 | return; |
1739 | 0 | } |
1740 | | |
1741 | 8 | origData = init; |
1742 | 8 | } |
1743 | 16 | } |
1744 | | |
1745 | 8 | SILValue adjStruct = getAdjointBuffer(bb, origEnum); |
1746 | 8 | StructDecl *adjStructDecl = |
1747 | 8 | adjStruct->getType().getStructOrBoundGenericStruct(); |
1748 | | |
1749 | 8 | VarDecl *adjOptVar = nullptr; |
1750 | 8 | if (adjStructDecl) { |
1751 | 8 | ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties(); |
1752 | 8 | adjOptVar = properties.size() == 1 ? properties[0] : nullptr; |
1753 | 8 | } |
1754 | | |
1755 | 8 | EnumDecl *adjOptDecl = |
1756 | 8 | adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum() |
1757 | 8 | : nullptr; |
1758 | | |
1759 | | // Optional<T>.TangentVector should be a struct with a single |
1760 | | // Optional<T.TangentVector> property. This is an implementation detail of |
1761 | | // OptionalDifferentiation.swift |
1762 | 8 | if (!adjOptDecl || adjOptDecl != optionalEnumDecl) |
1763 | 0 | llvm_unreachable("Unexpected type of Optional.TangentVector"); |
1764 | | |
1765 | 8 | SILLocation loc = origData->getLoc(); |
1766 | 8 | StructElementAddrInst *adjOpt = |
1767 | 8 | builder.createStructElementAddr(loc, adjStruct, adjOptVar); |
1768 | | |
1769 | | // unchecked_take_enum_data_addr is destructive, so copy |
1770 | | // Optional<T.TangentVector> to a new alloca. |
1771 | 8 | AllocStackInst *adjOptCopy = |
1772 | 8 | createFunctionLocalAllocation(adjOpt->getType(), loc); |
1773 | 8 | builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake, |
1774 | 8 | IsInitialization); |
1775 | | |
1776 | 8 | EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl(); |
1777 | 8 | UncheckedTakeEnumDataAddrInst *adjData = |
1778 | 8 | builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl); |
1779 | | |
1780 | 8 | setAdjointBuffer(bb, origData, adjData); |
1781 | | |
1782 | | // The Optional copy is invalidated, do not attempt to destroy it at the end |
1783 | | // of the pullback. The value returned from unchecked_take_enum_data_addr is |
1784 | | // destroyed in visitInitEnumDataAddrInst. |
1785 | 8 | destroyedLocalAllocations.insert(adjOptCopy); |
1786 | 8 | } |
1787 | | |
1788 | | /// Handle `init_enum_data_addr` instruction. |
1789 | | /// Destroy the value returned from `unchecked_take_enum_data_addr`. |
1790 | 8 | void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) { |
1791 | 8 | auto bufIt = bufferMap.find({init->getParent(), SILValue(init)}); |
1792 | 8 | if (bufIt == bufferMap.end()) |
1793 | 0 | return; |
1794 | 8 | SILValue adjData = bufIt->second; |
1795 | 8 | builder.emitDestroyAddr(init->getLoc(), adjData); |
1796 | 8 | } |
1797 | | |
1798 | | /// Handle `unchecked_ref_cast` instruction. |
1799 | | /// Original: y = unchecked_ref_cast x |
1800 | | /// Adjoint: adj[x] += adj[y] |
1801 | | /// (assuming adj[x] and adj[y] have the same type) |
1802 | 8 | void visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { |
1803 | 8 | auto *bb = urci->getParent(); |
1804 | 8 | assert(urci->getOperand()->getType().isObject()); |
1805 | 0 | assert(getRemappedTangentType(urci->getOperand()->getType()) == |
1806 | 8 | getRemappedTangentType(urci->getType()) && |
1807 | 8 | "Operand/result must have the same `TangentVector` type"); |
1808 | 0 | switch (getTangentValueCategory(urci)) { |
1809 | 0 | case SILValueCategory::Object: { |
1810 | 0 | auto adj = getAdjointValue(bb, urci); |
1811 | 0 | addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc()); |
1812 | 0 | break; |
1813 | 0 | } |
1814 | 8 | case SILValueCategory::Address: { |
1815 | 8 | auto adjDest = getAdjointBuffer(bb, urci); |
1816 | 8 | addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc()); |
1817 | 8 | builder.emitZeroIntoBuffer(urci->getLoc(), adjDest, IsNotInitialization); |
1818 | 8 | break; |
1819 | 0 | } |
1820 | 8 | } |
1821 | 8 | } |
1822 | | |
1823 | | /// Handle `upcast` instruction. |
1824 | | /// Original: y = upcast x |
1825 | | /// Adjoint: adj[x] += adj[y] |
1826 | | /// (assuming adj[x] and adj[y] have the same type) |
1827 | 24 | void visitUpcastInst(UpcastInst *ui) { |
1828 | 24 | auto *bb = ui->getParent(); |
1829 | 24 | assert(ui->getOperand()->getType().isObject()); |
1830 | 0 | assert(getRemappedTangentType(ui->getOperand()->getType()) == |
1831 | 24 | getRemappedTangentType(ui->getType()) && |
1832 | 24 | "Operand/result must have the same `TangentVector` type"); |
1833 | 0 | switch (getTangentValueCategory(ui)) { |
1834 | 8 | case SILValueCategory::Object: { |
1835 | 8 | auto adj = getAdjointValue(bb, ui); |
1836 | 8 | addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc()); |
1837 | 8 | break; |
1838 | 0 | } |
1839 | 16 | case SILValueCategory::Address: { |
1840 | 16 | auto adjDest = getAdjointBuffer(bb, ui); |
1841 | 16 | addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc()); |
1842 | 16 | builder.emitZeroIntoBuffer(ui->getLoc(), adjDest, IsNotInitialization); |
1843 | 16 | break; |
1844 | 0 | } |
1845 | 24 | } |
1846 | 24 | } |
1847 | | |
1848 | | /// Handle `unchecked_take_enum_data_addr` instruction. |
1849 | | /// Currently, only `Optional`-typed operands are supported. |
1850 | | /// Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case |
1851 | | /// Adjoint: adj[x] += $Enum.TangentVector(adj[y]) |
1852 | | void |
1853 | 112 | visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) { |
1854 | 112 | auto *bb = utedai->getParent(); |
1855 | 112 | auto adjDest = getAdjointBuffer(bb, utedai); |
1856 | 112 | auto enumTy = utedai->getOperand()->getType(); |
1857 | 112 | auto *optionalEnumDecl = getASTContext().getOptionalDecl(); |
1858 | | // Only `Optional`-typed operands are supported for now. Diagnose all other |
1859 | | // enum operand types. |
1860 | 112 | if (enumTy.getASTType().getEnumOrBoundGenericEnum() != optionalEnumDecl) { |
1861 | 0 | LLVM_DEBUG(getADDebugStream() |
1862 | 0 | << "Unhandled instruction in PullbackCloner: " << *utedai); |
1863 | 0 | getContext().emitNondifferentiabilityError( |
1864 | 0 | utedai, getInvoker(), |
1865 | 0 | diag::autodiff_expression_not_differentiable_note); |
1866 | 0 | errorOccurred = true; |
1867 | 0 | return; |
1868 | 0 | } |
1869 | 112 | accumulateAdjointForOptionalBuffer(bb, utedai->getOperand(), adjDest); |
1870 | 112 | builder.emitZeroIntoBuffer(utedai->getLoc(), adjDest, IsNotInitialization); |
1871 | 112 | } |
1872 | | |
1873 | | #define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst); |
1874 | | #undef NOT_DIFFERENTIABLE |
1875 | | |
1876 | | #define NO_ADJOINT(INST) \ |
1877 | 16.5k | void visit##INST##Inst(INST##Inst *inst) {}_ZN5swift8autodiff14PullbackCloner14Implementation19visitAllocStackInstEPNS_14AllocStackInstE Line | Count | Source | 1877 | 4.24k | void visit##INST##Inst(INST##Inst *inst) {} |
_ZN5swift8autodiff14PullbackCloner14Implementation18visitIndexAddrInstEPNS_13IndexAddrInstE Line | Count | Source | 1877 | 124 | void visit##INST##Inst(INST##Inst *inst) {} |
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation25visitPointerToAddressInstEPNS_20PointerToAddressInstE _ZN5swift8autodiff14PullbackCloner14Implementation25visitTupleElementAddrInstEPNS_20TupleElementAddrInstE Line | Count | Source | 1877 | 1.22k | void visit##INST##Inst(INST##Inst *inst) {} |
_ZN5swift8autodiff14PullbackCloner14Implementation26visitStructElementAddrInstEPNS_21StructElementAddrInstE Line | Count | Source | 1877 | 964 | void visit##INST##Inst(INST##Inst *inst) {} |
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation15visitReturnInstEPNS_10ReturnInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation15visitBranchInstEPNS_10BranchInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation19visitCondBranchInstEPNS_14CondBranchInstE _ZN5swift8autodiff14PullbackCloner14Implementation21visitDeallocStackInstEPNS_16DeallocStackInstE Line | Count | Source | 1877 | 4.70k | void visit##INST##Inst(INST##Inst *inst) {} |
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation21visitStrongRetainInstEPNS_16StrongRetainInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation22visitStrongReleaseInstEPNS_17StrongReleaseInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation28visitStrongRetainUnownedInstEPNS_23StrongRetainUnownedInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation22visitUnownedRetainInstEPNS_17UnownedRetainInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation23visitUnownedReleaseInstEPNS_18UnownedReleaseInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation20visitRetainValueInstEPNS_15RetainValueInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation24visitRetainValueAddrInstEPNS_19RetainValueAddrInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation21visitReleaseValueInstEPNS_16ReleaseValueInstE Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation25visitReleaseValueAddrInstEPNS_20ReleaseValueAddrInstE _ZN5swift8autodiff14PullbackCloner14Implementation21visitDestroyValueInstEPNS_16DestroyValueInstE Line | Count | Source | 1877 | 492 | void visit##INST##Inst(INST##Inst *inst) {} |
_ZN5swift8autodiff14PullbackCloner14Implementation18visitEndBorrowInstEPNS_13EndBorrowInstE Line | Count | Source | 1877 | 228 | void visit##INST##Inst(INST##Inst *inst) {} |
_ZN5swift8autodiff14PullbackCloner14Implementation18visitEndAccessInstEPNS_13EndAccessInstE Line | Count | Source | 1877 | 3.16k | void visit##INST##Inst(INST##Inst *inst) {} |
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation19visitDebugValueInstEPNS_14DebugValueInstE _ZN5swift8autodiff14PullbackCloner14Implementation20visitDestroyAddrInstEPNS_15DestroyAddrInstE Line | Count | Source | 1877 | 1.44k | void visit##INST##Inst(INST##Inst *inst) {} |
|
1878 | | // Terminators. |
1879 | | NO_ADJOINT(Return) |
1880 | | NO_ADJOINT(Branch) |
1881 | | NO_ADJOINT(CondBranch) |
1882 | | |
1883 | | // Address projections. |
1884 | | NO_ADJOINT(StructElementAddr) |
1885 | | NO_ADJOINT(TupleElementAddr) |
1886 | | |
1887 | | // Array literal initialization address projections. |
1888 | | NO_ADJOINT(PointerToAddress) |
1889 | | NO_ADJOINT(IndexAddr) |
1890 | | |
1891 | | // Memory allocation/access. |
1892 | | NO_ADJOINT(AllocStack) |
1893 | | NO_ADJOINT(DeallocStack) |
1894 | | NO_ADJOINT(EndAccess) |
1895 | | |
1896 | | // Debugging/reference counting instructions. |
1897 | | NO_ADJOINT(DebugValue) |
1898 | | NO_ADJOINT(RetainValue) |
1899 | | NO_ADJOINT(RetainValueAddr) |
1900 | | NO_ADJOINT(ReleaseValue) |
1901 | | NO_ADJOINT(ReleaseValueAddr) |
1902 | | NO_ADJOINT(StrongRetain) |
1903 | | NO_ADJOINT(StrongRelease) |
1904 | | NO_ADJOINT(UnownedRetain) |
1905 | | NO_ADJOINT(UnownedRelease) |
1906 | | NO_ADJOINT(StrongRetainUnowned) |
1907 | | NO_ADJOINT(DestroyValue) |
1908 | | NO_ADJOINT(DestroyAddr) |
1909 | | |
1910 | | // Value ownership. |
1911 | | NO_ADJOINT(EndBorrow) |
1912 | | #undef NO_ADJOINT |
1913 | | }; |
1914 | | |
1915 | | PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner) |
1916 | | : vjpCloner(vjpCloner), scopeCloner(getPullback()), |
1917 | | builder(getPullback(), getContext()), |
1918 | 5.23k | localAllocBuilder(getPullback(), getContext()) { |
1919 | | // Get dominance and post-order info for the original function. |
1920 | 5.23k | auto &passManager = getContext().getPassManager(); |
1921 | 5.23k | auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>(); |
1922 | 5.23k | auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>(); |
1923 | 5.23k | auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>(); |
1924 | 5.23k | auto *original = &vjpCloner.getOriginal(); |
1925 | 5.23k | domInfo = domAnalysis->get(original); |
1926 | 5.23k | postDomInfo = postDomAnalysis->get(original); |
1927 | 5.23k | postOrderInfo = postOrderAnalysis->get(original); |
1928 | | // Initialize `originalExitBlock`. |
1929 | 5.23k | auto origExitIt = original->findReturnBB(); |
1930 | 5.23k | assert(origExitIt != original->end() && |
1931 | 5.23k | "Functions without returns must have been diagnosed"); |
1932 | 0 | originalExitBlock = &*origExitIt; |
1933 | 5.23k | localAllocBuilder.setCurrentDebugScope( |
1934 | 5.23k | remapScope(originalExitBlock->getTerminator()->getDebugScope())); |
1935 | 5.23k | } |
1936 | | |
1937 | | PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) |
1938 | 5.23k | : impl(*new Implementation(vjpCloner)) {} |
1939 | | |
1940 | 5.23k | PullbackCloner::~PullbackCloner() { delete &impl; } |
1941 | | |
1942 | | //--------------------------------------------------------------------------// |
1943 | | // Entry point |
1944 | | //--------------------------------------------------------------------------// |
1945 | | |
1946 | 5.23k | bool PullbackCloner::run() { |
1947 | 5.23k | bool foundError = impl.run(); |
1948 | 5.23k | #ifndef NDEBUG |
1949 | 5.23k | if (!foundError) |
1950 | 5.10k | impl.getPullback().verify(); |
1951 | 5.23k | #endif |
1952 | 5.23k | return foundError; |
1953 | 5.23k | } |
1954 | | |
1955 | 5.23k | bool PullbackCloner::Implementation::run() { |
1956 | 5.23k | PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal()); |
1957 | 5.23k | auto &original = getOriginal(); |
1958 | 5.23k | auto &pullback = getPullback(); |
1959 | 5.23k | auto pbLoc = getPullback().getLocation(); |
1960 | 5.23k | LLVM_DEBUG(getADDebugStream() << "Running PullbackCloner on\n" << original); |
1961 | | |
1962 | | // Collect original formal results. |
1963 | 5.23k | SmallVector<SILValue, 8> origFormalResults; |
1964 | 5.23k | collectAllFormalResultsInTypeOrder(original, origFormalResults); |
1965 | 5.32k | for (auto resultIndex : getConfig().resultIndices->getIndices()) { |
1966 | 5.32k | auto origResult = origFormalResults[resultIndex]; |
1967 | | // If original result is non-varied, it will always have a zero derivative. |
1968 | | // Skip full pullback generation and simply emit zero derivatives for wrt |
1969 | | // parameters. |
1970 | | // |
1971 | | // NOTE(TF-876): This shortcut is currently necessary for functions |
1972 | | // returning non-varied result with >1 basic block where some basic blocks |
1973 | | // have no dominated active values; control flow differentiation does not |
1974 | | // handle this case. See TF-876 for context. |
1975 | 5.32k | if (!getActivityInfo().isVaried(origResult, getConfig().parameterIndices)) { |
1976 | 112 | emitZeroDerivativesForNonvariedResult(origResult); |
1977 | 112 | return false; |
1978 | 112 | } |
1979 | 5.32k | } |
1980 | | |
1981 | | // Collect dominated active values in original basic blocks. |
1982 | | // Adjoint values of dominated active values are passed as pullback block |
1983 | | // arguments. |
1984 | 5.12k | DominanceOrder domOrder(original.getEntryBlock(), domInfo); |
1985 | | // Keep track of visited values. |
1986 | 5.12k | SmallPtrSet<SILValue, 8> visited; |
1987 | 12.0k | while (auto *bb = domOrder.getNext()) { |
1988 | 6.96k | auto &bbActiveValues = activeValues[bb]; |
1989 | | // If the current block has an immediate dominator, append the immediate |
1990 | | // dominator block's active values to the current block's active values. |
1991 | 6.96k | if (auto *domNode = domInfo->getNode(bb)->getIDom()) { |
1992 | 1.83k | auto &domBBActiveValues = activeValues[domNode->getBlock()]; |
1993 | 1.83k | bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end()); |
1994 | 1.83k | } |
1995 | | // If `v` is active and has not been visited, records it as an active value |
1996 | | // in the original basic block. |
1997 | | // For active values unsupported by differentiation, emits a diagnostic and |
1998 | | // returns true. Otherwise, returns false. |
1999 | 146k | auto recordValueIfActive = [&](SILValue v) -> bool { |
2000 | | // If value is not active, skip. |
2001 | 146k | if (!getActivityInfo().isActive(v, getConfig())) |
2002 | 64.3k | return false; |
2003 | | // If active value has already been visited, skip. |
2004 | 81.8k | if (visited.count(v)) |
2005 | 54.4k | return false; |
2006 | | // Mark active value as visited. |
2007 | 27.4k | visited.insert(v); |
2008 | | |
2009 | | // Diagnose unsupported active values. |
2010 | 27.4k | auto type = v->getType(); |
2011 | | // Do not emit remaining activity-related diagnostics for semantic member |
2012 | | // accessors, which have special-case pullback generation. |
2013 | 27.4k | if (isSemanticMemberAccessor(&original)) |
2014 | 1.32k | return false; |
2015 | | // Diagnose active enum values. Differentiation of enum values requires |
2016 | | // special adjoint value handling and is not yet supported. Diagnose |
2017 | | // only the first active enum value to prevent too many diagnostics. |
2018 | | // |
2019 | | // Do not diagnose `Optional`-typed values, which will have special-case |
2020 | | // differentiation support. |
2021 | 26.1k | if (auto *enumDecl = type.getEnumOrBoundGenericEnum()) { |
2022 | 940 | if (!type.getASTType()->isOptional()) { |
2023 | 40 | getContext().emitNondifferentiabilityError( |
2024 | 40 | v, getInvoker(), diag::autodiff_enums_unsupported); |
2025 | 40 | errorOccurred = true; |
2026 | 40 | return true; |
2027 | 40 | } |
2028 | 940 | } |
2029 | | // Diagnose unsupported stored property projections. |
2030 | 26.1k | if (isa<StructExtractInst>(v) || isa<RefElementAddrInst>(v) || |
2031 | 26.1k | isa<StructElementAddrInst>(v)) { |
2032 | 1.63k | auto *inst = cast<SingleValueInstruction>(v); |
2033 | 1.63k | assert(inst->getNumOperands() == 1); |
2034 | 0 | auto baseType = remapType(inst->getOperand(0)->getType()).getASTType(); |
2035 | 1.63k | if (!getTangentStoredProperty(getContext(), inst, baseType, |
2036 | 1.63k | getInvoker())) { |
2037 | 32 | errorOccurred = true; |
2038 | 32 | return true; |
2039 | 32 | } |
2040 | 1.63k | } |
2041 | | // Skip address projections. |
2042 | | // Address projections do not need their own adjoint buffers; they |
2043 | | // become projections into their adjoint base buffer. |
2044 | 26.0k | if (Projection::isAddressProjection(v)) |
2045 | 2.55k | return false; |
2046 | | |
2047 | | // Check that active values are differentiable. Otherwise we may crash |
2048 | | // later when tangent space is required, but not available. |
2049 | 23.5k | if (!getTangentSpace(remapType(type).getASTType())) { |
2050 | 4 | getContext().emitNondifferentiabilityError( |
2051 | 4 | v, getInvoker(), diag::autodiff_expression_not_differentiable_note); |
2052 | 4 | errorOccurred = true; |
2053 | 4 | return true; |
2054 | 4 | } |
2055 | | |
2056 | | // Record active value. |
2057 | 23.5k | bbActiveValues.push_back(v); |
2058 | 23.5k | return false; |
2059 | 23.5k | }; |
2060 | | // Record all active values in the basic block. |
2061 | 6.96k | for (auto *arg : bb->getArguments()) |
2062 | 10.1k | if (recordValueIfActive(arg)) |
2063 | 32 | return true; |
2064 | 83.4k | for (auto &inst : *bb) { |
2065 | 83.4k | for (auto op : inst.getOperandValues()) |
2066 | 87.2k | if (recordValueIfActive(op)) |
2067 | 0 | return true; |
2068 | 83.4k | for (auto result : inst.getResults()) |
2069 | 48.9k | if (recordValueIfActive(result)) |
2070 | 44 | return true; |
2071 | 83.4k | } |
2072 | 6.88k | domOrder.pushChildren(bb); |
2073 | 6.88k | } |
2074 | | |
2075 | | // Create pullback blocks and arguments, visiting original blocks using BFS |
2076 | | // starting from the original exit block. Unvisited original basic blocks |
2077 | | // (e.g unreachable blocks) are not relevant for pullback generation and thus |
2078 | | // ignored. |
2079 | | // The original blocks in traversal order for pullback generation. |
2080 | 5.04k | SmallVector<SILBasicBlock *, 8> originalBlocks; |
2081 | | // The workqueue used for bookkeeping during the breadth-first traversal. |
2082 | 5.04k | BasicBlockWorkqueue workqueue = {originalExitBlock}; |
2083 | | |
2084 | | // Perform BFS from the original exit block. |
2085 | 5.04k | { |
2086 | 11.8k | while (auto *BB = workqueue.pop()) { |
2087 | 6.80k | originalBlocks.push_back(BB); |
2088 | | |
2089 | 6.80k | for (auto *nextBB : BB->getPredecessorBlocks()) { |
2090 | 2.36k | workqueue.pushIfNotVisited(nextBB); |
2091 | 2.36k | } |
2092 | 6.80k | } |
2093 | 5.04k | } |
2094 | | |
2095 | 6.80k | for (auto *origBB : originalBlocks) { |
2096 | 6.80k | auto *pullbackBB = pullback.createBasicBlock(); |
2097 | 6.80k | pullbackBBMap.insert({origBB, pullbackBB}); |
2098 | 6.80k | auto pbTupleLoweredType = |
2099 | 6.80k | remapType(getPullbackInfo().getLinearMapTupleLoweredType(origBB)); |
2100 | | // If the BB is the original exit, then the pullback block that we just |
2101 | | // created must be the pullback function's entry. For the pullback entry, |
2102 | | // create entry arguments and continue to the next block. |
2103 | 6.80k | if (origBB == originalExitBlock) { |
2104 | 5.04k | assert(pullbackBB->isEntry()); |
2105 | 0 | createEntryArguments(&pullback); |
2106 | 5.04k | auto *origTerm = originalExitBlock->getTerminator(); |
2107 | 5.04k | builder.setCurrentDebugScope(remapScope(origTerm->getDebugScope())); |
2108 | 5.04k | builder.setInsertionPoint(pullbackBB); |
2109 | | // Obtain the context object, if any, and the top-level subcontext, i.e. |
2110 | | // the main pullback struct. |
2111 | 5.04k | if (getPullbackInfo().hasHeapAllocatedContext()) { |
2112 | | // The last argument is the context object (`Builtin.NativeObject`). |
2113 | 100 | contextValue = pullbackBB->getArguments().back(); |
2114 | 100 | assert(contextValue->getType() == |
2115 | 100 | SILType::getNativeObjectType(getASTContext())); |
2116 | | // Load the pullback context. |
2117 | 0 | auto subcontextAddr = emitProjectTopLevelSubcontext( |
2118 | 100 | builder, pbLoc, contextValue, pbTupleLoweredType); |
2119 | 100 | SILValue mainPullbackTuple = builder.createLoad( |
2120 | 100 | pbLoc, subcontextAddr, |
2121 | 100 | pbTupleLoweredType.isTrivial(getPullback()) ? |
2122 | 88 | LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy); |
2123 | 100 | auto *dsi = builder.createDestructureTuple(pbLoc, mainPullbackTuple); |
2124 | 100 | initializePullbackTupleElements(origBB, dsi->getAllResults()); |
2125 | 4.94k | } else { |
2126 | | // Obtain and destructure pullback struct elements. |
2127 | 4.94k | unsigned numVals = pbTupleLoweredType.getAs<TupleType>()->getNumElements(); |
2128 | 4.94k | initializePullbackTupleElements(origBB, |
2129 | 4.94k | pullbackBB->getArguments().take_back(numVals)); |
2130 | 4.94k | } |
2131 | | |
2132 | 0 | continue; |
2133 | 5.04k | } |
2134 | | |
2135 | | // Get all active values in the original block. |
2136 | | // If the original block has no active values, continue. |
2137 | 1.75k | auto &bbActiveValues = activeValues[origBB]; |
2138 | 1.75k | if (bbActiveValues.empty()) |
2139 | 0 | continue; |
2140 | | |
2141 | | // Otherwise, if the original block has active values: |
2142 | | // - For each active buffer in the original block, allocate a new local |
2143 | | // buffer in the pullback entry. (All adjoint buffers are allocated in |
2144 | | // the pullback entry and deallocated in the pullback exit.) |
2145 | | // - For each active value in the original block, add adjoint value |
2146 | | // arguments to the pullback block. |
2147 | 9.11k | for (auto activeValue : bbActiveValues) { |
2148 | | // Handle the active value based on its value category. |
2149 | 9.11k | switch (getTangentValueCategory(activeValue)) { |
2150 | 4.76k | case SILValueCategory::Address: { |
2151 | | // Allocate and zero initialize a new local buffer using |
2152 | | // `getAdjointBuffer`. |
2153 | 4.76k | builder.setCurrentDebugScope( |
2154 | 4.76k | remapScope(originalExitBlock->getTerminator()->getDebugScope())); |
2155 | 4.76k | builder.setInsertionPoint(pullback.getEntryBlock()); |
2156 | 4.76k | getAdjointBuffer(origBB, activeValue); |
2157 | 4.76k | break; |
2158 | 0 | } |
2159 | 4.35k | case SILValueCategory::Object: { |
2160 | | // Create and register pullback block argument for the active value. |
2161 | 4.35k | auto *pullbackArg = pullbackBB->createPhiArgument( |
2162 | 4.35k | getRemappedTangentType(activeValue->getType()), |
2163 | 4.35k | OwnershipKind::Owned); |
2164 | 4.35k | activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg; |
2165 | 4.35k | recordTemporary(pullbackArg); |
2166 | 4.35k | break; |
2167 | 0 | } |
2168 | 9.11k | } |
2169 | 9.11k | } |
2170 | | // Add a pullback tuple argument. |
2171 | 1.75k | auto *pbTupleArg = pullbackBB->createPhiArgument(pbTupleLoweredType, |
2172 | 1.75k | OwnershipKind::Owned); |
2173 | | // Destructure the pullback struct to get the elements. |
2174 | 1.75k | builder.setCurrentDebugScope( |
2175 | 1.75k | remapScope(origBB->getTerminator()->getDebugScope())); |
2176 | 1.75k | builder.setInsertionPoint(pullbackBB); |
2177 | 1.75k | auto *dsi = builder.createDestructureTuple(pbLoc, pbTupleArg); |
2178 | 1.75k | initializePullbackTupleElements(origBB, dsi->getResults()); |
2179 | | |
2180 | | // - Create pullback trampoline blocks for each successor block of the |
2181 | | // original block. Pullback trampoline blocks only have a pullback |
2182 | | // struct argument. They branch from a pullback successor block to the |
2183 | | // pullback original block, passing adjoint values of active values. |
2184 | 2.40k | for (auto *succBB : origBB->getSuccessorBlocks()) { |
2185 | | // Skip generating pullback block for original unreachable blocks. |
2186 | 2.40k | if (!workqueue.isVisited(succBB)) |
2187 | 44 | continue; |
2188 | 2.36k | auto *pullbackTrampolineBB = pullback.createBasicBlockBefore(pullbackBB); |
2189 | 2.36k | pullbackTrampolineBBMap.insert({{origBB, succBB}, pullbackTrampolineBB}); |
2190 | | // Get the enum element type (i.e. the pullback struct type). The enum |
2191 | | // element type may be boxed if the enum is indirect. |
2192 | 2.36k | auto enumLoweredTy = |
2193 | 2.36k | getPullbackInfo().getBranchingTraceEnumLoweredType(succBB); |
2194 | 2.36k | auto *enumEltDecl = |
2195 | 2.36k | getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB); |
2196 | 2.36k | auto enumEltType = remapType(enumLoweredTy.getEnumElementType( |
2197 | 2.36k | enumEltDecl, getModule(), TypeExpansionContext::minimal())); |
2198 | 2.36k | pullbackTrampolineBB->createPhiArgument(enumEltType, |
2199 | 2.36k | OwnershipKind::Owned); |
2200 | 2.36k | } |
2201 | 1.75k | } |
2202 | | |
2203 | 5.04k | auto *pullbackEntry = pullback.getEntryBlock(); |
2204 | 5.04k | auto pbTupleLoweredType = |
2205 | 5.04k | remapType(getPullbackInfo().getLinearMapTupleLoweredType(originalExitBlock)); |
2206 | 5.04k | unsigned numVals = (getPullbackInfo().hasHeapAllocatedContext() ? |
2207 | 4.94k | 1 : pbTupleLoweredType.getAs<TupleType>()->getNumElements()); |
2208 | 5.04k | (void)numVals; |
2209 | | |
2210 | | // The pullback function has type: |
2211 | | // `(seed0, seed1, ..., (exit_pb_tuple_el0, ..., )|context_obj) -> (d_arg0, ..., d_argn)`. |
2212 | 5.04k | auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults(); |
2213 | 5.04k | assert(getConfig().resultIndices->getNumIndices() == pbParamArgs.size() - numVals && |
2214 | 5.04k | pbParamArgs.size() >= 1); |
2215 | | // Assign adjoints for original result. |
2216 | 0 | builder.setCurrentDebugScope( |
2217 | 5.04k | remapScope(originalExitBlock->getTerminator()->getDebugScope())); |
2218 | 5.04k | builder.setInsertionPoint(pullbackEntry, |
2219 | 5.04k | getNextFunctionLocalAllocationInsertionPoint()); |
2220 | 5.04k | unsigned seedIndex = 0; |
2221 | 5.14k | for (auto resultIndex : getConfig().resultIndices->getIndices()) { |
2222 | 5.14k | auto origResult = origFormalResults[resultIndex]; |
2223 | 5.14k | auto *seed = pbParamArgs[seedIndex]; |
2224 | 5.14k | if (seed->getType().isAddress()) { |
2225 | | // If the seed argument is an `inout` parameter, assign it directly as |
2226 | | // the adjoint buffer of the original result. |
2227 | 1.63k | auto seedParamInfo = |
2228 | 1.63k | pullback.getLoweredFunctionType()->getParameters()[seedIndex]; |
2229 | | |
2230 | 1.63k | if (seedParamInfo.isIndirectInOut()) { |
2231 | 376 | setAdjointBuffer(originalExitBlock, origResult, seed); |
2232 | 376 | } |
2233 | | // Otherwise, assign a copy of the seed argument as the adjoint buffer of |
2234 | | // the original result. |
2235 | 1.26k | else { |
2236 | 1.26k | auto *seedBufCopy = |
2237 | 1.26k | createFunctionLocalAllocation(seed->getType(), pbLoc); |
2238 | 1.26k | builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake, |
2239 | 1.26k | IsInitialization); |
2240 | 1.26k | setAdjointBuffer(originalExitBlock, origResult, seedBufCopy); |
2241 | 1.26k | LLVM_DEBUG(getADDebugStream() |
2242 | 1.26k | << "Assigned seed buffer " << *seedBufCopy |
2243 | 1.26k | << " as the adjoint of original indirect result " |
2244 | 1.26k | << origResult); |
2245 | 1.26k | } |
2246 | 3.50k | } else { |
2247 | 3.50k | addAdjointValue(originalExitBlock, origResult, makeConcreteAdjointValue(seed), |
2248 | 3.50k | pbLoc); |
2249 | 3.50k | LLVM_DEBUG(getADDebugStream() |
2250 | 3.50k | << "Assigned seed " << *seed |
2251 | 3.50k | << " as the adjoint of original result " << origResult); |
2252 | 3.50k | } |
2253 | 5.14k | ++seedIndex; |
2254 | 5.14k | } |
2255 | | |
2256 | | // If the original function is an accessor with special-case pullback |
2257 | | // generation logic, do special-case generation. |
2258 | 5.04k | if (isSemanticMemberAccessor(&original)) { |
2259 | 256 | if (runForSemanticMemberAccessor()) |
2260 | 0 | return true; |
2261 | 256 | } |
2262 | | // Otherwise, perform standard pullback generation. |
2263 | | // Visit original blocks in post-order and perform differentiation |
2264 | | // in corresponding pullback blocks. If errors occurred, back out. |
2265 | 4.79k | else { |
2266 | 6.54k | for (auto *bb : originalBlocks) { |
2267 | 6.54k | visitSILBasicBlock(bb); |
2268 | 6.54k | if (errorOccurred) |
2269 | 56 | return true; |
2270 | 6.54k | } |
2271 | 4.79k | } |
2272 | | |
2273 | | // Prepare and emit a `return` in the pullback exit block. |
2274 | 4.99k | auto *origEntry = getOriginal().getEntryBlock(); |
2275 | 4.99k | auto *pbExit = getPullbackBlock(origEntry); |
2276 | 4.99k | builder.setCurrentDebugScope(pbExit->back().getDebugScope()); |
2277 | 4.99k | builder.setInsertionPoint(pbExit); |
2278 | | |
2279 | | // This vector will contain all the materialized return elements. |
2280 | 4.99k | SmallVector<SILValue, 8> retElts; |
2281 | | // This vector will contain all indirect parameter adjoint buffers. |
2282 | 4.99k | SmallVector<SILValue, 4> indParamAdjoints; |
2283 | | // This vector will identify the locations where initialization is needed. |
2284 | 4.99k | SmallBitVector outputsToInitialize; |
2285 | | |
2286 | 4.99k | auto conv = getOriginal().getConventions(); |
2287 | 4.99k | auto origParams = getOriginal().getArgumentsWithoutIndirectResults(); |
2288 | | |
2289 | | // Materializes the return element corresponding to the parameter |
2290 | | // `parameterIndex` into the `retElts` vector. |
2291 | 6.64k | auto addRetElt = [&](unsigned parameterIndex) -> void { |
2292 | 6.64k | auto origParam = origParams[parameterIndex]; |
2293 | 6.64k | switch (getTangentValueCategory(origParam)) { |
2294 | 4.84k | case SILValueCategory::Object: { |
2295 | 4.84k | auto pbVal = getAdjointValue(origEntry, origParam); |
2296 | 4.84k | auto val = materializeAdjointDirect(pbVal, pbLoc); |
2297 | 4.84k | auto newVal = builder.emitCopyValueOperation(pbLoc, val); |
2298 | 4.84k | retElts.push_back(newVal); |
2299 | 4.84k | break; |
2300 | 0 | } |
2301 | 1.80k | case SILValueCategory::Address: { |
2302 | 1.80k | auto adjBuf = getAdjointBuffer(origEntry, origParam); |
2303 | 1.80k | indParamAdjoints.push_back(adjBuf); |
2304 | 1.80k | outputsToInitialize.push_back( |
2305 | 1.80k | !conv.getParameters()[parameterIndex].isIndirectMutating()); |
2306 | 1.80k | break; |
2307 | 0 | } |
2308 | 6.64k | } |
2309 | 6.64k | }; |
2310 | 4.99k | SmallVector<SILArgument *, 4> pullbackIndirectResults( |
2311 | 4.99k | getPullback().getIndirectResults().begin(), |
2312 | 4.99k | getPullback().getIndirectResults().end()); |
2313 | | |
2314 | | // Collect differentiation parameter adjoints. |
2315 | | // Do a first pass to collect non-inout values. |
2316 | 7.00k | for (auto i : getConfig().parameterIndices->getIndices()) { |
2317 | 7.00k | if (!conv.getParameters()[i].isAutoDiffSemanticResult()) { |
2318 | 6.62k | addRetElt(i); |
2319 | 6.62k | } |
2320 | 7.00k | } |
2321 | | |
2322 | | // Do a second pass for all inout parameters, however this is only necessary |
2323 | | // for functions with multiple basic blocks. For functions with a single |
2324 | | // basic block adjoint accumulation for those parameters is already done by |
2325 | | // per-instruction visitors. |
2326 | 4.99k | if (getOriginal().size() > 1) { |
2327 | 448 | const auto &pullbackConv = pullback.getConventions(); |
2328 | 448 | SmallVector<SILArgument *, 1> pullbackInOutArgs; |
2329 | 912 | for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) { |
2330 | 912 | if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult()) |
2331 | 20 | pullbackInOutArgs.push_back(pullbackArg.value()); |
2332 | 912 | } |
2333 | | |
2334 | 448 | unsigned pullbackInoutArgumentIdx = 0; |
2335 | 532 | for (auto i : getConfig().parameterIndices->getIndices()) { |
2336 | | // Skip non-inout parameters. |
2337 | 532 | if (!conv.getParameters()[i].isAutoDiffSemanticResult()) |
2338 | 512 | continue; |
2339 | | |
2340 | | // For functions with multiple basic blocks, accumulation is needed |
2341 | | // for `inout` parameters because pullback basic blocks have different |
2342 | | // adjoint buffers. |
2343 | 20 | pullbackIndirectResults.push_back(pullbackInOutArgs[pullbackInoutArgumentIdx++]); |
2344 | 20 | addRetElt(i); |
2345 | 20 | } |
2346 | 448 | } |
2347 | | |
2348 | | // Copy them to adjoint indirect results. |
2349 | 4.99k | assert(indParamAdjoints.size() == pullbackIndirectResults.size() && |
2350 | 4.99k | "Indirect parameter adjoint count mismatch"); |
2351 | 0 | unsigned currentIndex = 0; |
2352 | 4.99k | for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) { |
2353 | 1.80k | auto source = std::get<0>(pair); |
2354 | 1.80k | auto *dest = std::get<1>(pair); |
2355 | 1.80k | if (outputsToInitialize[currentIndex]) { |
2356 | 1.78k | builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization); |
2357 | 1.78k | } else { |
2358 | 20 | builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization); |
2359 | 20 | } |
2360 | 1.80k | currentIndex++; |
2361 | | // Prevent source buffer from being deallocated, since the underlying |
2362 | | // value is moved. |
2363 | 1.80k | destroyedLocalAllocations.insert(source); |
2364 | 1.80k | } |
2365 | | |
2366 | | // Emit cleanups for all local values. |
2367 | 4.99k | cleanUpTemporariesForBlock(pbExit, pbLoc); |
2368 | | // Deallocate local allocations. |
2369 | 11.9k | for (auto alloc : functionLocalAllocations) { |
2370 | | // Assert that local allocations have at least one use. |
2371 | | // Buffers should not be allocated needlessly. |
2372 | 11.9k | assert(!alloc->use_empty()); |
2373 | 11.9k | if (!destroyedLocalAllocations.count(alloc)) { |
2374 | 10.0k | builder.emitDestroyAddrAndFold(pbLoc, alloc); |
2375 | 10.0k | destroyedLocalAllocations.insert(alloc); |
2376 | 10.0k | } |
2377 | 11.9k | builder.createDeallocStack(pbLoc, alloc); |
2378 | 11.9k | } |
2379 | 4.99k | builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc)); |
2380 | | |
2381 | 4.99k | #ifndef NDEBUG |
2382 | 4.99k | bool leakFound = false; |
2383 | | // Ensure all temporaries have been cleaned up. |
2384 | 9.02k | for (auto &bb : pullback) { |
2385 | 9.02k | for (auto temp : blockTemporaries[&bb]) { |
2386 | 0 | if (blockTemporaries[&bb].count(temp)) { |
2387 | 0 | leakFound = true; |
2388 | 0 | getADDebugStream() << "Found leaked temporary:\n" << temp; |
2389 | 0 | } |
2390 | 0 | } |
2391 | 9.02k | } |
2392 | | // Ensure all local allocations have been cleaned up. |
2393 | 11.9k | for (auto localAlloc : functionLocalAllocations) { |
2394 | 11.9k | if (!destroyedLocalAllocations.count(localAlloc)) { |
2395 | 0 | leakFound = true; |
2396 | 0 | getADDebugStream() << "Found leaked local buffer:\n" << localAlloc; |
2397 | 0 | } |
2398 | 11.9k | } |
2399 | 4.99k | assert(!leakFound && "Leaks found!"); |
2400 | 0 | #endif |
2401 | | |
2402 | 4.99k | LLVM_DEBUG(getADDebugStream() |
2403 | 4.99k | << "Generated pullback for " << original.getName() << ":\n" |
2404 | 4.99k | << pullback); |
2405 | 4.99k | return errorOccurred; |
2406 | 5.04k | } |
2407 | | |
2408 | | void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult( |
2409 | 112 | SILValue origNonvariedResult) { |
2410 | 112 | auto &pullback = getPullback(); |
2411 | 112 | auto pbLoc = getPullback().getLocation(); |
2412 | | /* |
2413 | | // TODO(TF-788): Re-enable non-varied result warning. |
2414 | | // Emit fixit if original non-varied result has a valid source location. |
2415 | | auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc(); |
2416 | | auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc(); |
2417 | | if (startLoc.isValid() && endLoc.isValid()) { |
2418 | | getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) |
2419 | | .fixItInsert(startLoc, "withoutDerivative(at:") |
2420 | | .fixItInsertAfter(endLoc, ")"); |
2421 | | } |
2422 | | */ |
2423 | 112 | LLVM_DEBUG(getADDebugStream() << getOriginal().getName() |
2424 | 112 | << " has non-varied result, returning zero" |
2425 | 112 | " for all pullback results\n"); |
2426 | 112 | auto *pullbackEntry = pullback.createBasicBlock(); |
2427 | 112 | createEntryArguments(&pullback); |
2428 | 112 | builder.setCurrentDebugScope( |
2429 | 112 | remapScope(originalExitBlock->getTerminator()->getDebugScope())); |
2430 | 112 | builder.setInsertionPoint(pullbackEntry); |
2431 | | // Destroy all owned arguments. |
2432 | 112 | for (auto *arg : pullbackEntry->getArguments()) |
2433 | 172 | if (arg->getOwnershipKind() == OwnershipKind::Owned) |
2434 | 0 | builder.emitDestroyOperation(pbLoc, arg); |
2435 | | // Return zero for each result. |
2436 | 112 | SmallVector<SILValue, 4> directResults; |
2437 | 112 | auto indirectResultIt = pullback.getIndirectResults().begin(); |
2438 | 132 | for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) { |
2439 | 132 | auto resultType = |
2440 | 132 | pullback.mapTypeIntoContext(resultInfo.getInterfaceType()) |
2441 | 132 | ->getCanonicalType(); |
2442 | 132 | if (resultInfo.isFormalDirect()) |
2443 | 88 | directResults.push_back(builder.emitZero(pbLoc, resultType)); |
2444 | 44 | else |
2445 | 44 | builder.emitZeroIntoBuffer(pbLoc, *indirectResultIt++, IsInitialization); |
2446 | 132 | } |
2447 | 112 | builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc)); |
2448 | 112 | LLVM_DEBUG(getADDebugStream() |
2449 | 112 | << "Generated pullback for " << getOriginal().getName() << ":\n" |
2450 | 112 | << pullback); |
2451 | 112 | } |
2452 | | |
2453 | | AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint( |
2454 | 264 | SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy) { |
2455 | 264 | auto pbLoc = getPullback().getLocation(); |
2456 | | // `Optional<T>` |
2457 | 264 | optionalTy = remapType(optionalTy); |
2458 | 264 | assert(optionalTy.getASTType()->isOptional()); |
2459 | | // `T` |
2460 | 0 | auto wrappedType = optionalTy.getOptionalObjectType(); |
2461 | | // `T.TangentVector` |
2462 | 264 | auto wrappedTanType = remapType(wrappedAdjoint->getType()); |
2463 | | // `Optional<T.TangentVector>` |
2464 | 264 | auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType); |
2465 | | // `Optional<T>.TangentVector` |
2466 | 264 | auto optionalTanTy = getRemappedTangentType(optionalTy); |
2467 | 264 | auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal(); |
2468 | | // Look up the `Optional<T>.TangentVector.init` declaration. |
2469 | 264 | auto initLookup = |
2470 | 264 | optionalTanDecl->lookupDirect(DeclBaseName::createConstructor()); |
2471 | 264 | ConstructorDecl *constructorDecl = nullptr; |
2472 | 264 | for (auto *candidate : initLookup) { |
2473 | 264 | auto candidateModule = candidate->getModuleContext(); |
2474 | 264 | if (candidateModule->getName() == |
2475 | 264 | builder.getASTContext().Id_Differentiation || |
2476 | 264 | candidateModule->isStdlibModule()) { |
2477 | 264 | assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s"); |
2478 | 0 | constructorDecl = cast<ConstructorDecl>(candidate); |
2479 | | #ifdef NDEBUG |
2480 | | break; |
2481 | | #endif |
2482 | 264 | } |
2483 | 264 | } |
2484 | 264 | assert(constructorDecl && "No `Optional.TangentVector.init`"); |
2485 | | |
2486 | | // Allocate a local buffer for the `Optional` adjoint value. |
2487 | 0 | auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy); |
2488 | | // Find `Optional<T.TangentVector>.some` EnumElementDecl. |
2489 | 264 | auto someEltDecl = builder.getASTContext().getOptionalSomeDecl(); |
2490 | | |
2491 | | // Initialize an `Optional<T.TangentVector>` buffer from `wrappedAdjoint` as |
2492 | | // the input for `Optional<T>.TangentVector.init`. |
2493 | 264 | auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType); |
2494 | 264 | if (optionalOfWrappedTanType.isLoadableOrOpaque(builder.getFunction())) { |
2495 | | // %enum = enum $Optional<T.TangentVector>, #Optional.some!enumelt, |
2496 | | // %wrappedAdjoint : $T |
2497 | 152 | auto *enumInst = builder.createEnum(pbLoc, wrappedAdjoint, someEltDecl, |
2498 | 152 | optionalOfWrappedTanType); |
2499 | | // store %enum to %optArgBuf |
2500 | 152 | builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf, |
2501 | 152 | StoreOwnershipQualifier::Init); |
2502 | 152 | } else { |
2503 | | // %enumAddr = init_enum_data_addr %optArgBuf $Optional<T.TangentVector>, |
2504 | | // #Optional.some!enumelt |
2505 | 112 | auto *enumAddr = builder.createInitEnumDataAddr( |
2506 | 112 | pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType()); |
2507 | | // copy_addr %wrappedAdjoint to [init] %enumAddr |
2508 | 112 | builder.createCopyAddr(pbLoc, wrappedAdjoint, enumAddr, IsNotTake, |
2509 | 112 | IsInitialization); |
2510 | | // inject_enum_addr %optArgBuf : $*Optional<T.TangentVector>, |
2511 | | // #Optional.some!enumelt |
2512 | 112 | builder.createInjectEnumAddr(pbLoc, optArgBuf, someEltDecl); |
2513 | 112 | } |
2514 | | |
2515 | | // Apply `Optional<T>.TangentVector.init`. |
2516 | 264 | SILOptFunctionBuilder fb(getContext().getTransform()); |
2517 | | // %init_fn = function_ref @Optional<T>.TangentVector.init |
2518 | 264 | auto *initFn = fb.getOrCreateFunction(pbLoc, SILDeclRef(constructorDecl), |
2519 | 264 | NotForDefinition); |
2520 | 264 | auto *initFnRef = builder.createFunctionRef(pbLoc, initFn); |
2521 | 264 | auto *diffProto = |
2522 | 264 | builder.getASTContext().getProtocol(KnownProtocolKind::Differentiable); |
2523 | 264 | auto *swiftModule = getModule().getSwiftModule(); |
2524 | 264 | auto diffConf = |
2525 | 264 | swiftModule->lookupConformance(wrappedType.getASTType(), diffProto); |
2526 | 264 | assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); |
2527 | 0 | auto subMap = SubstitutionMap::get( |
2528 | 264 | initFn->getLoweredFunctionType()->getSubstGenericSignature(), |
2529 | 264 | ArrayRef<Type>(wrappedType.getASTType()), {diffConf}); |
2530 | | // %metatype = metatype $Optional<T>.TangentVector.Type |
2531 | 264 | auto metatypeType = CanMetatypeType::get(optionalTanTy.getASTType(), |
2532 | 264 | MetatypeRepresentation::Thin); |
2533 | 264 | auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
2534 | 264 | auto metatype = builder.createMetatype(pbLoc, metatypeSILType); |
2535 | | // apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype) |
2536 | 264 | builder.createApply(pbLoc, initFnRef, subMap, |
2537 | 264 | {optTanAdjBuf, optArgBuf, metatype}); |
2538 | 264 | builder.createDeallocStack(pbLoc, optArgBuf); |
2539 | 264 | return optTanAdjBuf; |
2540 | 264 | } |
2541 | | |
2542 | | // Accumulate adjoint for the incoming `Optional` buffer. |
2543 | | void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer( |
2544 | 112 | SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) { |
2545 | 112 | assert(getTangentValueCategory(optionalBuffer) == SILValueCategory::Address); |
2546 | 0 | auto pbLoc = getPullback().getLocation(); |
2547 | | |
2548 | | // Allocate and initialize Optional<Wrapped>.TangentVector from |
2549 | | // Wrapped.TangentVector |
2550 | 112 | AllocStackInst *optTanAdjBuf = |
2551 | 112 | createOptionalAdjoint(bb, wrappedAdjoint, optionalBuffer->getType()); |
2552 | | |
2553 | | // Accumulate into optionalBuffer |
2554 | 112 | addToAdjointBuffer(bb, optionalBuffer, optTanAdjBuf, pbLoc); |
2555 | 112 | builder.emitDestroyAddr(pbLoc, optTanAdjBuf); |
2556 | 112 | builder.createDeallocStack(pbLoc, optTanAdjBuf); |
2557 | 112 | } |
2558 | | |
2559 | | // Set the adjoint value for the incoming `Optional` value. |
2560 | | void PullbackCloner::Implementation::setAdjointValueForOptional( |
2561 | 152 | SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) { |
2562 | 152 | assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object); |
2563 | 0 | auto pbLoc = getPullback().getLocation(); |
2564 | | |
2565 | | // Allocate and initialize Optional<Wrapped>.TangentVector from |
2566 | | // Wrapped.TangentVector |
2567 | 152 | AllocStackInst *optTanAdjBuf = |
2568 | 152 | createOptionalAdjoint(bb, wrappedAdjoint, optionalValue->getType()); |
2569 | | |
2570 | 152 | auto optTanAdjVal = builder.emitLoadValueOperation( |
2571 | 152 | pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take); |
2572 | 152 | recordTemporary(optTanAdjVal); |
2573 | 152 | builder.createDeallocStack(pbLoc, optTanAdjBuf); |
2574 | | |
2575 | 152 | setAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal)); |
2576 | 152 | } |
2577 | | |
2578 | | SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor( |
2579 | | SILBasicBlock *origBB, SILBasicBlock *origPredBB, |
2580 | 2.35k | SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) { |
2581 | | // Get the pullback block and optional pullback trampoline block of the |
2582 | | // predecessor block. |
2583 | 2.35k | auto *pullbackBB = getPullbackBlock(origPredBB); |
2584 | 2.35k | auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB); |
2585 | | // If the predecessor block does not have a corresponding pullback |
2586 | | // trampoline block, then the pullback successor is the pullback block. |
2587 | 2.35k | if (!pullbackTrampolineBB) |
2588 | 0 | return pullbackBB; |
2589 | | |
2590 | | // Otherwise, the pullback successor is the pullback trampoline block, |
2591 | | // which branches to the pullback block and propagates adjoint values of |
2592 | | // active values. |
2593 | 2.35k | assert(pullbackTrampolineBB->getNumArguments() == 1); |
2594 | 0 | auto loc = origBB->getParent()->getLocation(); |
2595 | 2.35k | SmallVector<SILValue, 8> trampolineArguments; |
2596 | | |
2597 | | // Propagate adjoint values/buffers of active values/buffers to |
2598 | | // predecessor blocks. |
2599 | 2.35k | auto &predBBActiveValues = activeValues[origPredBB]; |
2600 | 2.35k | llvm::SmallSet<std::pair<SILValue, SILValue>, 32> propagatedAdjoints; |
2601 | 11.6k | for (auto activeValue : predBBActiveValues) { |
2602 | 11.6k | LLVM_DEBUG(getADDebugStream() |
2603 | 11.6k | << "Propagating adjoint of active value " << activeValue |
2604 | 11.6k | << "from bb" << origBB->getDebugID() |
2605 | 11.6k | << " to predecessors' (bb" << origPredBB->getDebugID() |
2606 | 11.6k | << ") pullback blocks\n"); |
2607 | 11.6k | switch (getTangentValueCategory(activeValue)) { |
2608 | 5.48k | case SILValueCategory::Object: { |
2609 | 5.48k | auto activeValueAdj = getAdjointValue(origBB, activeValue); |
2610 | 5.48k | auto concreteActiveValueAdj = |
2611 | 5.48k | materializeAdjointDirect(activeValueAdj, loc); |
2612 | | |
2613 | 5.48k | if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) { |
2614 | 4.35k | concreteActiveValueAdj = |
2615 | 4.35k | builder.emitCopyValueOperation(loc, concreteActiveValueAdj); |
2616 | 4.35k | setAdjointValue(origBB, activeValue, |
2617 | 4.35k | makeConcreteAdjointValue(concreteActiveValueAdj)); |
2618 | 4.35k | } |
2619 | 5.48k | auto insertion = pullbackTrampolineBlockMap.try_emplace( |
2620 | 5.48k | concreteActiveValueAdj, TrampolineBlockSet()); |
2621 | 5.48k | auto &blockSet = insertion.first->getSecond(); |
2622 | 5.48k | blockSet.insert(pullbackTrampolineBB); |
2623 | 5.48k | trampolineArguments.push_back(concreteActiveValueAdj); |
2624 | | |
2625 | | // If the pullback block does not yet have a registered adjoint |
2626 | | // value for the active value, set the adjoint value to the |
2627 | | // forwarded adjoint value argument. |
2628 | | // TODO: Hoist this logic out of loop over predecessor blocks to |
2629 | | // remove the `hasAdjointValue` check. |
2630 | 5.48k | if (!hasAdjointValue(origPredBB, activeValue)) { |
2631 | 3.81k | auto *pullbackBBArg = |
2632 | 3.81k | getActiveValuePullbackBlockArgument(origPredBB, activeValue); |
2633 | 3.81k | auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg); |
2634 | 3.81k | setAdjointValue(origPredBB, activeValue, forwardedArgAdj); |
2635 | 3.81k | } |
2636 | 5.48k | break; |
2637 | 0 | } |
2638 | 6.17k | case SILValueCategory::Address: { |
2639 | | // Propagate adjoint buffers using `copy_addr`. |
2640 | 6.17k | auto adjBuf = getAdjointBuffer(origBB, activeValue); |
2641 | 6.17k | auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue); |
2642 | 6.17k | if (propagatedAdjoints.insert({adjBuf, predAdjBuf}).second) |
2643 | 5.12k | builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake, |
2644 | 5.12k | IsNotInitialization); |
2645 | 6.17k | break; |
2646 | 0 | } |
2647 | 11.6k | } |
2648 | 11.6k | } |
2649 | | |
2650 | | // Propagate pullback struct argument. |
2651 | 2.35k | TangentBuilder pullbackTrampolineBBBuilder( |
2652 | 2.35k | pullbackTrampolineBB, getContext()); |
2653 | 2.35k | pullbackTrampolineBBBuilder.setCurrentDebugScope( |
2654 | 2.35k | remapScope(origPredBB->getTerminator()->getDebugScope())); |
2655 | | |
2656 | 2.35k | auto *pullbackTrampolineBBArg = pullbackTrampolineBB->getArguments().front(); |
2657 | 2.35k | if (vjpCloner.getLoopInfo()->getLoopFor(origPredBB)) { |
2658 | 376 | assert(pullbackTrampolineBBArg->getType() == |
2659 | 376 | SILType::getRawPointerType(getASTContext())); |
2660 | 0 | auto pbTupleType = |
2661 | 376 | remapType(getPullbackInfo().getLinearMapTupleLoweredType(origPredBB)); |
2662 | 376 | auto predPbTupleAddr = pullbackTrampolineBBBuilder.createPointerToAddress( |
2663 | 376 | loc, pullbackTrampolineBBArg, pbTupleType.getAddressType(), |
2664 | 376 | /*isStrict*/ true); |
2665 | 376 | auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad( |
2666 | 376 | loc, predPbTupleAddr, |
2667 | 376 | pbTupleType.isTrivial(getPullback()) ? |
2668 | 284 | LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy); |
2669 | 376 | trampolineArguments.push_back(predPbStructVal); |
2670 | 1.98k | } else { |
2671 | 1.98k | trampolineArguments.push_back(pullbackTrampolineBBArg); |
2672 | 1.98k | } |
2673 | | // Branch from pullback trampoline block to pullback block. |
2674 | 0 | pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB, |
2675 | 2.35k | trampolineArguments); |
2676 | 2.35k | return pullbackTrampolineBB; |
2677 | 2.35k | } |
2678 | | |
2679 | 6.54k | void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { |
2680 | 6.54k | auto pbLoc = getPullback().getLocation(); |
2681 | | // Get the corresponding pullback basic block. |
2682 | 6.54k | auto *pbBB = getPullbackBlock(bb); |
2683 | 6.54k | builder.setInsertionPoint(pbBB); |
2684 | | |
2685 | 6.54k | LLVM_DEBUG({ |
2686 | 6.54k | auto &s = getADDebugStream() |
2687 | 6.54k | << "Original bb" + std::to_string(bb->getDebugID()) |
2688 | 6.54k | << ": To differentiate or not to differentiate?\n"; |
2689 | 6.54k | for (auto &inst : llvm::reverse(*bb)) { |
2690 | 6.54k | s << (getPullbackInfo().shouldDifferentiateInstruction(&inst) ? "[x] " |
2691 | 6.54k | : "[ ] ") |
2692 | 6.54k | << inst; |
2693 | 6.54k | } |
2694 | 6.54k | }); |
2695 | | |
2696 | | // Visit each instruction in reverse order. |
2697 | 79.2k | for (auto &inst : llvm::reverse(*bb)) { |
2698 | 79.2k | if (!getPullbackInfo().shouldDifferentiateInstruction(&inst)) |
2699 | 44.5k | continue; |
2700 | | // Differentiate instruction. |
2701 | 34.6k | builder.setCurrentDebugScope(remapScope(inst.getDebugScope())); |
2702 | 34.6k | visit(&inst); |
2703 | 34.6k | if (errorOccurred) |
2704 | 56 | return; |
2705 | 34.6k | } |
2706 | | |
2707 | | // Emit a branching terminator for the block. |
2708 | | // If the original block is the original entry, then the pullback block is |
2709 | | // the pullback exit. This is handled specially in |
2710 | | // `PullbackCloner::Implementation::run()`, so we leave the block |
2711 | | // non-terminated. |
2712 | 6.48k | if (bb->isEntry()) |
2713 | 4.73k | return; |
2714 | | |
2715 | | // Otherwise, add a `switch_enum` terminator for non-exit |
2716 | | // pullback blocks. |
2717 | | // 1. Get the pullback struct pullback block argument. |
2718 | | // 2. Extract the predecessor enum value from the pullback struct value. |
2719 | 1.75k | auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb); |
2720 | 1.75k | (void)predEnum; |
2721 | 1.75k | auto predEnumVal = getPullbackPredTupleElement(bb); |
2722 | | |
2723 | | // Propagate adjoint values from active basic block arguments to |
2724 | | // incoming values (predecessor terminator operands). |
2725 | 1.75k | for (auto *bbArg : bb->getArguments()) { |
2726 | 540 | if (!getActivityInfo().isActive(bbArg, getConfig())) |
2727 | 180 | continue; |
2728 | 360 | LLVM_DEBUG(getADDebugStream() << "Propagating adjoint value for active bb" |
2729 | 360 | << bb->getDebugID() << " argument: " |
2730 | 360 | << *bbArg); |
2731 | | |
2732 | | // Get predecessor terminator operands. |
2733 | 360 | SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues; |
2734 | 360 | if (bbArg->getSingleTerminatorOperands(incomingValues)) { |
2735 | | // Returns true if the given terminator instruction is a `switch_enum` on |
2736 | | // an `Optional`-typed value. `switch_enum` instructions require |
2737 | | // special-case adjoint value propagation for the operand. |
2738 | 360 | auto isSwitchEnumInstOnOptional = |
2739 | 620 | [&ctx = getASTContext()](TermInst *termInst) { |
2740 | 620 | if (!termInst) |
2741 | 468 | return false; |
2742 | 152 | if (auto *sei = dyn_cast<SwitchEnumInst>(termInst)) { |
2743 | 152 | auto operandTy = sei->getOperand()->getType(); |
2744 | 152 | return operandTy.getASTType()->isOptional(); |
2745 | 152 | } |
2746 | 0 | return false; |
2747 | 152 | }; |
2748 | | |
2749 | | // Check the tangent value category of the active basic block argument. |
2750 | 360 | switch (getTangentValueCategory(bbArg)) { |
2751 | | // If argument has a loadable tangent value category: materialize adjoint |
2752 | | // value of the argument, create a copy, and set the copy as the adjoint |
2753 | | // value of incoming values. |
2754 | 360 | case SILValueCategory::Object: { |
2755 | 360 | auto bbArgAdj = getAdjointValue(bb, bbArg); |
2756 | 360 | auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc); |
2757 | 360 | auto concreteBBArgAdjCopy = |
2758 | 360 | builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj); |
2759 | 620 | for (auto pair : incomingValues) { |
2760 | 620 | auto *predBB = std::get<0>(pair); |
2761 | 620 | auto incomingValue = std::get<1>(pair); |
2762 | | // Handle `switch_enum` on `Optional`. |
2763 | 620 | auto termInst = bbArg->getSingleTerminator(); |
2764 | 620 | if (isSwitchEnumInstOnOptional(termInst)) { |
2765 | 152 | setAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy); |
2766 | 468 | } else { |
2767 | 468 | blockTemporaries[getPullbackBlock(predBB)].insert( |
2768 | 468 | concreteBBArgAdjCopy); |
2769 | 468 | setAdjointValue(predBB, incomingValue, |
2770 | 468 | makeConcreteAdjointValue(concreteBBArgAdjCopy)); |
2771 | 468 | } |
2772 | 620 | } |
2773 | 360 | break; |
2774 | 0 | } |
2775 | | // If argument has an address tangent value category: materialize adjoint |
2776 | | // value of the argument, create a copy, and set the copy as the adjoint |
2777 | | // value of incoming values. |
2778 | 0 | case SILValueCategory::Address: { |
2779 | 0 | auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg); |
2780 | 0 | for (auto pair : incomingValues) { |
2781 | 0 | auto incomingValue = std::get<1>(pair); |
2782 | | // Handle `switch_enum` on `Optional`. |
2783 | 0 | auto termInst = bbArg->getSingleTerminator(); |
2784 | 0 | if (isSwitchEnumInstOnOptional(termInst)) |
2785 | 0 | accumulateAdjointForOptionalBuffer(bb, incomingValue, bbArgAdjBuf); |
2786 | 0 | else |
2787 | 0 | addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc); |
2788 | 0 | } |
2789 | 0 | break; |
2790 | 0 | } |
2791 | 360 | } |
2792 | 360 | } else |
2793 | 0 | llvm::report_fatal_error("do not know how to handle this incoming bb argument"); |
2794 | 360 | } |
2795 | | |
2796 | | // 3. Build the pullback successor cases for the `switch_enum` |
2797 | | // instruction. The pullback successors correspond to the predecessors |
2798 | | // of the current original block. |
2799 | 1.75k | SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> |
2800 | 1.75k | pullbackSuccessorCases; |
2801 | | // A map from active values' adjoint values to the trampoline blocks that |
2802 | | // are using them. |
2803 | 1.75k | SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap; |
2804 | 1.75k | SmallDenseMap<SILBasicBlock *, SILBasicBlock *> origPredpullbackSuccBBMap; |
2805 | 2.35k | for (auto *predBB : bb->getPredecessorBlocks()) { |
2806 | 2.35k | auto *pullbackSuccBB = |
2807 | 2.35k | buildPullbackSuccessor(bb, predBB, pullbackTrampolineBlockMap); |
2808 | 2.35k | origPredpullbackSuccBBMap[predBB] = pullbackSuccBB; |
2809 | 2.35k | auto *enumEltDecl = |
2810 | 2.35k | getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb); |
2811 | 2.35k | pullbackSuccessorCases.push_back({enumEltDecl, pullbackSuccBB}); |
2812 | 2.35k | } |
2813 | | // Values are trampolined by only a subset of pullback successor blocks. |
2814 | | // Other successors blocks should destroy the value. |
2815 | 4.35k | for (auto pair : pullbackTrampolineBlockMap) { |
2816 | 4.35k | auto value = pair.getFirst(); |
2817 | | // The set of trampoline BBs that are users of `value`. |
2818 | 4.35k | auto &userTrampolineBBSet = pair.getSecond(); |
2819 | | // For each pullback successor block that does not trampoline the value, |
2820 | | // release the value. |
2821 | 6.75k | for (auto origPredPbSuccPair : origPredpullbackSuccBBMap) { |
2822 | 6.75k | auto *origPred = origPredPbSuccPair.getFirst(); |
2823 | 6.75k | auto *pbSucc = origPredPbSuccPair.getSecond(); |
2824 | 6.75k | if (userTrampolineBBSet.count(pbSucc)) |
2825 | 5.48k | continue; |
2826 | 1.26k | TangentBuilder pullbackSuccBuilder(pbSucc->begin(), getContext()); |
2827 | 1.26k | pullbackSuccBuilder.setCurrentDebugScope( |
2828 | 1.26k | remapScope(origPred->getTerminator()->getDebugScope())); |
2829 | 1.26k | pullbackSuccBuilder.emitDestroyValueOperation(pbLoc, value); |
2830 | 1.26k | } |
2831 | 4.35k | } |
2832 | | // Emit cleanups for all block-local temporaries. |
2833 | 1.75k | cleanUpTemporariesForBlock(pbBB, pbLoc); |
2834 | | // Branch to pullback successor blocks. |
2835 | 1.75k | assert(pullbackSuccessorCases.size() == predEnum->getNumElements()); |
2836 | 0 | builder.createSwitchEnum(pbLoc, predEnumVal, /*DefaultBB*/ nullptr, |
2837 | 1.75k | pullbackSuccessorCases, llvm::None, ProfileCounter(), |
2838 | 1.75k | OwnershipKind::Owned); |
2839 | 1.75k | } |
2840 | | |
2841 | | //--------------------------------------------------------------------------// |
2842 | | // Member accessor pullback generation |
2843 | | //--------------------------------------------------------------------------// |
2844 | | |
2845 | 256 | bool PullbackCloner::Implementation::runForSemanticMemberAccessor() { |
2846 | 256 | auto &original = getOriginal(); |
2847 | 256 | auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl()); |
2848 | 256 | switch (accessor->getAccessorKind()) { |
2849 | 192 | case AccessorKind::Get: |
2850 | 192 | return runForSemanticMemberGetter(); |
2851 | 64 | case AccessorKind::Set: |
2852 | 64 | return runForSemanticMemberSetter(); |
2853 | | // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors. |
2854 | 0 | default: |
2855 | 0 | llvm_unreachable("Unsupported accessor kind; inconsistent with " |
2856 | 256 | "`isSemanticMemberAccessor`?"); |
2857 | 256 | } |
2858 | 256 | } |
2859 | | |
2860 | 192 | bool PullbackCloner::Implementation::runForSemanticMemberGetter() { |
2861 | 192 | auto &original = getOriginal(); |
2862 | 192 | auto &pullback = getPullback(); |
2863 | 192 | auto pbLoc = getPullback().getLocation(); |
2864 | | |
2865 | 192 | auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl()); |
2866 | 192 | assert(accessor->getAccessorKind() == AccessorKind::Get); |
2867 | | |
2868 | 0 | auto *origEntry = original.getEntryBlock(); |
2869 | 192 | auto *pbEntry = pullback.getEntryBlock(); |
2870 | 192 | builder.setCurrentDebugScope( |
2871 | 192 | remapScope(origEntry->getScopeOfFirstNonMetaInstruction())); |
2872 | 192 | builder.setInsertionPoint(pbEntry); |
2873 | | |
2874 | | // Get getter argument and result values. |
2875 | | // Getter type: $(Self) -> Result |
2876 | | // Pullback type: $(Result') -> Self' |
2877 | 192 | assert(original.getLoweredFunctionType()->getNumParameters() == 1); |
2878 | 0 | assert(pullback.getLoweredFunctionType()->getNumParameters() == 1); |
2879 | 0 | assert(pullback.getLoweredFunctionType()->getNumResults() == 1); |
2880 | 0 | SILValue origSelf = original.getArgumentsWithoutIndirectResults().front(); |
2881 | | |
2882 | 192 | SmallVector<SILValue, 8> origFormalResults; |
2883 | 192 | collectAllFormalResultsInTypeOrder(original, origFormalResults); |
2884 | 192 | assert(getConfig().resultIndices->getNumIndices() == 1 && |
2885 | 192 | "Getter should have one semantic result"); |
2886 | 0 | auto origResult = origFormalResults[*getConfig().resultIndices->begin()]; |
2887 | | |
2888 | 192 | auto tangentVectorSILTy = pullback.getConventions().getResults().front() |
2889 | 192 | .getSILStorageType(getModule(), |
2890 | 192 | pullback.getLoweredFunctionType(), |
2891 | 192 | TypeExpansionContext::minimal()); |
2892 | 192 | auto tangentVectorTy = tangentVectorSILTy.getASTType(); |
2893 | 192 | auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); |
2894 | | |
2895 | | // Look up the corresponding field in the tangent space. |
2896 | 192 | auto *origField = cast<VarDecl>(accessor->getStorage()); |
2897 | 192 | auto baseType = remapType(origSelf->getType()).getASTType(); |
2898 | 192 | auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, |
2899 | 192 | pbLoc, getInvoker()); |
2900 | 192 | if (!tanField) { |
2901 | 0 | errorOccurred = true; |
2902 | 0 | return true; |
2903 | 0 | } |
2904 | | |
2905 | | // Switch based on the base tangent struct's value category. |
2906 | 192 | switch (getTangentValueCategory(origSelf)) { |
2907 | 88 | case SILValueCategory::Object: { |
2908 | 88 | auto adjResult = getAdjointValue(origEntry, origResult); |
2909 | 88 | switch (adjResult.getKind()) { |
2910 | 0 | case AdjointValueKind::Zero: |
2911 | 0 | addAdjointValue(origEntry, origSelf, |
2912 | 0 | makeZeroAdjointValue(tangentVectorSILTy), pbLoc); |
2913 | 0 | break; |
2914 | 88 | case AdjointValueKind::Concrete: |
2915 | 88 | case AdjointValueKind::Aggregate: { |
2916 | 88 | SmallVector<AdjointValue, 8> eltVals; |
2917 | 152 | for (auto *field : tangentVectorDecl->getStoredProperties()) { |
2918 | 152 | if (field == tanField) { |
2919 | 88 | eltVals.push_back(adjResult); |
2920 | 88 | } else { |
2921 | 64 | auto substMap = tangentVectorTy->getMemberSubstitutionMap( |
2922 | 64 | field->getModuleContext(), field); |
2923 | 64 | auto fieldTy = field->getInterfaceType().subst(substMap); |
2924 | 64 | auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); |
2925 | 64 | assert(fieldSILTy.isObject()); |
2926 | 0 | eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); |
2927 | 64 | } |
2928 | 152 | } |
2929 | 88 | addAdjointValue(origEntry, origSelf, |
2930 | 88 | makeAggregateAdjointValue(tangentVectorSILTy, eltVals), |
2931 | 88 | pbLoc); |
2932 | | |
2933 | 88 | break; |
2934 | 88 | } |
2935 | 0 | case AdjointValueKind::AddElement: |
2936 | 0 | llvm_unreachable("Adjoint of an aggregate type's field cannot be of kind " |
2937 | 88 | "`AddElement`"); |
2938 | 88 | } |
2939 | 88 | break; |
2940 | 88 | } |
2941 | 104 | case SILValueCategory::Address: { |
2942 | 104 | assert(pullback.getIndirectResults().size() == 1); |
2943 | 0 | auto pbIndRes = pullback.getIndirectResults().front(); |
2944 | 104 | auto *adjSelf = createFunctionLocalAllocation( |
2945 | 104 | pbIndRes->getType().getObjectType(), pbLoc); |
2946 | 104 | setAdjointBuffer(origEntry, origSelf, adjSelf); |
2947 | 296 | for (auto *field : tangentVectorDecl->getStoredProperties()) { |
2948 | 296 | auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field); |
2949 | | // Non-tangent fields get a zero. |
2950 | 296 | if (field != tanField) { |
2951 | 192 | builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization); |
2952 | 192 | continue; |
2953 | 192 | } |
2954 | | // Switch based on the property's value category. |
2955 | 104 | switch (getTangentValueCategory(origResult)) { |
2956 | 20 | case SILValueCategory::Object: { |
2957 | 20 | auto adjResult = getAdjointValue(origEntry, origResult); |
2958 | 20 | auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc); |
2959 | 20 | auto adjResultValueCopy = |
2960 | 20 | builder.emitCopyValueOperation(pbLoc, adjResultValue); |
2961 | 20 | builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt, |
2962 | 20 | StoreOwnershipQualifier::Init); |
2963 | 20 | break; |
2964 | 0 | } |
2965 | 84 | case SILValueCategory::Address: { |
2966 | 84 | auto adjResult = getAdjointBuffer(origEntry, origResult); |
2967 | 84 | builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake, |
2968 | 84 | IsInitialization); |
2969 | 84 | destroyedLocalAllocations.insert(adjResult); |
2970 | 84 | break; |
2971 | 0 | } |
2972 | 104 | } |
2973 | 104 | } |
2974 | 104 | break; |
2975 | 104 | } |
2976 | 192 | } |
2977 | 192 | return false; |
2978 | 192 | } |
2979 | | |
2980 | 64 | bool PullbackCloner::Implementation::runForSemanticMemberSetter() { |
2981 | 64 | auto &original = getOriginal(); |
2982 | 64 | auto &pullback = getPullback(); |
2983 | 64 | auto pbLoc = getPullback().getLocation(); |
2984 | | |
2985 | 64 | auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl()); |
2986 | 64 | assert(accessor->getAccessorKind() == AccessorKind::Set); |
2987 | | |
2988 | 0 | auto *origEntry = original.getEntryBlock(); |
2989 | 64 | auto *pbEntry = pullback.getEntryBlock(); |
2990 | 64 | builder.setCurrentDebugScope( |
2991 | 64 | remapScope(origEntry->getScopeOfFirstNonMetaInstruction())); |
2992 | 64 | builder.setInsertionPoint(pbEntry); |
2993 | | |
2994 | | // Get setter argument values. |
2995 | | // Setter type: $(inout Self, Argument) -> () |
2996 | | // Pullback type (wrt self): $(inout Self') -> () |
2997 | | // Pullback type (wrt both): $(inout Self') -> Argument' |
2998 | 64 | assert(original.getLoweredFunctionType()->getNumParameters() == 2); |
2999 | 0 | assert(pullback.getLoweredFunctionType()->getNumParameters() == 1); |
3000 | 0 | assert(pullback.getLoweredFunctionType()->getNumResults() == 0 || |
3001 | 64 | pullback.getLoweredFunctionType()->getNumResults() == 1); |
3002 | | |
3003 | 0 | SILValue origArg = original.getArgumentsWithoutIndirectResults()[0]; |
3004 | 64 | SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1]; |
3005 | | |
3006 | | // Look up the corresponding field in the tangent space. |
3007 | 64 | auto *origField = cast<VarDecl>(accessor->getStorage()); |
3008 | 64 | auto baseType = remapType(origSelf->getType()).getASTType(); |
3009 | 64 | auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, |
3010 | 64 | pbLoc, getInvoker()); |
3011 | 64 | if (!tanField) { |
3012 | 0 | errorOccurred = true; |
3013 | 0 | return true; |
3014 | 0 | } |
3015 | | |
3016 | 64 | auto adjSelf = getAdjointBuffer(origEntry, origSelf); |
3017 | 64 | auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); |
3018 | | // Switch based on the property's value category. |
3019 | 64 | switch (getTangentValueCategory(origArg)) { |
3020 | 24 | case SILValueCategory::Object: { |
3021 | 24 | auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt, |
3022 | 24 | LoadOwnershipQualifier::Take); |
3023 | 24 | setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg)); |
3024 | 24 | blockTemporaries[pbEntry].insert(adjArg); |
3025 | 24 | break; |
3026 | 0 | } |
3027 | 40 | case SILValueCategory::Address: { |
3028 | 40 | addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc); |
3029 | 40 | builder.emitDestroyOperation(pbLoc, adjSelfElt); |
3030 | 40 | break; |
3031 | 0 | } |
3032 | 64 | } |
3033 | 64 | builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization); |
3034 | | |
3035 | 64 | return false; |
3036 | 64 | } |
3037 | | |
3038 | | //--------------------------------------------------------------------------// |
3039 | | // Adjoint buffer mapping |
3040 | | //--------------------------------------------------------------------------// |
3041 | | |
3042 | | SILValue PullbackCloner::Implementation::getAdjointProjection( |
3043 | 16.7k | SILBasicBlock *origBB, SILValue originalProjection) { |
3044 | | // Handle `struct_element_addr`. |
3045 | | // Adjoint projection: a `struct_element_addr` into the base adjoint buffer. |
3046 | 16.7k | if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) { |
3047 | 920 | assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
3048 | 920 | "`@noDerivative` struct projections should never be active"); |
3049 | 0 | auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); |
3050 | 920 | auto structType = remapType(seai->getOperand()->getType()).getASTType(); |
3051 | 920 | auto *tanField = |
3052 | 920 | getTangentStoredProperty(getContext(), seai, structType, getInvoker()); |
3053 | 920 | assert(tanField && "Invalid projections should have been diagnosed"); |
3054 | 0 | return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField); |
3055 | 920 | } |
3056 | | // Handle `tuple_element_addr`. |
3057 | | // Adjoint projection: a `tuple_element_addr` into the base adjoint buffer. |
3058 | 15.8k | if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) { |
3059 | 1.16k | auto source = teai->getOperand(); |
3060 | 1.16k | auto adjSource = getAdjointBuffer(origBB, source); |
3061 | 1.16k | if (!adjSource->getType().is<TupleType>()) |
3062 | 200 | return adjSource; |
3063 | 960 | auto origTupleTy = source->getType().castTo<TupleType>(); |
3064 | 960 | unsigned adjIndex = 0; |
3065 | 960 | for (unsigned i : range(teai->getFieldIndex())) { |
3066 | 384 | if (getTangentSpace( |
3067 | 384 | origTupleTy->getElement(i).getType()->getCanonicalType())) |
3068 | 328 | ++adjIndex; |
3069 | 384 | } |
3070 | 960 | return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex); |
3071 | 1.16k | } |
3072 | | // Handle `ref_element_addr`. |
3073 | | // Adjoint projection: a local allocation initialized with the corresponding |
3074 | | // field value from the class's base adjoint value. |
3075 | 14.6k | if (auto *reai = dyn_cast<RefElementAddrInst>(originalProjection)) { |
3076 | 164 | assert(!reai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
3077 | 164 | "`@noDerivative` class projections should never be active"); |
3078 | 0 | auto loc = reai->getLoc(); |
3079 | | // Get the class operand, stripping `begin_borrow`. |
3080 | 164 | auto classOperand = stripBorrow(reai->getOperand()); |
3081 | 164 | auto classType = remapType(reai->getOperand()->getType()).getASTType(); |
3082 | 164 | auto *tanField = |
3083 | 164 | getTangentStoredProperty(getContext(), reai->getField(), classType, |
3084 | 164 | reai->getLoc(), getInvoker()); |
3085 | 164 | assert(tanField && "Invalid projections should have been diagnosed"); |
3086 | | // Create a local allocation for the element adjoint buffer. |
3087 | 0 | auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); |
3088 | 164 | auto eltTanSILType = |
3089 | 164 | remapType(SILType::getPrimitiveAddressType(eltTanType)); |
3090 | 164 | auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); |
3091 | | // Check the class operand's `TangentVector` value category. |
3092 | 164 | switch (getTangentValueCategory(classOperand)) { |
3093 | 56 | case SILValueCategory::Object: { |
3094 | | // Get the class operand's adjoint value. Currently, it must be a |
3095 | | // `TangentVector` struct. |
3096 | 56 | auto adjClass = |
3097 | 56 | materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc); |
3098 | 56 | builder.emitScopedBorrowOperation( |
3099 | 56 | loc, adjClass, [&](SILValue borrowedAdjClass) { |
3100 | | // Initialize the element adjoint buffer with the base adjoint |
3101 | | // value. |
3102 | 56 | auto *adjElt = |
3103 | 56 | builder.createStructExtract(loc, borrowedAdjClass, tanField); |
3104 | 56 | auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt); |
3105 | 56 | builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer, |
3106 | 56 | StoreOwnershipQualifier::Init); |
3107 | 56 | }); |
3108 | 56 | return eltAdjBuffer; |
3109 | 0 | } |
3110 | 108 | case SILValueCategory::Address: { |
3111 | | // Get the class operand's adjoint buffer. Currently, it must be a |
3112 | | // `TangentVector` struct. |
3113 | 108 | auto adjClass = getAdjointBuffer(origBB, classOperand); |
3114 | | // Initialize the element adjoint buffer with the base adjoint buffer. |
3115 | 108 | auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField); |
3116 | 108 | builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake, |
3117 | 108 | IsInitialization); |
3118 | 108 | return eltAdjBuffer; |
3119 | 0 | } |
3120 | 164 | } |
3121 | 164 | } |
3122 | | // Handle `begin_access`. |
3123 | | // Adjoint projection: the base adjoint buffer itself. |
3124 | 14.4k | if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) { |
3125 | 3.92k | auto adjBase = getAdjointBuffer(origBB, bai->getOperand()); |
3126 | 3.92k | if (errorOccurred) |
3127 | 0 | return (bufferMap[{origBB, originalProjection}] = SILValue()); |
3128 | | // Return the base buffer's adjoint buffer. |
3129 | 3.92k | return adjBase; |
3130 | 3.92k | } |
3131 | | // Handle `array.uninitialized_intrinsic` application element addresses. |
3132 | | // Adjoint projection: a local allocation initialized by applying |
3133 | | // `Array.TangentVector.subscript` to the base array's adjoint value. |
3134 | 10.5k | auto *ai = |
3135 | 10.5k | getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection); |
3136 | 10.5k | auto *definingInst = dyn_cast_or_null<SingleValueInstruction>( |
3137 | 10.5k | originalProjection->getDefiningInstruction()); |
3138 | 10.5k | bool isAllocateUninitializedArrayIntrinsicElementAddress = |
3139 | 10.5k | ai && definingInst && |
3140 | 10.5k | (isa<PointerToAddressInst>(definingInst) || |
3141 | 488 | isa<IndexAddrInst>(definingInst)); |
3142 | 10.5k | if (isAllocateUninitializedArrayIntrinsicElementAddress) { |
3143 | | // Get the array element index of the result address. |
3144 | 488 | int eltIndex = 0; |
3145 | 488 | if (auto *iai = dyn_cast<IndexAddrInst>(definingInst)) { |
3146 | 124 | auto *ili = cast<IntegerLiteralInst>(iai->getIndex()); |
3147 | 124 | eltIndex = ili->getValue().getLimitedValue(); |
3148 | 124 | } |
3149 | | // Get the array adjoint value. |
3150 | 488 | SILValue arrayAdjoint; |
3151 | 488 | assert(ai && "Expected `array.uninitialized_intrinsic` application"); |
3152 | 488 | for (auto use : ai->getUses()) { |
3153 | 488 | auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()); |
3154 | 488 | if (!dti) |
3155 | 0 | continue; |
3156 | 488 | assert(!arrayAdjoint && "Array adjoint already found"); |
3157 | | // The first `destructure_tuple` result is the `Array` value. |
3158 | 0 | auto arrayValue = dti->getResult(0); |
3159 | 488 | arrayAdjoint = materializeAdjointDirect( |
3160 | 488 | getAdjointValue(origBB, arrayValue), definingInst->getLoc()); |
3161 | 488 | } |
3162 | 488 | assert(arrayAdjoint && "Array does not have adjoint value"); |
3163 | | // Apply `Array.TangentVector.subscript` to get array element adjoint value. |
3164 | 0 | auto *eltAdjBuffer = |
3165 | 488 | getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc()); |
3166 | 488 | return eltAdjBuffer; |
3167 | 488 | } |
3168 | 10.0k | return SILValue(); |
3169 | 10.5k | } |
3170 | | |
3171 | | //----------------------------------------------------------------------------// |
3172 | | // Adjoint value accumulation |
3173 | | //----------------------------------------------------------------------------// |
3174 | | |
3175 | | AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect( |
3176 | 3.06k | AdjointValue lhs, AdjointValue rhs, SILLocation loc) { |
3177 | 3.06k | LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint directly.\nLHS: " |
3178 | 3.06k | << lhs << "\nRHS: " << rhs << '\n'); |
3179 | 3.06k | switch (lhs.getKind()) { |
3180 | | // x |
3181 | 2.63k | case AdjointValueKind::Concrete: { |
3182 | 2.63k | auto lhsVal = lhs.getConcreteValue(); |
3183 | 2.63k | switch (rhs.getKind()) { |
3184 | | // x + y |
3185 | 2.34k | case AdjointValueKind::Concrete: { |
3186 | 2.34k | auto rhsVal = rhs.getConcreteValue(); |
3187 | 2.34k | auto sum = recordTemporary(builder.emitAdd(loc, lhsVal, rhsVal)); |
3188 | 2.34k | return makeConcreteAdjointValue(sum); |
3189 | 0 | } |
3190 | | // x + 0 => x |
3191 | 152 | case AdjointValueKind::Zero: |
3192 | 152 | return lhs; |
3193 | | // x + (y, z) => (x.0 + y, x.1 + z) |
3194 | 80 | case AdjointValueKind::Aggregate: { |
3195 | 80 | SmallVector<AdjointValue, 8> newElements; |
3196 | 80 | auto lhsTy = lhsVal->getType().getASTType(); |
3197 | 80 | auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal); |
3198 | 80 | if (lhsTy->is<TupleType>()) { |
3199 | 64 | auto elts = builder.createDestructureTuple(loc, lhsValCopy); |
3200 | 64 | llvm::for_each(elts->getResults(), |
3201 | 128 | [this](SILValue result) { recordTemporary(result); }); |
3202 | 128 | for (auto i : indices(elts->getResults())) { |
3203 | 128 | auto rhsElt = rhs.getAggregateElement(i); |
3204 | 128 | newElements.push_back(accumulateAdjointsDirect( |
3205 | 128 | makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); |
3206 | 128 | } |
3207 | 64 | } else if (lhsTy->getStructOrBoundGenericStruct()) { |
3208 | 16 | auto elts = |
3209 | 16 | builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy); |
3210 | 16 | llvm::for_each(elts->getResults(), |
3211 | 16 | [this](SILValue result) { recordTemporary(result); }); |
3212 | 16 | for (unsigned i : indices(elts->getResults())) { |
3213 | 16 | auto rhsElt = rhs.getAggregateElement(i); |
3214 | 16 | newElements.push_back(accumulateAdjointsDirect( |
3215 | 16 | makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); |
3216 | 16 | } |
3217 | 16 | } else { |
3218 | 0 | llvm_unreachable("Not an aggregate type"); |
3219 | 0 | } |
3220 | 80 | return makeAggregateAdjointValue(lhsVal->getType(), newElements); |
3221 | 0 | } |
3222 | | // x + (baseAdjoint, index, eltToAdd) => (x+baseAdjoint, index, eltToAdd) |
3223 | 56 | case AdjointValueKind::AddElement: { |
3224 | 56 | auto *addElementValue = rhs.getAddElementValue(); |
3225 | 56 | auto baseAdjoint = addElementValue->baseAdjoint; |
3226 | 56 | auto eltToAdd = addElementValue->eltToAdd; |
3227 | | |
3228 | 56 | auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc); |
3229 | 56 | return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd, |
3230 | 56 | addElementValue->fieldLocator); |
3231 | 0 | } |
3232 | 2.63k | } |
3233 | 2.63k | } |
3234 | | // 0 |
3235 | 192 | case AdjointValueKind::Zero: |
3236 | | // 0 + x => x |
3237 | 192 | return rhs; |
3238 | | // (x, y) |
3239 | 36 | case AdjointValueKind::Aggregate: { |
3240 | 36 | switch (rhs.getKind()) { |
3241 | | // (x, y) + z => (z.0 + x, z.1 + y) |
3242 | 0 | case AdjointValueKind::Concrete: |
3243 | 0 | return accumulateAdjointsDirect(rhs, lhs, loc); |
3244 | | // x + 0 => x |
3245 | 4 | case AdjointValueKind::Zero: |
3246 | 4 | return lhs; |
3247 | | // (x, y) + (z, w) => (x + z, y + w) |
3248 | 32 | case AdjointValueKind::Aggregate: { |
3249 | 32 | SmallVector<AdjointValue, 8> newElements; |
3250 | 32 | for (auto i : range(lhs.getNumAggregateElements())) |
3251 | 64 | newElements.push_back(accumulateAdjointsDirect( |
3252 | 64 | lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc)); |
3253 | 32 | return makeAggregateAdjointValue(lhs.getType(), newElements); |
3254 | 0 | } |
3255 | | // (x.0, ..., x.n) + (baseAdjoint, index, eltToAdd) => (x + baseAdjoint, |
3256 | | // index, eltToAdd) |
3257 | 0 | case AdjointValueKind::AddElement: { |
3258 | 0 | auto *addElementValue = rhs.getAddElementValue(); |
3259 | 0 | auto baseAdjoint = addElementValue->baseAdjoint; |
3260 | 0 | auto eltToAdd = addElementValue->eltToAdd; |
3261 | 0 | auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc); |
3262 | |
|
3263 | 0 | return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd, |
3264 | 0 | addElementValue->fieldLocator); |
3265 | 0 | } |
3266 | 36 | } |
3267 | 36 | } |
3268 | | // (baseAdjoint, index, eltToAdd) |
3269 | 196 | case AdjointValueKind::AddElement: { |
3270 | 196 | switch (rhs.getKind()) { |
3271 | 36 | case AdjointValueKind::Zero: |
3272 | 36 | return lhs; |
3273 | | // (baseAdjoint, index, eltToAdd) + x => (x + baseAdjoint, index, eltToAdd) |
3274 | 20 | case AdjointValueKind::Concrete: |
3275 | | // (baseAdjoint, index, eltToAdd) + (x.0, ..., x.n) => (x + baseAdjoint, |
3276 | | // index, eltToAdd) |
3277 | 20 | case AdjointValueKind::Aggregate: |
3278 | 20 | return accumulateAdjointsDirect(rhs, lhs, loc); |
3279 | | // (baseAdjoint1, index1, eltToAdd1) + (baseAdjoint2, index2, eltToAdd2) |
3280 | | // => ((baseAdjoint1 + baseAdjoint2, index1, eltToAdd1), index2, eltToAdd2) |
3281 | 140 | case AdjointValueKind::AddElement: { |
3282 | 140 | auto *addElementValueLhs = lhs.getAddElementValue(); |
3283 | 140 | auto baseAdjointLhs = addElementValueLhs->baseAdjoint; |
3284 | 140 | auto eltToAddLhs = addElementValueLhs->eltToAdd; |
3285 | | |
3286 | 140 | auto *addElementValueRhs = rhs.getAddElementValue(); |
3287 | 140 | auto baseAdjointRhs = addElementValueRhs->baseAdjoint; |
3288 | 140 | auto eltToAddRhs = addElementValueRhs->eltToAdd; |
3289 | | |
3290 | 140 | auto sumOfBaseAdjoints = |
3291 | 140 | accumulateAdjointsDirect(baseAdjointLhs, baseAdjointRhs, loc); |
3292 | 140 | auto newBaseAdjoint = makeAddElementAdjointValue( |
3293 | 140 | sumOfBaseAdjoints, eltToAddLhs, addElementValueLhs->fieldLocator); |
3294 | | |
3295 | 140 | return makeAddElementAdjointValue(newBaseAdjoint, eltToAddRhs, |
3296 | 140 | addElementValueRhs->fieldLocator); |
3297 | 20 | } |
3298 | 196 | } |
3299 | 196 | } |
3300 | 3.06k | } |
3301 | 0 | llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 |
3302 | 0 | } |
3303 | | |
3304 | | //----------------------------------------------------------------------------// |
3305 | | // Array literal initialization differentiation |
3306 | | //----------------------------------------------------------------------------// |
3307 | | |
3308 | | void PullbackCloner::Implementation:: |
3309 | | accumulateArrayLiteralElementAddressAdjoints(SILBasicBlock *origBB, |
3310 | | SILValue originalValue, |
3311 | | AdjointValue arrayAdjointValue, |
3312 | 2.63k | SILLocation loc) { |
3313 | | // Return if the original value is not the `Array` result of an |
3314 | | // `array.uninitialized_intrinsic` application. |
3315 | 2.63k | auto *dti = dyn_cast_or_null<DestructureTupleInst>( |
3316 | 2.63k | originalValue->getDefiningInstruction()); |
3317 | 2.63k | if (!dti) |
3318 | 2.54k | return; |
3319 | 92 | if (!ArraySemanticsCall(dti->getOperand(), |
3320 | 92 | semantics::ARRAY_UNINITIALIZED_INTRINSIC)) |
3321 | 32 | return; |
3322 | 60 | if (originalValue != dti->getResult(0)) |
3323 | 0 | return; |
3324 | | // Accumulate the array's adjoint value into the adjoint buffers of its |
3325 | | // element addresses: `pointer_to_address` and `index_addr` instructions. |
3326 | 60 | LLVM_DEBUG(getADDebugStream() |
3327 | 60 | << "Accumulating adjoint value for array literal into element " |
3328 | 60 | "address adjoint buffers" |
3329 | 60 | << originalValue); |
3330 | 60 | auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc); |
3331 | 60 | builder.setCurrentDebugScope(remapScope(dti->getDebugScope())); |
3332 | 60 | builder.setInsertionPoint(arrayAdjoint->getParentBlock()); |
3333 | 60 | for (auto use : dti->getResult(1)->getUses()) { |
3334 | 60 | auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser()); |
3335 | 60 | auto adjBuf = getAdjointBuffer(origBB, ptai); |
3336 | 60 | auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc); |
3337 | 60 | builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf); |
3338 | 72 | for (auto use : ptai->getUses()) { |
3339 | 72 | if (auto *iai = dyn_cast<IndexAddrInst>(use->getUser())) { |
3340 | 12 | auto *ili = cast<IntegerLiteralInst>(iai->getIndex()); |
3341 | 12 | auto eltIndex = ili->getValue().getLimitedValue(); |
3342 | 12 | auto adjBuf = getAdjointBuffer(origBB, iai); |
3343 | 12 | auto *eltAdjBuf = |
3344 | 12 | getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc); |
3345 | 12 | builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf); |
3346 | 12 | } |
3347 | 72 | } |
3348 | 60 | } |
3349 | 60 | } |
3350 | | |
3351 | | AllocStackInst *PullbackCloner::Implementation::getArrayAdjointElementBuffer( |
3352 | 560 | SILValue arrayAdjoint, int eltIndex, SILLocation loc) { |
3353 | 560 | auto &ctx = builder.getASTContext(); |
3354 | 560 | auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType()); |
3355 | 560 | auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>(); |
3356 | 560 | auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType(); |
3357 | 560 | auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType)); |
3358 | | // Get `function_ref` and generic signature of |
3359 | | // `Array.TangentVector.subscript.getter`. |
3360 | 560 | auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct(); |
3361 | 560 | auto subscriptLookup = |
3362 | 560 | arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript()); |
3363 | 560 | SubscriptDecl *subscriptDecl = nullptr; |
3364 | 1.05k | for (auto *candidate : subscriptLookup) { |
3365 | 1.05k | auto candidateModule = candidate->getModuleContext(); |
3366 | 1.05k | if (candidateModule->getName() == ctx.Id_Differentiation || |
3367 | 1.05k | candidateModule->isStdlibModule()) { |
3368 | 560 | assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s"); |
3369 | 0 | subscriptDecl = cast<SubscriptDecl>(candidate); |
3370 | | #ifdef NDEBUG |
3371 | | break; |
3372 | | #endif |
3373 | 560 | } |
3374 | 1.05k | } |
3375 | 560 | assert(subscriptDecl && "No `Array.TangentVector.subscript`"); |
3376 | 0 | auto *subscriptGetterDecl = |
3377 | 560 | subscriptDecl->getOpaqueAccessor(AccessorKind::Get); |
3378 | 560 | assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter"); |
3379 | 0 | SILOptFunctionBuilder fb(getContext().getTransform()); |
3380 | 560 | auto *subscriptGetterFn = fb.getOrCreateFunction( |
3381 | 560 | loc, SILDeclRef(subscriptGetterDecl), NotForDefinition); |
3382 | | // %subscript_fn = function_ref @Array.TangentVector<T>.subscript.getter |
3383 | 560 | auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn); |
3384 | 560 | auto subscriptFnGenSig = |
3385 | 560 | subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature(); |
3386 | | // Apply `Array.TangentVector.subscript.getter` to get array element adjoint |
3387 | | // buffer. |
3388 | | // %index_literal = integer_literal $Builtin.IntXX, <index> |
3389 | 560 | auto builtinIntType = |
3390 | 560 | SILType::getPrimitiveObjectType(ctx.getIntDecl() |
3391 | 560 | ->getStoredProperties() |
3392 | 560 | .front() |
3393 | 560 | ->getInterfaceType() |
3394 | 560 | ->getCanonicalType()); |
3395 | 560 | auto *eltIndexLiteral = |
3396 | 560 | builder.createIntegerLiteral(loc, builtinIntType, eltIndex); |
3397 | 560 | auto intType = SILType::getPrimitiveObjectType( |
3398 | 560 | ctx.getIntType()->getCanonicalType()); |
3399 | | // %index_int = struct $Int (%index_literal) |
3400 | 560 | auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral}); |
3401 | 560 | auto *swiftModule = getModule().getSwiftModule(); |
3402 | 560 | auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
3403 | 560 | auto diffConf = swiftModule->lookupConformance(eltTanType, diffProto); |
3404 | 560 | assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); |
3405 | 0 | auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
3406 | 560 | auto addArithConf = swiftModule->lookupConformance(eltTanType, addArithProto); |
3407 | 560 | assert(!addArithConf.isInvalid() && |
3408 | 560 | "Missing conformance to `AdditiveArithmetic`"); |
3409 | 0 | auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType}, |
3410 | 560 | {addArithConf, diffConf}); |
3411 | | // %elt_adj = alloc_stack $T.TangentVector |
3412 | | // Create and register a local allocation. |
3413 | 560 | auto *eltAdjBuffer = createFunctionLocalAllocation( |
3414 | 560 | eltTanSILType, loc, /*zeroInitialize*/ true); |
3415 | | // Immediately destroy the emitted zero value. |
3416 | | // NOTE: It is not efficient to emit a zero value then immediately destroy |
3417 | | // it. However, it was the easiest way to to avoid "lifetime mismatch in |
3418 | | // predecessors" memory lifetime verification errors for control flow |
3419 | | // differentiation. |
3420 | | // Perhaps we can avoid emitting a zero value if local allocations are created |
3421 | | // per pullback bb instead of all in the pullback entry: TF-1075. |
3422 | 560 | builder.emitDestroyOperation(loc, eltAdjBuffer); |
3423 | | // apply %subscript_fn<T.TangentVector>(%elt_adj, %index_int, %array_adj) |
3424 | 560 | builder.createApply(loc, subscriptFnRef, subMap, |
3425 | 560 | {eltAdjBuffer, eltIndexInt, arrayAdjoint}); |
3426 | 560 | return eltAdjBuffer; |
3427 | 560 | } |
3428 | | |
3429 | | } // end namespace autodiff |
3430 | | } // end namespace swift |