/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/VJPCloner.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- VJPCloner.cpp - VJP function generation --------------*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // This file defines a helper class for generating VJP functions for automatic |
14 | | // differentiation. |
15 | | // |
16 | | //===----------------------------------------------------------------------===// |
17 | | |
18 | | #define DEBUG_TYPE "differentiation" |
19 | | |
20 | | #include "swift/AST/Types.h" |
21 | | |
22 | | #include "swift/SILOptimizer/Differentiation/VJPCloner.h" |
23 | | #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
24 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
25 | | #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" |
26 | | #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" |
27 | | #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" |
28 | | #include "swift/SILOptimizer/Differentiation/Thunk.h" |
29 | | |
30 | | #include "swift/SIL/TerminatorUtils.h" |
31 | | #include "swift/SIL/TypeSubstCloner.h" |
32 | | #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" |
33 | | #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" |
34 | | #include "swift/SILOptimizer/Utils/CFGOptUtils.h" |
35 | | #include "swift/SILOptimizer/Utils/DifferentiationMangler.h" |
36 | | #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
37 | | #include "llvm/ADT/DenseMap.h" |
38 | | |
39 | | namespace swift { |
40 | | namespace autodiff { |
41 | | |
42 | | class VJPCloner::Implementation final |
43 | | : public TypeSubstCloner<VJPCloner::Implementation, SILOptFunctionBuilder> { |
44 | | friend class VJPCloner; |
45 | | friend class PullbackCloner; |
46 | | |
47 | | /// The parent VJP cloner. |
48 | | VJPCloner &cloner; |
49 | | |
50 | | /// The global context. |
51 | | ADContext &context; |
52 | | |
53 | | /// The original function. |
54 | | SILFunction *const original; |
55 | | |
56 | | /// The differentiability witness. |
57 | | SILDifferentiabilityWitness *const witness; |
58 | | |
59 | | /// The VJP function. |
60 | | SILFunction *const vjp; |
61 | | |
62 | | /// The pullback function. |
63 | | SILFunction *pullback; |
64 | | |
65 | | /// The differentiation invoker. |
66 | | DifferentiationInvoker invoker; |
67 | | |
68 | | /// Info from activity analysis on the original function. |
69 | | const DifferentiableActivityInfo &activityInfo; |
70 | | |
71 | | /// The loop info. |
72 | | SILLoopInfo *loopInfo; |
73 | | |
74 | | /// The linear map info. |
75 | | LinearMapInfo pullbackInfo; |
76 | | |
77 | | /// Caches basic blocks whose phi arguments have been remapped (adding a |
78 | | /// predecessor enum argument). |
79 | | SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks; |
80 | | |
81 | | /// The `AutoDiffLinearMapContext` object. If null, no explicit context is |
82 | | /// needed (no loops). |
83 | | SILValue pullbackContextValue; |
84 | | /// The unique, borrowed context object. This is valid until the exit block. |
85 | | SILValue borrowedPullbackContextValue; |
86 | | |
87 | | /// The generic signature of the `Builtin.autoDiffAllocateSubcontext(_:_:)` |
88 | | /// declaration. It is used for creating a builtin call. |
89 | | GenericSignature builtinAutoDiffAllocateSubcontextGenericSignature; |
90 | | |
91 | | bool errorOccurred = false; |
92 | | |
93 | | /// Mapping from original blocks to pullback values. Used to build pullback |
94 | | /// struct instances. |
95 | | llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> pullbackValues; |
96 | | |
97 | 12.5k | ASTContext &getASTContext() const { return vjp->getASTContext(); } |
98 | 28.4k | SILModule &getModule() const { return vjp->getModule(); } |
99 | 179k | const AutoDiffConfig &getConfig() const { |
100 | 179k | return witness->getConfig(); |
101 | 179k | } |
102 | | |
103 | | Implementation(VJPCloner &parent, ADContext &context, |
104 | | SILDifferentiabilityWitness *witness, SILFunction *vjp, |
105 | | DifferentiationInvoker invoker); |
106 | | |
107 | | /// Creates an empty pullback function, to be filled in by `PullbackCloner`. |
108 | | SILFunction *createEmptyPullback(); |
109 | | |
110 | | /// Run VJP generation. Returns true on error. |
111 | | bool run(); |
112 | | |
113 | | /// Initializes a context object if needed. |
114 | 5.25k | void emitLinearMapContextInitializationIfNeeded() { |
115 | 5.25k | if (!pullbackInfo.hasHeapAllocatedContext()) |
116 | 5.14k | return; |
117 | | |
118 | | // Get linear map struct size. |
119 | 108 | auto *returnBB = &*original->findReturnBB(); |
120 | 108 | auto pullbackTupleType = |
121 | 108 | remapASTType(pullbackInfo.getLinearMapTupleType(returnBB)->getCanonicalType()); |
122 | 108 | Builder.setInsertionPoint(vjp->getEntryBlock()); |
123 | | |
124 | 108 | auto pbTupleMetatypeType = |
125 | 108 | CanMetatypeType::get(pullbackTupleType, MetatypeRepresentation::Thick); |
126 | 108 | auto pbTupleMetatypeSILType = |
127 | 108 | SILType::getPrimitiveObjectType(pbTupleMetatypeType); |
128 | 108 | auto pbTupleMetatype = |
129 | 108 | Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType); |
130 | | |
131 | | // Create an context. |
132 | 108 | pullbackContextValue = Builder.createBuiltin( |
133 | 108 | original->getLocation(), |
134 | 108 | getASTContext().getIdentifier(getBuiltinName( |
135 | 108 | BuiltinValueKind::AutoDiffCreateLinearMapContextWithType)), |
136 | 108 | SILType::getNativeObjectType(getASTContext()), SubstitutionMap(), |
137 | 108 | {pbTupleMetatype}); |
138 | 108 | borrowedPullbackContextValue = Builder.createBeginBorrow( |
139 | 108 | original->getLocation(), pullbackContextValue); |
140 | 108 | LLVM_DEBUG(getADDebugStream() |
141 | 108 | << "Context object initialized because there are loops\n" |
142 | 108 | << *vjp->getEntryBlock() << '\n' |
143 | 108 | << "pullback tuple type: " << pullbackTupleType << '\n'); |
144 | 108 | } |
145 | | |
146 | | /// Get the lowered SIL type of the given AST type. |
147 | 5.71k | SILType getLoweredType(Type type) { |
148 | 5.71k | auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); |
149 | 5.71k | Lowering::AbstractionPattern pattern(vjpGenSig, |
150 | 5.71k | type->getReducedType(vjpGenSig)); |
151 | 5.71k | return vjp->getLoweredType(pattern, type); |
152 | 5.71k | } |
153 | | |
154 | 0 | GenericSignature getBuiltinAutoDiffAllocateSubcontextDecl() { |
155 | 0 | if (builtinAutoDiffAllocateSubcontextGenericSignature) |
156 | 0 | return builtinAutoDiffAllocateSubcontextGenericSignature; |
157 | 0 | auto &ctx = getASTContext(); |
158 | 0 | auto *decl = cast<FuncDecl>(getBuiltinValueDecl( |
159 | 0 | ctx, ctx.getIdentifier(getBuiltinName( |
160 | 0 | BuiltinValueKind::AutoDiffAllocateSubcontextWithType)))); |
161 | 0 | builtinAutoDiffAllocateSubcontextGenericSignature = |
162 | 0 | decl->getGenericSignature(); |
163 | 0 | assert(builtinAutoDiffAllocateSubcontextGenericSignature); |
164 | 0 | return builtinAutoDiffAllocateSubcontextGenericSignature; |
165 | 0 | } |
166 | | |
167 | | // Creates a trampoline block for given original terminator instruction, the |
168 | | // pullback struct value for its parent block, and a successor basic block. |
169 | | // |
170 | | // The trampoline block has the same arguments as and branches to the remapped |
171 | | // successor block, but drops the last predecessor enum argument. |
172 | | // |
173 | | // Used for cloning branching terminator instructions with specific |
174 | | // requirements on successor block arguments, where an additional predecessor |
175 | | // enum argument is not acceptable. |
176 | | SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst, |
177 | | TupleInst *pbTupleVal, |
178 | | SILBasicBlock *succBB); |
179 | | |
180 | | /// Build a pullback tuple value for the given original terminator |
181 | | /// instruction. |
182 | | TupleInst *buildPullbackValueTupleValue(TermInst *termInst); |
183 | | llvm::SmallVector<SILValue, 8> getPullbackValues(SILBasicBlock *origBB); |
184 | | |
185 | | /// Build a predecessor enum instance using the given builder for the given |
186 | | /// original predecessor/successor blocks and pullback struct value. |
187 | | EnumInst *buildPredecessorEnumValue(SILBuilder &builder, |
188 | | SILBasicBlock *predBB, |
189 | | SILBasicBlock *succBB, |
190 | | SILValue pbTupleVal); |
191 | | |
192 | | public: |
193 | | /// Remap original basic blocks, adding predecessor enum arguments. |
194 | 2.63k | SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { |
195 | 2.63k | auto *vjpBB = BBMap[bb]; |
196 | | // If error has occurred, or if block has already been remapped, return |
197 | | // remapped, return remapped block. |
198 | 2.63k | if (errorOccurred || remappedBasicBlocks.count(bb)) |
199 | 648 | return vjpBB; |
200 | | // Add predecessor enum argument to the remapped block. |
201 | 1.98k | auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb); |
202 | 1.98k | auto enumTy = |
203 | 1.98k | getOpASTType(predEnum->getDeclaredInterfaceType()->getCanonicalType()); |
204 | 1.98k | auto enumLoweredTy = context.getTypeConverter().getLoweredType( |
205 | 1.98k | enumTy, TypeExpansionContext::minimal()); |
206 | 1.98k | vjpBB->createPhiArgument(enumLoweredTy, OwnershipKind::Owned); |
207 | 1.98k | remappedBasicBlocks.insert(bb); |
208 | 1.98k | return vjpBB; |
209 | 2.63k | } |
210 | | |
211 | | /// General visitor for all instructions. If any error is emitted by previous |
212 | | /// visits, bail out. |
213 | 86.1k | void visit(SILInstruction *inst) { |
214 | 86.1k | if (errorOccurred) |
215 | 192 | return; |
216 | 85.9k | TypeSubstCloner::visit(inst); |
217 | 85.9k | } |
218 | | |
219 | 0 | void visitSILInstruction(SILInstruction *inst) { |
220 | 0 | context.emitNondifferentiabilityError( |
221 | 0 | inst, invoker, diag::autodiff_expression_not_differentiable_note); |
222 | 0 | errorOccurred = true; |
223 | 0 | } |
224 | | |
225 | 73.0k | void postProcess(SILInstruction *orig, SILInstruction *cloned) { |
226 | 73.0k | if (errorOccurred) |
227 | 0 | return; |
228 | 73.0k | SILClonerWithScopes::postProcess(orig, cloned); |
229 | 73.0k | } |
230 | | |
231 | 5.23k | void visitReturnInst(ReturnInst *ri) { |
232 | 5.23k | Builder.setCurrentDebugScope(getOpScope(ri->getDebugScope())); |
233 | 5.23k | auto loc = ri->getOperand().getLoc(); |
234 | | // Build pullback tuple value for original block. |
235 | 5.23k | auto *origExit = ri->getParent(); |
236 | | |
237 | | // Get the value in the VJP corresponding to the original result. |
238 | 5.23k | auto *origRetInst = cast<ReturnInst>(origExit->getTerminator()); |
239 | 5.23k | auto origResult = getOpValue(origRetInst->getOperand()); |
240 | 5.23k | SmallVector<SILValue, 8> origResults; |
241 | 5.23k | extractAllElements(origResult, Builder, origResults); |
242 | | |
243 | | // Get and partially apply the pullback. |
244 | 5.23k | auto vjpSubstMap = vjp->getForwardingSubstitutionMap(); |
245 | 5.23k | auto *pullbackRef = Builder.createFunctionRef(loc, pullback); |
246 | | |
247 | | // Prepare partial application arguments. |
248 | 5.23k | SILValue partialApplyArg; |
249 | 5.23k | PartialApplyInst *pullbackPartialApply; |
250 | 5.23k | if (borrowedPullbackContextValue) { |
251 | 104 | auto *pbTupleVal = buildPullbackValueTupleValue(ri); |
252 | | // Initialize the top-level subcontext buffer with the top-level pullback |
253 | | // tuple. |
254 | 104 | auto addr = emitProjectTopLevelSubcontext( |
255 | 104 | Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType()); |
256 | 104 | Builder.createStore( |
257 | 104 | loc, pbTupleVal, addr, |
258 | 104 | pbTupleVal->getType().isTrivial(*pullback) ? |
259 | 92 | StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); |
260 | | |
261 | 104 | Builder.createEndBorrow(loc, borrowedPullbackContextValue); |
262 | 104 | pullbackPartialApply = Builder.createPartialApply( |
263 | 104 | loc, pullbackRef, vjpSubstMap, {pullbackContextValue}, |
264 | 104 | ParameterConvention::Direct_Guaranteed); |
265 | 5.13k | } else { |
266 | 5.13k | pullbackPartialApply = Builder.createPartialApply( |
267 | 5.13k | loc, pullbackRef, vjpSubstMap, getPullbackValues(origExit), |
268 | 5.13k | ParameterConvention::Direct_Guaranteed); |
269 | 5.13k | } |
270 | | |
271 | 5.23k | auto pullbackType = vjp->mapTypeIntoContext( |
272 | 5.23k | vjp->getConventions().getSILType( |
273 | 5.23k | vjp->getLoweredFunctionType()->getResults().back(), |
274 | 5.23k | vjp->getTypeExpansionContext())); |
275 | 5.23k | auto pullbackFnType = pullbackType.castTo<SILFunctionType>(); |
276 | 5.23k | auto pullbackSubstType = |
277 | 5.23k | pullbackPartialApply->getType().castTo<SILFunctionType>(); |
278 | | |
279 | | // If necessary, convert the pullback value to the returned pullback |
280 | | // function type. |
281 | 5.23k | SILValue pullbackValue; |
282 | 5.23k | if (pullbackSubstType == pullbackFnType) { |
283 | 4.36k | pullbackValue = pullbackPartialApply; |
284 | 4.36k | } else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp) |
285 | 872 | .isCompatible()) { |
286 | 872 | pullbackValue = |
287 | 872 | Builder.createConvertFunction(loc, pullbackPartialApply, pullbackType, |
288 | 872 | /*withoutActuallyEscaping*/ false); |
289 | 872 | } else { |
290 | 0 | llvm::report_fatal_error("Pullback value type is not ABI-compatible " |
291 | 0 | "with the returned pullback type"); |
292 | 0 | } |
293 | | |
294 | | // Return a tuple of the original result and pullback. |
295 | 5.23k | SmallVector<SILValue, 8> directResults; |
296 | 5.23k | directResults.append(origResults.begin(), origResults.end()); |
297 | 5.23k | directResults.push_back(pullbackValue); |
298 | 5.23k | Builder.createReturn(ri->getLoc(), |
299 | 5.23k | joinElements(directResults, Builder, loc)); |
300 | 5.23k | } |
301 | | |
302 | 1.19k | void visitBranchInst(BranchInst *bi) { |
303 | 1.19k | Builder.setCurrentDebugScope(getOpScope(bi->getDebugScope())); |
304 | | // Build pullback struct value for original block. |
305 | | // Build predecessor enum value for destination block. |
306 | 1.19k | auto *origBB = bi->getParent(); |
307 | 1.19k | auto *pbTupleVal = buildPullbackValueTupleValue(bi); |
308 | 1.19k | auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB, |
309 | 1.19k | bi->getDestBB(), pbTupleVal); |
310 | | |
311 | | // Remap arguments, appending the new enum values. |
312 | 1.19k | SmallVector<SILValue, 8> args; |
313 | 1.19k | for (auto origArg : bi->getArgs()) |
314 | 528 | args.push_back(getOpValue(origArg)); |
315 | 1.19k | args.push_back(enumVal); |
316 | | |
317 | | // Create a new `br` instruction. |
318 | 1.19k | getBuilder().createBranch(bi->getLoc(), getOpBasicBlock(bi->getDestBB()), |
319 | 1.19k | args); |
320 | 1.19k | } |
321 | | |
322 | 224 | void visitCondBranchInst(CondBranchInst *cbi) { |
323 | 224 | Builder.setCurrentDebugScope(getOpScope(cbi->getDebugScope())); |
324 | | // Build pullback struct value for original block. |
325 | 224 | auto *pbTupleVal = buildPullbackValueTupleValue(cbi); |
326 | | // Create a new `cond_br` instruction. |
327 | 224 | getBuilder().createCondBranch( |
328 | 224 | cbi->getLoc(), getOpValue(cbi->getCondition()), |
329 | 224 | createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getTrueBB()), |
330 | 224 | createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getFalseBB())); |
331 | 224 | } |
332 | | |
333 | 452 | void visitSwitchEnumTermInst(SwitchEnumTermInst inst) { |
334 | 452 | Builder.setCurrentDebugScope(getOpScope(inst->getDebugScope())); |
335 | | // Build pullback tuple value for original block. |
336 | 452 | auto *pbTupleVal = buildPullbackValueTupleValue(*inst); |
337 | | |
338 | | // Create trampoline successor basic blocks. |
339 | 452 | SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs; |
340 | 876 | for (unsigned i : range(inst.getNumCases())) { |
341 | 876 | auto caseBB = inst.getCase(i); |
342 | 876 | auto *trampolineBB = |
343 | 876 | createTrampolineBasicBlock(inst, pbTupleVal, caseBB.second); |
344 | 876 | caseBBs.push_back({caseBB.first, trampolineBB}); |
345 | 876 | } |
346 | | // Create trampoline default basic block. |
347 | 452 | SILBasicBlock *newDefaultBB = nullptr; |
348 | 452 | if (auto *defaultBB = inst.getDefaultBBOrNull().getPtrOrNull()) |
349 | 20 | newDefaultBB = createTrampolineBasicBlock(inst, pbTupleVal, defaultBB); |
350 | | |
351 | | // Create a new `switch_enum` instruction. |
352 | 452 | switch (inst->getKind()) { |
353 | 324 | case SILInstructionKind::SwitchEnumInst: |
354 | 324 | getBuilder().createSwitchEnum( |
355 | 324 | inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs); |
356 | 324 | break; |
357 | 128 | case SILInstructionKind::SwitchEnumAddrInst: |
358 | 128 | getBuilder().createSwitchEnumAddr( |
359 | 128 | inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs); |
360 | 128 | break; |
361 | 0 | default: |
362 | 0 | llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`"); |
363 | 452 | } |
364 | 452 | } |
365 | | |
366 | 324 | void visitSwitchEnumInst(SwitchEnumInst *sei) { |
367 | 324 | visitSwitchEnumTermInst(sei); |
368 | 324 | } |
369 | | |
370 | 128 | void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { |
371 | 128 | visitSwitchEnumTermInst(seai); |
372 | 128 | } |
373 | | |
374 | 4 | void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { |
375 | 4 | Builder.setCurrentDebugScope(getOpScope(ccbi->getDebugScope())); |
376 | | // Build pullback struct value for original block. |
377 | 4 | auto *pbTupleVal = buildPullbackValueTupleValue(ccbi); |
378 | | // Create a new `checked_cast_branch` instruction. |
379 | 4 | getBuilder().createCheckedCastBranch( |
380 | 4 | ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()), |
381 | 4 | getOpASTType(ccbi->getSourceFormalType()), |
382 | 4 | getOpType(ccbi->getTargetLoweredType()), |
383 | 4 | getOpASTType(ccbi->getTargetFormalType()), |
384 | 4 | createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getSuccessBB()), |
385 | 4 | createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getFailureBB()), |
386 | 4 | ccbi->getTrueBBCount(), ccbi->getFalseBBCount()); |
387 | 4 | } |
388 | | |
389 | 8 | void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi) { |
390 | 8 | Builder.setCurrentDebugScope(getOpScope(ccabi->getDebugScope())); |
391 | | // Build pullback struct value for original block. |
392 | 8 | auto *pbTupleVal = buildPullbackValueTupleValue(ccabi); |
393 | | // Create a new `checked_cast_addr_branch` instruction. |
394 | 8 | getBuilder().createCheckedCastAddrBranch( |
395 | 8 | ccabi->getLoc(), ccabi->getConsumptionKind(), |
396 | 8 | getOpValue(ccabi->getSrc()), getOpASTType(ccabi->getSourceFormalType()), |
397 | 8 | getOpValue(ccabi->getDest()), |
398 | 8 | getOpASTType(ccabi->getTargetFormalType()), |
399 | 8 | createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getSuccessBB()), |
400 | 8 | createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getFailureBB()), |
401 | 8 | ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); |
402 | 8 | } |
403 | | |
404 | | // If an `apply` has active results or active inout arguments, replace it |
405 | | // with an `apply` of its VJP. |
406 | 9.62k | void visitApplyInst(ApplyInst *ai) { |
407 | | // If callee should not be differentiated, do standard cloning. |
408 | 9.62k | if (!pullbackInfo.shouldDifferentiateApplySite(ai)) { |
409 | 3.32k | LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); |
410 | 3.32k | TypeSubstCloner::visitApplyInst(ai); |
411 | 3.32k | return; |
412 | 3.32k | } |
413 | | // If callee is `array.uninitialized_intrinsic`, do standard cloning. |
414 | | // `array.uninitialized_intrinsic` differentiation is handled separately. |
415 | 6.30k | if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { |
416 | 212 | LLVM_DEBUG(getADDebugStream() |
417 | 212 | << "Cloning `array.uninitialized_intrinsic` `apply`:\n" |
418 | 212 | << *ai << '\n'); |
419 | 212 | TypeSubstCloner::visitApplyInst(ai); |
420 | 212 | return; |
421 | 212 | } |
422 | | // If callee is `array.finalize_intrinsic`, do standard cloning. |
423 | | // `array.finalize_intrinsic` has special-case pullback generation. |
424 | 6.09k | if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { |
425 | 212 | LLVM_DEBUG(getADDebugStream() |
426 | 212 | << "Cloning `array.finalize_intrinsic` `apply`:\n" |
427 | 212 | << *ai << '\n'); |
428 | 212 | TypeSubstCloner::visitApplyInst(ai); |
429 | 212 | return; |
430 | 212 | } |
431 | | // If the original function is a semantic member accessor, do standard |
432 | | // cloning. Semantic member accessors have special pullback generation |
433 | | // logic, so all `apply` instructions can be directly cloned to the VJP. |
434 | 5.88k | if (isSemanticMemberAccessor(original)) { |
435 | 152 | LLVM_DEBUG(getADDebugStream() |
436 | 152 | << "Cloning `apply` in semantic member accessor:\n" |
437 | 152 | << *ai << '\n'); |
438 | 152 | TypeSubstCloner::visitApplyInst(ai); |
439 | 152 | return; |
440 | 152 | } |
441 | | |
442 | 5.73k | Builder.setCurrentDebugScope(getOpScope(ai->getDebugScope())); |
443 | 5.73k | auto loc = ai->getLoc(); |
444 | 5.73k | auto &builder = getBuilder(); |
445 | 5.73k | auto origCallee = getOpValue(ai->getCallee()); |
446 | 5.73k | auto originalFnTy = origCallee->getType().castTo<SILFunctionType>(); |
447 | | |
448 | 5.73k | LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n'); |
449 | | |
450 | | // Get the minimal parameter and result indices required for differentiating |
451 | | // this `apply`. |
452 | 5.73k | SmallVector<SILValue, 4> allResults; |
453 | 5.73k | SmallVector<unsigned, 8> activeParamIndices; |
454 | 5.73k | SmallVector<unsigned, 8> activeResultIndices; |
455 | 5.73k | collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo, |
456 | 5.73k | allResults, activeParamIndices, |
457 | 5.73k | activeResultIndices); |
458 | 5.73k | assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); |
459 | 0 | assert(!activeResultIndices.empty() && "Result indices cannot be empty"); |
460 | 5.73k | LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; |
461 | 5.73k | llvm::interleave( |
462 | 5.73k | activeParamIndices.begin(), activeParamIndices.end(), |
463 | 5.73k | [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
464 | 5.73k | s << "), results=("; llvm::interleave( |
465 | 5.73k | activeResultIndices.begin(), activeResultIndices.end(), |
466 | 5.73k | [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
467 | 5.73k | s << ")\n";); |
468 | | |
469 | | // Form expected indices. |
470 | 5.73k | AutoDiffConfig config( |
471 | 5.73k | IndexSubset::get(getASTContext(), |
472 | 5.73k | ai->getArgumentsWithoutIndirectResults().size(), |
473 | 5.73k | activeParamIndices), |
474 | 5.73k | IndexSubset::get(getASTContext(), |
475 | 5.73k | ai->getSubstCalleeType()->getNumAutoDiffSemanticResults(), |
476 | 5.73k | activeResultIndices)); |
477 | | |
478 | | // Emit the VJP. |
479 | 5.73k | SILValue vjpValue; |
480 | | // If functionSource is a `@differentiable` function, just extract it. |
481 | 5.73k | if (originalFnTy->isDifferentiable()) { |
482 | 156 | auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); |
483 | 204 | for (auto i : config.parameterIndices->getIndices()) { |
484 | 204 | if (!paramIndices->contains(i)) { |
485 | 0 | context.emitNondifferentiabilityError( |
486 | 0 | origCallee, invoker, |
487 | 0 | diag:: |
488 | 0 | autodiff_function_noderivative_parameter_not_differentiable); |
489 | 0 | errorOccurred = true; |
490 | 0 | return; |
491 | 0 | } |
492 | 204 | } |
493 | 156 | builder.emitScopedBorrowOperation( |
494 | 156 | loc, origCallee, [&](SILValue borrowedDiffFunc) { |
495 | 156 | auto origFnType = origCallee->getType().castTo<SILFunctionType>(); |
496 | 156 | auto origFnUnsubstType = |
497 | 156 | origFnType->getUnsubstitutedType(getModule()); |
498 | 156 | if (origFnType != origFnUnsubstType) { |
499 | 20 | borrowedDiffFunc = builder.createConvertFunction( |
500 | 20 | loc, borrowedDiffFunc, |
501 | 20 | SILType::getPrimitiveObjectType(origFnUnsubstType), |
502 | 20 | /*withoutActuallyEscaping*/ false); |
503 | 20 | } |
504 | 156 | vjpValue = builder.createDifferentiableFunctionExtract( |
505 | 156 | loc, NormalDifferentiableFunctionTypeComponent::VJP, |
506 | 156 | borrowedDiffFunc); |
507 | 156 | vjpValue = builder.emitCopyValueOperation(loc, vjpValue); |
508 | 156 | }); |
509 | 156 | auto vjpFnType = vjpValue->getType().castTo<SILFunctionType>(); |
510 | 156 | auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule()); |
511 | 156 | if (vjpFnType != vjpFnUnsubstType) { |
512 | 0 | vjpValue = builder.createConvertFunction( |
513 | 0 | loc, vjpValue, SILType::getPrimitiveObjectType(vjpFnUnsubstType), |
514 | 0 | /*withoutActuallyEscaping*/ false); |
515 | 0 | } |
516 | 156 | } |
517 | | |
518 | | // Check and diagnose non-differentiable original function type. |
519 | 5.73k | auto diagnoseNondifferentiableOriginalFunctionType = |
520 | 7.82k | [&](CanSILFunctionType origFnTy) { |
521 | | // Check and diagnose non-differentiable arguments. |
522 | 11.9k | for (auto paramIndex : config.parameterIndices->getIndices()) { |
523 | 11.9k | if (!originalFnTy->getParameters()[paramIndex] |
524 | 11.9k | .getSILStorageInterfaceType() |
525 | 11.9k | .isDifferentiable(getModule())) { |
526 | 8 | auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex]; |
527 | | // FIXME: This shouldn't be necessary and might indicate a bug in |
528 | | // the transformation. |
529 | 8 | RegularLocation nonAutoGenLoc(arg.getLoc()); |
530 | 8 | nonAutoGenLoc.markNonAutoGenerated(); |
531 | 8 | auto startLoc = nonAutoGenLoc.getStartSourceLoc(); |
532 | 8 | auto endLoc = nonAutoGenLoc.getEndSourceLoc(); |
533 | 8 | context |
534 | 8 | .emitNondifferentiabilityError( |
535 | 8 | arg, invoker, diag::autodiff_nondifferentiable_argument) |
536 | 8 | .fixItInsert(startLoc, "withoutDerivative(at: ") |
537 | 8 | .fixItInsertAfter(endLoc, ")"); |
538 | 8 | errorOccurred = true; |
539 | 8 | return true; |
540 | 8 | } |
541 | 11.9k | } |
542 | | // Check and diagnose non-differentiable results. |
543 | 7.92k | for (auto resultIndex : config.resultIndices->getIndices()) { |
544 | 7.92k | SILType remappedResultType; |
545 | 7.92k | if (resultIndex >= originalFnTy->getNumResults()) { |
546 | 648 | auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults(); |
547 | 648 | auto semanticResultArg = |
548 | 648 | *std::next(ai->getAutoDiffSemanticResultArguments().begin(), |
549 | 648 | semanticResultArgIdx); |
550 | 648 | remappedResultType = semanticResultArg->getType(); |
551 | 7.27k | } else { |
552 | 7.27k | remappedResultType = originalFnTy->getResults()[resultIndex] |
553 | 7.27k | .getSILStorageInterfaceType(); |
554 | 7.27k | } |
555 | 7.92k | if (!remappedResultType.isDifferentiable(getModule())) { |
556 | 12 | auto startLoc = ai->getLoc().getStartSourceLoc(); |
557 | 12 | auto endLoc = ai->getLoc().getEndSourceLoc(); |
558 | 12 | context |
559 | 12 | .emitNondifferentiabilityError( |
560 | 12 | origCallee, invoker, |
561 | 12 | diag::autodiff_nondifferentiable_result) |
562 | 12 | .fixItInsert(startLoc, "withoutDerivative(at: ") |
563 | 12 | .fixItInsertAfter(endLoc, ")"); |
564 | 12 | errorOccurred = true; |
565 | 12 | return true; |
566 | 12 | } |
567 | 7.92k | } |
568 | 7.80k | return false; |
569 | 7.81k | }; |
570 | 5.73k | if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) |
571 | 12 | return; |
572 | | |
573 | | // If VJP has not yet been found, emit an `differentiable_function` |
574 | | // instruction on the remapped original function operand and |
575 | | // an `differentiable_function_extract` instruction to get the VJP. |
576 | | // The `differentiable_function` instruction will be canonicalized during |
577 | | // the transform main loop. |
578 | 5.72k | if (!vjpValue) { |
579 | | // FIXME: Handle indirect differentiation invokers. This may require some |
580 | | // redesign: currently, each original function + witness pair is mapped |
581 | | // only to one invoker. |
582 | | /* |
583 | | DifferentiationInvoker indirect(ai, attr); |
584 | | auto insertion = |
585 | | context.getInvokers().try_emplace({original, attr}, indirect); |
586 | | auto &invoker = insertion.first->getSecond(); |
587 | | invoker = indirect; |
588 | | */ |
589 | | |
590 | | // If the original `apply` instruction has a substitution map, then the |
591 | | // applied function is specialized. |
592 | | // In the VJP, specialization is also necessary for parity. The original |
593 | | // function operand is specialized with a remapped version of same |
594 | | // substitution map using an argument-less `partial_apply`. |
595 | 5.56k | if (ai->getSubstitutionMap().empty()) { |
596 | 3.47k | origCallee = builder.emitCopyValueOperation(loc, origCallee); |
597 | 3.47k | } else { |
598 | 2.08k | auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); |
599 | 2.08k | auto vjpPartialApply = getBuilder().createPartialApply( |
600 | 2.08k | ai->getLoc(), origCallee, substMap, {}, |
601 | 2.08k | ParameterConvention::Direct_Guaranteed); |
602 | 2.08k | origCallee = vjpPartialApply; |
603 | 2.08k | originalFnTy = origCallee->getType().castTo<SILFunctionType>(); |
604 | | // Diagnose if new original function type is non-differentiable. |
605 | 2.08k | if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) |
606 | 8 | return; |
607 | 2.08k | } |
608 | | |
609 | 5.55k | auto *diffFuncInst = context.createDifferentiableFunction( |
610 | 5.55k | getBuilder(), loc, config.parameterIndices, config.resultIndices, |
611 | 5.55k | origCallee); |
612 | | |
613 | | // Record the `differentiable_function` instruction. |
614 | 5.55k | context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst); |
615 | | |
616 | 5.55k | builder.emitScopedBorrowOperation( |
617 | 5.55k | loc, diffFuncInst, [&](SILValue borrowedADFunc) { |
618 | 5.55k | auto extractedVJP = |
619 | 5.55k | getBuilder().createDifferentiableFunctionExtract( |
620 | 5.55k | loc, NormalDifferentiableFunctionTypeComponent::VJP, |
621 | 5.55k | borrowedADFunc); |
622 | 5.55k | vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); |
623 | 5.55k | }); |
624 | 5.55k | builder.emitDestroyValueOperation(loc, diffFuncInst); |
625 | 5.55k | } |
626 | | |
627 | | // Record desired/actual VJP indices. |
628 | | // Temporarily set original pullback type to `None`. |
629 | 5.71k | NestedApplyInfo info{config, /*originalPullbackType*/ llvm::None}; |
630 | 5.71k | auto insertion = context.getNestedApplyInfo().try_emplace(ai, info); |
631 | 5.71k | auto &nestedApplyInfo = insertion.first->getSecond(); |
632 | 5.71k | nestedApplyInfo = info; |
633 | | |
634 | | // Call the VJP using the original parameters. |
635 | 5.71k | SmallVector<SILValue, 8> vjpArgs; |
636 | 5.71k | auto vjpFnTy = getOpType(vjpValue->getType()).castTo<SILFunctionType>(); |
637 | 5.71k | auto numVJPArgs = |
638 | 5.71k | vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults(); |
639 | 5.71k | vjpArgs.reserve(numVJPArgs); |
640 | | // Collect substituted arguments. |
641 | 5.71k | for (auto origArg : ai->getArguments()) |
642 | 15.7k | vjpArgs.push_back(getOpValue(origArg)); |
643 | 5.71k | assert(vjpArgs.size() == numVJPArgs); |
644 | | // Apply the VJP. |
645 | | // The VJP should be specialized, so no substitution map is necessary. |
646 | 0 | auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(), |
647 | 5.71k | vjpArgs, ai->getApplyOptions()); |
648 | 5.71k | LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); |
649 | 5.71k | builder.emitDestroyValueOperation(loc, vjpValue); |
650 | | |
651 | | // Get the VJP results (original results and pullback). |
652 | 5.71k | SmallVector<SILValue, 8> vjpDirectResults; |
653 | 5.71k | extractAllElements(vjpCall, getBuilder(), vjpDirectResults); |
654 | 5.71k | ArrayRef<SILValue> originalDirectResults = |
655 | 5.71k | ArrayRef<SILValue>(vjpDirectResults).drop_back(1); |
656 | 5.71k | SILValue originalDirectResult = |
657 | 5.71k | joinElements(originalDirectResults, getBuilder(), vjpCall->getLoc()); |
658 | 5.71k | SILValue pullback = vjpDirectResults.back(); |
659 | 5.71k | { |
660 | 5.71k | auto pullbackFnType = pullback->getType().castTo<SILFunctionType>(); |
661 | 5.71k | auto pullbackUnsubstFnType = |
662 | 5.71k | pullbackFnType->getUnsubstitutedType(getModule()); |
663 | 5.71k | if (pullbackFnType != pullbackUnsubstFnType) { |
664 | 508 | pullback = builder.createConvertFunction( |
665 | 508 | loc, pullback, |
666 | 508 | SILType::getPrimitiveObjectType(pullbackUnsubstFnType), |
667 | 508 | /*withoutActuallyEscaping*/ false); |
668 | 508 | } |
669 | 5.71k | } |
670 | | |
671 | | // Store the original result to the value map. |
672 | 5.71k | mapValue(ai, originalDirectResult); |
673 | | |
674 | | // Checkpoint the pullback. |
675 | 5.71k | auto pullbackType = pullbackInfo.lookUpLinearMapType(ai); |
676 | | |
677 | | // If actual pullback type does not match lowered pullback type, reabstract |
678 | | // the pullback using a thunk. |
679 | 5.71k | auto actualPullbackType = |
680 | 5.71k | getOpType(pullback->getType()).getAs<SILFunctionType>(); |
681 | 5.71k | auto loweredPullbackType = |
682 | 5.71k | getOpType(getLoweredType(pullbackType)).castTo<SILFunctionType>(); |
683 | 5.71k | if (!loweredPullbackType->isEqual(actualPullbackType)) { |
684 | | // Set non-reabstracted original pullback type in nested apply info. |
685 | 1.46k | nestedApplyInfo.originalPullbackType = actualPullbackType; |
686 | 1.46k | SILOptFunctionBuilder fb(context.getTransform()); |
687 | 1.46k | pullback = reabstractFunction( |
688 | 1.46k | getBuilder(), fb, ai->getLoc(), pullback, loweredPullbackType, |
689 | 1.46k | [this](SubstitutionMap subs) -> SubstitutionMap { |
690 | 1.46k | return this->getOpSubstitutionMap(subs); |
691 | 1.46k | }); |
692 | 1.46k | } |
693 | 5.71k | pullbackValues[ai->getParent()].push_back(pullback); |
694 | | |
695 | | // Some instructions that produce the callee may have been cloned. |
696 | | // If the original callee did not have any users beyond this `apply`, |
697 | | // recursively kill the cloned callee. |
698 | 5.71k | if (auto *origCallee = cast_or_null<SingleValueInstruction>( |
699 | 5.71k | ai->getCallee()->getDefiningInstruction())) |
700 | 5.62k | if (origCallee->hasOneUse()) |
701 | 5.55k | recursivelyDeleteTriviallyDeadInstructions( |
702 | 5.55k | getOpValue(origCallee)->getDefiningInstruction()); |
703 | 5.71k | } |
704 | | |
705 | 36 | void visitTryApplyInst(TryApplyInst *tai) { |
706 | 36 | Builder.setCurrentDebugScope(getOpScope(tai->getDebugScope())); |
707 | | // Build pullback struct value for original block. |
708 | 36 | auto *pbTupleVal = buildPullbackValueTupleValue(tai); |
709 | | // Create a new `try_apply` instruction. |
710 | 36 | auto args = getOpValueArray<8>(tai->getArguments()); |
711 | 36 | getBuilder().createTryApply( |
712 | 36 | tai->getLoc(), getOpValue(tai->getCallee()), |
713 | 36 | getOpSubstitutionMap(tai->getSubstitutionMap()), args, |
714 | 36 | createTrampolineBasicBlock(tai, pbTupleVal, tai->getNormalBB()), |
715 | 36 | createTrampolineBasicBlock(tai, pbTupleVal, tai->getErrorBB()), |
716 | 36 | tai->getApplyOptions()); |
717 | 36 | } |
718 | | |
719 | 96 | void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { |
720 | | // Clone `differentiable_function` from original to VJP, then add the cloned |
721 | | // instruction to the `differentiable_function` worklist. |
722 | 96 | TypeSubstCloner::visitDifferentiableFunctionInst(dfi); |
723 | 96 | auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi)); |
724 | 96 | context.getDifferentiableFunctionInstWorklist().push_back(newDFI); |
725 | 96 | } |
726 | | |
727 | 0 | void visitLinearFunctionInst(LinearFunctionInst *lfi) { |
728 | | // Clone `linear_function` from original to VJP, then add the cloned |
729 | | // instruction to the `linear_function` worklist. |
730 | 0 | TypeSubstCloner::visitLinearFunctionInst(lfi); |
731 | 0 | auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi)); |
732 | 0 | context.getLinearFunctionInstWorklist().push_back(newLFI); |
733 | 0 | } |
734 | | }; |
735 | | |
736 | | /// Initialization helper function. |
737 | | /// |
738 | | /// Returns the substitution map used for type remapping. |
739 | | static SubstitutionMap getSubstitutionMap(SILFunction *original, |
740 | 5.25k | SILFunction *vjp) { |
741 | 5.25k | auto substMap = original->getForwardingSubstitutionMap(); |
742 | 5.25k | if (auto *vjpGenEnv = vjp->getGenericEnvironment()) { |
743 | 940 | auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap(); |
744 | 940 | substMap = SubstitutionMap::get( |
745 | 940 | vjpGenEnv->getGenericSignature(), QuerySubstitutionMap{vjpSubstMap}, |
746 | 940 | LookUpConformanceInSubstitutionMap(vjpSubstMap)); |
747 | 940 | } |
748 | 5.25k | return substMap; |
749 | 5.25k | } |
750 | | |
751 | | /// Initialization helper function. |
752 | | /// |
753 | | /// Returns the activity info for the given original function, autodiff indices, |
754 | | /// and VJP generic signature. |
755 | | static const DifferentiableActivityInfo & |
756 | | getActivityInfoHelper(ADContext &context, SILFunction *original, |
757 | 5.25k | const AutoDiffConfig &config, SILFunction *vjp) { |
758 | | // Get activity info of the original function. |
759 | 5.25k | auto &passManager = context.getPassManager(); |
760 | 5.25k | auto *activityAnalysis = |
761 | 5.25k | passManager.getAnalysis<DifferentiableActivityAnalysis>(); |
762 | 5.25k | auto &activityCollection = *activityAnalysis->get(original); |
763 | 5.25k | auto &activityInfo = activityCollection.getActivityInfo( |
764 | 5.25k | vjp->getLoweredFunctionType()->getSubstGenericSignature(), |
765 | 5.25k | AutoDiffDerivativeFunctionKind::VJP); |
766 | 5.25k | LLVM_DEBUG(activityInfo.dump(config, getADDebugStream())); |
767 | 5.25k | return activityInfo; |
768 | 5.25k | } |
769 | | |
770 | | VJPCloner::Implementation::Implementation(VJPCloner &cloner, ADContext &context, |
771 | | SILDifferentiabilityWitness *witness, |
772 | | SILFunction *vjp, |
773 | | DifferentiationInvoker invoker) |
774 | | : TypeSubstCloner(*vjp, *witness->getOriginalFunction(), |
775 | | getSubstitutionMap(witness->getOriginalFunction(), vjp)), |
776 | | cloner(cloner), context(context), |
777 | | original(witness->getOriginalFunction()), witness(witness), |
778 | | vjp(vjp), invoker(invoker), |
779 | | activityInfo(getActivityInfoHelper( |
780 | | context, original, witness->getConfig(), vjp)), |
781 | | loopInfo(context.getPassManager().getAnalysis<SILLoopAnalysis>() |
782 | | ->get(original)), |
783 | | pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp, |
784 | 5.25k | witness->getConfig(), activityInfo, loopInfo) { |
785 | | // Create empty pullback function. |
786 | 5.25k | pullback = createEmptyPullback(); |
787 | 5.25k | context.recordGeneratedFunction(pullback); |
788 | 5.25k | } |
789 | | |
790 | | VJPCloner::VJPCloner(ADContext &context, |
791 | | SILDifferentiabilityWitness *witness, SILFunction *vjp, |
792 | | DifferentiationInvoker invoker) |
793 | 5.25k | : impl(*new Implementation(*this, context, witness, vjp, invoker)) {} |
794 | | |
795 | 5.25k | VJPCloner::~VJPCloner() { delete &impl; } |
796 | | |
797 | 185k | ADContext &VJPCloner::getContext() const { return impl.context; } |
798 | 0 | SILModule &VJPCloner::getModule() const { return impl.getModule(); } |
799 | 180k | SILFunction &VJPCloner::getOriginal() const { return *impl.original; } |
800 | 5.10k | SILFunction &VJPCloner::getVJP() const { return *impl.vjp; } |
801 | 615k | SILFunction &VJPCloner::getPullback() const { return *impl.pullback; } |
802 | 144k | SILDifferentiabilityWitness *VJPCloner::getWitness() const { |
803 | 144k | return impl.witness; |
804 | 144k | } |
805 | 173k | const AutoDiffConfig &VJPCloner::getConfig() const { |
806 | 173k | return impl.getConfig(); |
807 | 173k | } |
808 | 3.81k | DifferentiationInvoker VJPCloner::getInvoker() const { return impl.invoker; } |
809 | 134k | LinearMapInfo &VJPCloner::getPullbackInfo() const { return impl.pullbackInfo; } |
810 | 2.35k | SILLoopInfo *VJPCloner::getLoopInfo() const { return impl.loopInfo; } |
811 | 152k | const DifferentiableActivityInfo &VJPCloner::getActivityInfo() const { |
812 | 152k | return impl.activityInfo; |
813 | 152k | } |
814 | | |
815 | 5.25k | SILFunction *VJPCloner::Implementation::createEmptyPullback() { |
816 | 5.25k | auto &module = context.getModule(); |
817 | 5.25k | auto origTy = original->getLoweredFunctionType(); |
818 | | // Get witness generic signature for remapping types. |
819 | | // Witness generic signature may have more requirements than VJP generic |
820 | | // signature: when witness generic signature has same-type requirements |
821 | | // binding all generic parameters to concrete types, VJP function type uses |
822 | | // all the concrete types and VJP generic signature is null. |
823 | 5.25k | auto witnessCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature(); |
824 | 5.25k | auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); |
825 | | |
826 | | // Given a type, returns its formal SIL parameter info. |
827 | 5.25k | auto getTangentParameterInfoForOriginalResult = |
828 | 5.25k | [&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo { |
829 | 4.98k | tanType = tanType->getReducedType(witnessCanGenSig); |
830 | 4.98k | Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType); |
831 | 4.98k | auto &tl = context.getTypeConverter().getTypeLowering( |
832 | 4.98k | pattern, tanType, TypeExpansionContext::minimal()); |
833 | 4.98k | ParameterConvention conv; |
834 | 4.98k | switch (origResConv) { |
835 | 3.28k | case ResultConvention::Unowned: |
836 | 3.28k | case ResultConvention::UnownedInnerPointer: |
837 | 3.76k | case ResultConvention::Owned: |
838 | 3.76k | case ResultConvention::Autoreleased: |
839 | 3.76k | if (tl.isAddressOnly()) { |
840 | 92 | conv = ParameterConvention::Indirect_In_Guaranteed; |
841 | 3.66k | } else { |
842 | 3.66k | conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned |
843 | 3.66k | : ParameterConvention::Direct_Guaranteed; |
844 | 3.66k | } |
845 | 3.76k | break; |
846 | 1.22k | case ResultConvention::Indirect: |
847 | 1.22k | conv = ParameterConvention::Indirect_In_Guaranteed; |
848 | 1.22k | break; |
849 | 0 | case ResultConvention::Pack: |
850 | 0 | conv = ParameterConvention::Pack_Guaranteed; |
851 | 0 | break; |
852 | 4.98k | } |
853 | 4.98k | return {tanType, conv}; |
854 | 4.98k | }; |
855 | | |
856 | | // Given a type, returns its formal SIL result info. |
857 | 5.25k | auto getTangentResultInfoForOriginalParameter = |
858 | 6.93k | [&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo { |
859 | 6.93k | tanType = tanType->getReducedType(witnessCanGenSig); |
860 | 6.93k | Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType); |
861 | 6.93k | auto &tl = context.getTypeConverter().getTypeLowering( |
862 | 6.93k | pattern, tanType, TypeExpansionContext::minimal()); |
863 | 6.93k | ResultConvention conv; |
864 | 6.93k | switch (origParamConv) { |
865 | 48 | case ParameterConvention::Direct_Owned: |
866 | 620 | case ParameterConvention::Direct_Guaranteed: |
867 | 5.20k | case ParameterConvention::Direct_Unowned: |
868 | 5.20k | if (tl.isAddressOnly()) { |
869 | 112 | conv = ResultConvention::Indirect; |
870 | 5.09k | } else { |
871 | 5.09k | conv = tl.isTrivial() ? ResultConvention::Unowned |
872 | 5.09k | : ResultConvention::Owned; |
873 | 5.09k | } |
874 | 5.20k | break; |
875 | 204 | case ParameterConvention::Indirect_In: |
876 | 204 | case ParameterConvention::Indirect_Inout: |
877 | 1.73k | case ParameterConvention::Indirect_In_Guaranteed: |
878 | 1.73k | case ParameterConvention::Indirect_InoutAliasable: |
879 | 1.73k | conv = ResultConvention::Indirect; |
880 | 1.73k | break; |
881 | 0 | case ParameterConvention::Pack_Guaranteed: |
882 | 0 | case ParameterConvention::Pack_Owned: |
883 | 0 | case ParameterConvention::Pack_Inout: |
884 | 0 | conv = ResultConvention::Pack; |
885 | 0 | break; |
886 | 6.93k | } |
887 | 6.93k | return {tanType, conv}; |
888 | 6.93k | }; |
889 | | |
890 | | // Parameters of the pullback are: |
891 | | // - the tangent vectors of the original results, and |
892 | | // - a pullback struct. |
893 | | // Results of the pullback are in the tangent space of the original |
894 | | // parameters. |
895 | 5.25k | SmallVector<SILParameterInfo, 8> pbParams; |
896 | 5.25k | SmallVector<SILResultInfo, 8> adjResults; |
897 | 5.25k | auto origParams = origTy->getParameters(); |
898 | 5.25k | auto config = witness->getConfig(); |
899 | | |
900 | | // Add pullback parameters based on original result indices. |
901 | 5.25k | SmallVector<unsigned, 4> semanticResultParamIndices; |
902 | 8.56k | for (auto i : range(origTy->getNumParameters())) { |
903 | 8.56k | auto origParam = origParams[i]; |
904 | 8.56k | if (!origParam.isAutoDiffSemanticResult()) |
905 | 8.18k | continue; |
906 | 388 | semanticResultParamIndices.push_back(i); |
907 | 388 | } |
908 | | |
909 | 5.36k | for (auto resultIndex : config.resultIndices->getIndices()) { |
910 | | // Handle formal result. |
911 | 5.36k | if (resultIndex < origTy->getNumResults()) { |
912 | 4.98k | auto origResult = origTy->getResults()[resultIndex]; |
913 | 4.98k | origResult = origResult.getWithInterfaceType( |
914 | 4.98k | origResult.getInterfaceType()->getReducedType(witnessCanGenSig)); |
915 | 4.98k | auto paramInfo = getTangentParameterInfoForOriginalResult( |
916 | 4.98k | origResult.getInterfaceType() |
917 | 4.98k | ->getAutoDiffTangentSpace(lookupConformance) |
918 | 4.98k | ->getType() |
919 | 4.98k | ->getReducedType(witnessCanGenSig), |
920 | 4.98k | origResult.getConvention()); |
921 | 4.98k | pbParams.push_back(paramInfo); |
922 | 4.98k | continue; |
923 | 4.98k | } |
924 | | |
925 | | // Handle semantic result parameter. |
926 | 384 | unsigned paramIndex = 0; |
927 | 384 | unsigned resultParamIndex = 0; |
928 | 592 | for (auto i : range(origTy->getNumParameters())) { |
929 | 592 | auto origParam = origTy->getParameters()[i]; |
930 | 592 | if (!origParam.isAutoDiffSemanticResult()) { |
931 | 168 | ++paramIndex; |
932 | 168 | continue; |
933 | 168 | } |
934 | 424 | if (resultParamIndex == resultIndex - origTy->getNumResults()) |
935 | 384 | break; |
936 | 40 | ++paramIndex; |
937 | 40 | ++resultParamIndex; |
938 | 40 | } |
939 | 384 | auto resultParam = origParams[paramIndex]; |
940 | 384 | auto origResult = resultParam.getWithInterfaceType( |
941 | 384 | resultParam.getInterfaceType()->getReducedType(witnessCanGenSig)); |
942 | | |
943 | 384 | auto resultParamTanConvention = resultParam.getConvention(); |
944 | 384 | if (!config.isWrtParameter(paramIndex)) |
945 | 0 | resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; |
946 | | |
947 | 384 | pbParams.emplace_back(origResult.getInterfaceType() |
948 | 384 | ->getAutoDiffTangentSpace(lookupConformance) |
949 | 384 | ->getType() |
950 | 384 | ->getReducedType(witnessCanGenSig), |
951 | 384 | resultParamTanConvention); |
952 | 384 | } |
953 | | |
954 | 5.25k | if (pullbackInfo.hasHeapAllocatedContext()) { |
955 | | // Accept a `AutoDiffLinarMapContext` heap object if there are loops. |
956 | 108 | pbParams.push_back({ |
957 | 108 | getASTContext().TheNativeObjectType, |
958 | 108 | ParameterConvention::Direct_Guaranteed |
959 | 108 | }); |
960 | 5.14k | } else { |
961 | | // Accept a pullback struct in the pullback parameter list. This is the |
962 | | // returned pullback's closure context. |
963 | 5.14k | auto *origExit = &*original->findReturnBB(); |
964 | 5.14k | auto pbTupleType = |
965 | 5.14k | pullbackInfo.getLinearMapTupleLoweredType(origExit).getAs<TupleType>(); |
966 | 5.14k | for (Type eltTy : pbTupleType->getElementTypes()) |
967 | 5.48k | pbParams.emplace_back(CanType(eltTy), ParameterConvention::Direct_Owned); |
968 | 5.14k | } |
969 | | |
970 | | // Add pullback results for the requested wrt parameters. |
971 | 7.32k | for (auto i : config.parameterIndices->getIndices()) { |
972 | 7.32k | auto origParam = origParams[i]; |
973 | 7.32k | if (origParam.isAutoDiffSemanticResult()) |
974 | 384 | continue; |
975 | 6.93k | origParam = origParam.getWithInterfaceType( |
976 | 6.93k | origParam.getInterfaceType()->getReducedType(witnessCanGenSig)); |
977 | 6.93k | adjResults.push_back(getTangentResultInfoForOriginalParameter( |
978 | 6.93k | origParam.getInterfaceType() |
979 | 6.93k | ->getAutoDiffTangentSpace(lookupConformance) |
980 | 6.93k | ->getType() |
981 | 6.93k | ->getReducedType(witnessCanGenSig), |
982 | 6.93k | origParam.getConvention())); |
983 | 6.93k | } |
984 | | |
985 | 5.25k | Mangle::DifferentiationMangler mangler; |
986 | 5.25k | auto pbName = mangler.mangleLinearMap( |
987 | 5.25k | original->getName(), AutoDiffLinearMapKind::Pullback, config); |
988 | | // Set pullback generic signature equal to VJP generic signature. |
989 | | // Do not use witness generic signature, which may have same-type requirements |
990 | | // binding all generic parameters to concrete types. |
991 | 5.25k | auto pbGenericSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); |
992 | 5.25k | auto *pbGenericEnv = pbGenericSig.getGenericEnvironment(); |
993 | 5.25k | auto pbType = SILFunctionType::get( |
994 | 5.25k | pbGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(), |
995 | 5.25k | origTy->getCalleeConvention(), pbParams, {}, adjResults, llvm::None, |
996 | 5.25k | origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(), |
997 | 5.25k | original->getASTContext()); |
998 | | |
999 | 5.25k | SILOptFunctionBuilder fb(context.getTransform()); |
1000 | 5.25k | auto linkage = vjp->isSerialized() ? SILLinkage::Public : SILLinkage::Private; |
1001 | 5.25k | auto *pullback = fb.createFunction( |
1002 | 5.25k | linkage, context.getASTContext().getIdentifier(pbName).str(), pbType, |
1003 | 5.25k | pbGenericEnv, original->getLocation(), original->isBare(), |
1004 | 5.25k | IsNotTransparent, vjp->isSerialized(), |
1005 | 5.25k | original->isDynamicallyReplaceable(), original->isDistributed(), |
1006 | 5.25k | original->isRuntimeAccessible()); |
1007 | 5.25k | pullback->setDebugScope(new (module) |
1008 | 5.25k | SILDebugScope(original->getLocation(), pullback)); |
1009 | | |
1010 | 5.25k | return pullback; |
1011 | 5.25k | } |
1012 | | |
1013 | | SILBasicBlock *VJPCloner::Implementation::createTrampolineBasicBlock( |
1014 | 1.44k | TermInst *termInst, TupleInst *pbTupleVal, SILBasicBlock *succBB) { |
1015 | 1.44k | assert(llvm::find(termInst->getSuccessorBlocks(), succBB) != |
1016 | 1.44k | termInst->getSuccessorBlocks().end() && |
1017 | 1.44k | "Basic block is not a successor of terminator instruction"); |
1018 | | // Create the trampoline block. |
1019 | 0 | auto *vjpSuccBB = getOpBasicBlock(succBB); |
1020 | 1.44k | auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); |
1021 | 1.44k | for (auto *arg : vjpSuccBB->getArguments().drop_back()) |
1022 | 452 | trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind()); |
1023 | | // In the trampoline block, build predecessor enum value for VJP successor |
1024 | | // block and branch to it. |
1025 | 1.44k | SILBuilder trampolineBuilder(trampolineBB); |
1026 | 1.44k | trampolineBuilder.setCurrentDebugScope(getOpScope(termInst->getDebugScope())); |
1027 | 1.44k | auto *origBB = termInst->getParent(); |
1028 | 1.44k | auto *succEnumVal = |
1029 | 1.44k | buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbTupleVal); |
1030 | 1.44k | SmallVector<SILValue, 4> forwardedArguments( |
1031 | 1.44k | trampolineBB->getArguments().begin(), trampolineBB->getArguments().end()); |
1032 | 1.44k | forwardedArguments.push_back(succEnumVal); |
1033 | 1.44k | trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB, |
1034 | 1.44k | forwardedArguments); |
1035 | 1.44k | return trampolineBB; |
1036 | 1.44k | } |
1037 | | |
1038 | | llvm::SmallVector<SILValue, 8> |
1039 | 7.15k | VJPCloner::Implementation::getPullbackValues(SILBasicBlock *origBB) { |
1040 | 7.15k | auto *vjpBB = BBMap[origBB]; |
1041 | 7.15k | auto bbPullbackValues = pullbackValues[origBB]; |
1042 | 7.15k | if (!origBB->isEntry()) { |
1043 | 1.91k | auto *predEnumArg = vjpBB->getArguments().back(); |
1044 | 1.91k | bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg); |
1045 | 1.91k | } |
1046 | | |
1047 | 7.15k | return bbPullbackValues; |
1048 | 7.15k | } |
1049 | | |
1050 | | TupleInst * |
1051 | 2.02k | VJPCloner::Implementation::buildPullbackValueTupleValue(TermInst *termInst) { |
1052 | 2.02k | assert(termInst->getFunction() == original); |
1053 | 0 | auto loc = RegularLocation::getAutoGeneratedLocation(); |
1054 | 2.02k | auto origBB = termInst->getParent(); |
1055 | 2.02k | auto tupleLoweredTy = |
1056 | 2.02k | remapType(pullbackInfo.getLinearMapTupleLoweredType(origBB)); |
1057 | 2.02k | auto bbPullbackValues = getPullbackValues(origBB); |
1058 | 2.02k | return getBuilder().createTuple(loc, tupleLoweredTy, bbPullbackValues); |
1059 | 2.02k | } |
1060 | | |
1061 | | EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue( |
1062 | | SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB, |
1063 | 2.63k | SILValue pbTupleVal) { |
1064 | 2.63k | auto loc = RegularLocation::getAutoGeneratedLocation(); |
1065 | 2.63k | auto enumLoweredTy = |
1066 | 2.63k | remapType(pullbackInfo.getBranchingTraceEnumLoweredType(succBB)); |
1067 | 2.63k | auto *enumEltDecl = |
1068 | 2.63k | pullbackInfo.lookUpBranchingTraceEnumElement(predBB, succBB); |
1069 | 2.63k | auto enumEltType = getOpType(enumLoweredTy.getEnumElementType( |
1070 | 2.63k | enumEltDecl, getModule(), TypeExpansionContext::minimal())); |
1071 | | // If the predecessor block is in a loop, its predecessor enum payload is a |
1072 | | // `Builtin.RawPointer`. |
1073 | 2.63k | if (loopInfo->getLoopFor(predBB)) { |
1074 | 396 | auto rawPtrType = SILType::getRawPointerType(getASTContext()); |
1075 | 396 | assert(enumEltType == rawPtrType); |
1076 | 0 | auto pbTupleType = |
1077 | 396 | remapASTType(pullbackInfo.getLinearMapTupleType(predBB)->getCanonicalType()); |
1078 | | |
1079 | 396 | auto pbTupleMetatypeType = |
1080 | 396 | CanMetatypeType::get(pbTupleType, MetatypeRepresentation::Thick); |
1081 | 396 | auto pbTupleMetatypeSILType = |
1082 | 396 | SILType::getPrimitiveObjectType(pbTupleMetatypeType); |
1083 | 396 | auto pbTupleMetatype = |
1084 | 396 | Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType); |
1085 | | |
1086 | 396 | auto rawBufferValue = builder.createBuiltin( |
1087 | 396 | loc, |
1088 | 396 | getASTContext().getIdentifier(getBuiltinName( |
1089 | 396 | BuiltinValueKind::AutoDiffAllocateSubcontextWithType)), |
1090 | 396 | rawPtrType, SubstitutionMap(), |
1091 | 396 | {borrowedPullbackContextValue, pbTupleMetatype}); |
1092 | | |
1093 | 396 | auto typedBufferValue = |
1094 | 396 | builder.createPointerToAddress( |
1095 | 396 | loc, rawBufferValue, pbTupleVal->getType().getAddressType(), |
1096 | 396 | /*isStrict*/ true); |
1097 | 396 | builder.createStore( |
1098 | 396 | loc, pbTupleVal, typedBufferValue, |
1099 | 396 | pbTupleVal->getType().isTrivial(*pullback) ? |
1100 | 296 | StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); |
1101 | 396 | return builder.createEnum(loc, rawBufferValue, enumEltDecl, enumLoweredTy); |
1102 | 396 | } |
1103 | 2.23k | return builder.createEnum(loc, pbTupleVal, enumEltDecl, enumLoweredTy); |
1104 | 2.63k | } |
1105 | | |
1106 | 5.25k | bool VJPCloner::Implementation::run() { |
1107 | 5.25k | PrettyStackTraceSILFunction trace("generating VJP for", original); |
1108 | 5.25k | LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() |
1109 | 5.25k | << " to vjp @" << vjp->getName() << '\n'); |
1110 | | |
1111 | | // Create entry BB and arguments. |
1112 | 5.25k | auto *entry = vjp->createBasicBlock(); |
1113 | 5.25k | createEntryArguments(vjp); |
1114 | | |
1115 | 5.25k | emitLinearMapContextInitializationIfNeeded(); |
1116 | | |
1117 | | // Clone. |
1118 | 5.25k | SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(), |
1119 | 5.25k | entry->getArguments().end()); |
1120 | 5.25k | cloneFunctionBody(original, entry, entryArgs); |
1121 | | // If errors occurred, back out. |
1122 | 5.25k | if (errorOccurred) |
1123 | 20 | return true; |
1124 | | |
1125 | | // Merge VJP basic blocks. This is significant for control flow |
1126 | | // differentiation: trampoline destination bbs are merged into trampoline bbs. |
1127 | | // NOTE(TF-990): Merging basic blocks ensures that `@guaranteed` trampoline |
1128 | | // bb arguments have a lifetime-ending `end_borrow` use, and is robust when |
1129 | | // `-enable-strip-ownership-after-serialization` is true. |
1130 | 5.23k | mergeBasicBlocks(vjp); |
1131 | | |
1132 | 5.23k | LLVM_DEBUG(getADDebugStream() |
1133 | 5.23k | << "Generated VJP for " << original->getName() << ":\n" |
1134 | 5.23k | << *vjp); |
1135 | | |
1136 | | // Generate pullback code. |
1137 | 5.23k | PullbackCloner PullbackCloner(cloner); |
1138 | 5.23k | if (PullbackCloner.run()) { |
1139 | 132 | errorOccurred = true; |
1140 | 132 | return true; |
1141 | 132 | } |
1142 | 5.10k | return errorOccurred; |
1143 | 5.23k | } |
1144 | | |
1145 | 5.25k | bool VJPCloner::run() { |
1146 | 5.25k | bool foundError = impl.run(); |
1147 | 5.25k | #ifndef NDEBUG |
1148 | 5.25k | if (!foundError) |
1149 | 5.10k | getVJP().verify(); |
1150 | 5.25k | #endif |
1151 | 5.25k | return foundError; |
1152 | 5.25k | } |
1153 | | |
1154 | | } // end namespace autodiff |
1155 | | } // end namespace swift |