/Volumes/compiler/apple/swift/lib/SILOptimizer/Mandatory/Differentiation.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- Differentiation.cpp - SIL Automatic Differentiation --*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2018 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // This file implements automatic differentiation. |
14 | | // |
15 | | //===----------------------------------------------------------------------===// |
16 | | |
17 | | #define DEBUG_TYPE "differentiation" |
18 | | |
19 | | #include "swift/AST/ASTMangler.h" |
20 | | #include "swift/AST/ASTPrinter.h" |
21 | | #include "swift/AST/AnyFunctionRef.h" |
22 | | #include "swift/AST/AutoDiff.h" |
23 | | #include "swift/AST/Builtins.h" |
24 | | #include "swift/AST/DeclContext.h" |
25 | | #include "swift/AST/DiagnosticsSIL.h" |
26 | | #include "swift/AST/Expr.h" |
27 | | #include "swift/AST/GenericEnvironment.h" |
28 | | #include "swift/AST/LazyResolver.h" |
29 | | #include "swift/AST/ParameterList.h" |
30 | | #include "swift/AST/SourceFile.h" |
31 | | #include "swift/AST/SubstitutionMap.h" |
32 | | #include "swift/AST/TypeCheckRequests.h" |
33 | | #include "swift/SIL/FormalLinkage.h" |
34 | | #include "swift/SIL/PrettyStackTrace.h" |
35 | | #include "swift/SIL/SILBuilder.h" |
36 | | #include "swift/SIL/TypeSubstCloner.h" |
37 | | #include "swift/SILOptimizer/Analysis/DominanceAnalysis.h" |
38 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
39 | | #include "swift/SILOptimizer/Differentiation/JVPCloner.h" |
40 | | #include "swift/SILOptimizer/Differentiation/Thunk.h" |
41 | | #include "swift/SILOptimizer/Differentiation/VJPCloner.h" |
42 | | #include "swift/SILOptimizer/PassManager/Passes.h" |
43 | | #include "swift/SILOptimizer/PassManager/Transforms.h" |
44 | | #include "swift/SILOptimizer/Utils/DifferentiationMangler.h" |
45 | | #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
46 | | #include "llvm/ADT/APSInt.h" |
47 | | #include "llvm/ADT/BreadthFirstIterator.h" |
48 | | #include "llvm/ADT/DenseSet.h" |
49 | | #include "llvm/ADT/SmallSet.h" |
50 | | #include "llvm/Support/CommandLine.h" |
51 | | |
52 | | using namespace swift; |
53 | | using namespace swift::autodiff; |
54 | | using llvm::DenseMap; |
55 | | using llvm::SmallDenseMap; |
56 | | using llvm::SmallDenseSet; |
57 | | using llvm::SmallMapVector; |
58 | | using llvm::SmallSet; |
59 | | |
60 | | /// This flag enables experimental `@differentiable(_linear)` function |
61 | | /// transposition. |
62 | | static llvm::cl::opt<bool> EnableExperimentalLinearMapTransposition( |
63 | | "enable-experimental-linear-map-transposition", llvm::cl::init(false)); |
64 | | |
65 | | //===----------------------------------------------------------------------===// |
66 | | // Helpers |
67 | | //===----------------------------------------------------------------------===// |
68 | | |
69 | | /// Given a dumpable value, dumps it to `llvm::dbgs()`. |
70 | 24.3k | template <typename T> static inline void debugDump(T &v) { |
71 | 24.3k | LLVM_DEBUG(llvm::dbgs() << "\n==== BEGIN DEBUG DUMP ====\n" |
72 | 24.3k | << v << "\n==== END DEBUG DUMP ====\n"); |
73 | 24.3k | } |
74 | | |
75 | | namespace { |
76 | | |
77 | | class DifferentiationTransformer { |
78 | | private: |
79 | | /// Reference to the main transform. |
80 | | SILModuleTransform &transform; |
81 | | |
82 | | /// Context necessary for performing the transformations. |
83 | | ADContext context; |
84 | | |
85 | | /// Promotes the given `differentiable_function` instruction to a valid |
86 | | /// `@differentiable` function-typed value. |
87 | | SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst, |
88 | | SILBuilder &builder, SILLocation loc, |
89 | | DifferentiationInvoker invoker); |
90 | | |
91 | | /// Given a `linear_function` instruction that is missing a transpose operand, |
92 | | /// return a new `linear_function` instruction with the transpose filled in. |
93 | | SILValue promoteToLinearFunction(LinearFunctionInst *inst, |
94 | | SILBuilder &builder, SILLocation loc, |
95 | | DifferentiationInvoker invoker); |
96 | | |
97 | | public: |
98 | | /// Construct an `DifferentiationTransformer` for the given module. |
99 | | explicit DifferentiationTransformer(SILModuleTransform &transform) |
100 | 24.3k | : transform(transform), context(transform) {} |
101 | | |
102 | 4 | SILModuleTransform &getTransform() { return transform; } |
103 | | |
104 | 59.2k | ADContext &getContext() { return context; } |
105 | | |
106 | | /// Canonicalize the given witness, filling in derivative functions if |
107 | | /// missing. |
108 | | /// |
109 | | /// Generated derivative functions have the same linkage as the witness. |
110 | | /// |
111 | | /// \param serializeFunctions specifies whether generated functions should be |
112 | | /// serialized. |
113 | | bool canonicalizeDifferentiabilityWitness( |
114 | | SILDifferentiabilityWitness *witness, DifferentiationInvoker invoker, |
115 | | IsSerialized_t serializeFunctions); |
116 | | |
117 | | /// Process the given `differentiable_function` instruction, filling in |
118 | | /// missing derivative functions if necessary. |
119 | | bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi); |
120 | | |
121 | | /// Process the given `linear_function` instruction, filling in the missing |
122 | | /// transpose function if necessary. |
123 | | bool processLinearFunctionInst(LinearFunctionInst *lfi); |
124 | | }; |
125 | | |
126 | | } // end anonymous namespace |
127 | | |
128 | | /// If the original function doesn't have a return, it cannot be differentiated. |
129 | | /// Returns true if error is emitted. |
130 | | static bool diagnoseNoReturn(ADContext &context, SILFunction *original, |
131 | 6.62k | DifferentiationInvoker invoker) { |
132 | 6.62k | if (original->findReturnBB() != original->end()) |
133 | 6.62k | return false; |
134 | 4 | context.emitNondifferentiabilityError( |
135 | 4 | original->getLocation().getEndSourceLoc(), invoker, |
136 | 4 | diag::autodiff_missing_return); |
137 | 4 | return true; |
138 | 6.62k | } |
139 | | |
140 | | /// If the original function contains unsupported control flow, emit a "control |
141 | | /// flow unsupported" error at appropriate source locations. Returns true if |
142 | | /// error is emitted. |
143 | | /// |
144 | | /// Update as control flow support is added. |
145 | | static bool diagnoseUnsupportedControlFlow(ADContext &context, |
146 | | SILFunction *original, |
147 | 6.62k | DifferentiationInvoker invoker) { |
148 | 6.62k | if (original->size() <= 1) |
149 | 6.10k | return false; |
150 | | // Diagnose unsupported branching terminators. |
151 | 2.54k | for (auto &bb : *original) { |
152 | 2.54k | auto *term = bb.getTerminator(); |
153 | | // Check supported branching terminators. |
154 | 2.54k | if (isa<BranchInst>(term) || isa<CondBranchInst>(term) || |
155 | 2.54k | isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) || |
156 | 2.54k | isa<CheckedCastBranchInst>(term) || |
157 | 2.54k | isa<CheckedCastAddrBranchInst>(term) || isa<TryApplyInst>(term)) |
158 | 1.95k | continue; |
159 | | // If terminator is an unsupported branching terminator, emit an error. |
160 | 588 | if (term->isBranch()) { |
161 | 0 | context.emitNondifferentiabilityError( |
162 | 0 | term, invoker, diag::autodiff_control_flow_not_supported); |
163 | 0 | return true; |
164 | 0 | } |
165 | 588 | } |
166 | 520 | return false; |
167 | 520 | } |
168 | | |
169 | | /// Check whether the given requirements are satisfied, with the given |
170 | | /// derivative generic signature (containing requirements), and substitution |
171 | | /// map. Returns true if error is emitted. |
172 | | static bool diagnoseUnsatisfiedRequirements(ADContext &context, |
173 | | CanSILFunctionType origFnTy, |
174 | | GenericSignature derivativeGenSig, |
175 | | SubstitutionMap substMap, |
176 | | DifferentiationInvoker invoker, |
177 | 22.5k | SourceLoc loc) { |
178 | | // If the original function is polymorphic and its generic signature is the |
179 | | // same as the derivative generic signature, then the requirements are |
180 | | // satisfied. This check is necessary because the subsequent logic does not |
181 | | // correctly handle polymorphic original functions. |
182 | | // TODO(TF-1055): Can be removed after we have a robust solution for TF-1055. |
183 | 22.5k | if (origFnTy->getInvocationGenericSignature() && derivativeGenSig && |
184 | 22.5k | origFnTy->getInvocationGenericSignature()->isEqual(derivativeGenSig)) |
185 | 192 | return false; |
186 | | |
187 | | // If there are no derivative requirements, return false. |
188 | 22.3k | auto requirements = derivativeGenSig.getRequirements(); |
189 | 22.3k | if (requirements.empty()) |
190 | 16.4k | return false; |
191 | | // Iterate through all requirements and check whether they are satisfied. |
192 | 5.93k | auto *swiftModule = context.getModule().getSwiftModule(); |
193 | 5.93k | SmallVector<Requirement, 2> unsatisfiedRequirements; |
194 | 13.8k | for (auto req : requirements) { |
195 | 13.8k | auto firstType = req.getFirstType(); |
196 | 13.8k | Type secondType; |
197 | | // Substitute first and second types using the given substitution map, |
198 | | // looking up conformances in the current module, if possible. |
199 | 13.8k | if (auto substFirstType = |
200 | 13.8k | firstType.subst(QuerySubstitutionMap{substMap}, |
201 | 13.8k | LookUpConformanceInModule(swiftModule))) { |
202 | 13.8k | firstType = substFirstType; |
203 | 13.8k | } |
204 | 13.8k | if (req.getKind() != RequirementKind::Layout) { |
205 | 13.7k | secondType = req.getSecondType(); |
206 | 13.7k | if (auto substSecondType = |
207 | 13.7k | secondType.subst(QuerySubstitutionMap{substMap}, |
208 | 13.7k | LookUpConformanceInModule(swiftModule))) { |
209 | 13.7k | secondType = substSecondType; |
210 | 13.7k | } |
211 | 13.7k | } |
212 | 13.8k | switch (req.getKind()) { |
213 | 0 | case RequirementKind::SameShape: |
214 | 0 | llvm_unreachable("Same-shape requirement not supported here"); |
215 | | |
216 | | // Check layout requirements. |
217 | 16 | case RequirementKind::Layout: { |
218 | 16 | auto layout = req.getLayoutConstraint(); |
219 | 16 | switch (layout->getKind()) { |
220 | 16 | case LayoutConstraintKind::Class: |
221 | 16 | if (!firstType->satisfiesClassConstraint()) |
222 | 0 | unsatisfiedRequirements.push_back(req); |
223 | 16 | continue; |
224 | 0 | default: |
225 | | // TODO: Check other layout requirements. Note that `@differentiable` |
226 | | // attribute type-checking does not yet support layout requirements in |
227 | | // where clauses; layout requirements in derivative generic signatures |
228 | | // can be formed only from `differentiable_function` instructions whose |
229 | | // original function operand is generic with layout requirements. |
230 | 0 | break; |
231 | 16 | } |
232 | 0 | continue; |
233 | 16 | } |
234 | | // Check same type requirements. |
235 | 3.48k | case RequirementKind::SameType: |
236 | | // If the first type does not equal the second type, then record the |
237 | | // unsatisfied requirement. |
238 | 3.48k | if (!firstType->isEqual(secondType)) |
239 | 0 | unsatisfiedRequirements.push_back(req); |
240 | 3.48k | continue; |
241 | | // Check superclass requirements. |
242 | 48 | case RequirementKind::Superclass: { |
243 | | // If the second type is not an exact superclass of second type, then |
244 | | // record the unsatisfied requirement. |
245 | 48 | if (!secondType->isExactSuperclassOf(firstType)) |
246 | 0 | unsatisfiedRequirements.push_back(req); |
247 | 48 | continue; |
248 | 16 | } |
249 | | // Check conformance requirements. |
250 | 10.2k | case RequirementKind::Conformance: { |
251 | 10.2k | auto *protocol = req.getProtocolDecl(); |
252 | 10.2k | assert(protocol && "Expected protocol in generic signature requirement"); |
253 | | // If the first type does not conform to the second type in the current |
254 | | // module, then record the unsatisfied requirement. |
255 | 10.2k | if (!swiftModule->lookupConformance(firstType, protocol)) |
256 | 8 | unsatisfiedRequirements.push_back(req); |
257 | 10.2k | continue; |
258 | 16 | } |
259 | 13.8k | } |
260 | 13.8k | } |
261 | 5.93k | if (unsatisfiedRequirements.empty()) |
262 | 5.92k | return false; |
263 | | // Diagnose unsatisfied requirements. |
264 | 4 | std::string reqText; |
265 | 4 | llvm::raw_string_ostream stream(reqText); |
266 | 4 | interleave( |
267 | 4 | unsatisfiedRequirements, |
268 | 8 | [&](Requirement req) { req.print(stream, PrintOptions()); }, |
269 | 4 | [&] { stream << ", "; }); |
270 | 4 | context.emitNondifferentiabilityError( |
271 | 4 | loc, invoker, diag::autodiff_function_assoc_func_unmet_requirements, |
272 | 4 | stream.str()); |
273 | 4 | return true; |
274 | 5.93k | } |
275 | | |
276 | | //===----------------------------------------------------------------------===// |
277 | | // Code emission utilities |
278 | | //===----------------------------------------------------------------------===// |
279 | | |
280 | | /// Given an apply site, emit copies of all parameters and place them in |
281 | | /// `copiedArgs`. Any buffers that need to be destroyed will be added to |
282 | | /// `newArgsToDestroy`. Any new buffers that need to be deallocated will be |
283 | | /// added to `newBuffersToDealloc`. This helper is used for duplicating an |
284 | | /// apply site. |
285 | | static void copyParameterArgumentsForApply( |
286 | | ApplySite applySite, SmallVectorImpl<SILValue> &copiedArgs, |
287 | | SmallVectorImpl<SILValue> &newArgsToDestroy, |
288 | 6.97k | SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) { |
289 | 6.97k | LLVM_DEBUG({ |
290 | 6.97k | auto &s = getADDebugStream() << "Copying arguments from apply site: "; |
291 | 6.97k | applySite.getInstruction()->print(s); |
292 | 6.97k | }); |
293 | 6.97k | auto loc = applySite.getLoc(); |
294 | 6.97k | copiedArgs.reserve(applySite.getNumArguments()); |
295 | 6.97k | SILBuilderWithScope copyBuilder(applySite.getInstruction()); |
296 | 6.97k | for (auto &argOperand : applySite.getArgumentOperands()) { |
297 | 756 | auto arg = argOperand.get(); |
298 | 756 | auto argConv = applySite.getArgumentConvention(argOperand); |
299 | 756 | auto collectNewArg = [&](SILValue newArg) { |
300 | 756 | copiedArgs.push_back(newArg); |
301 | 756 | if (argConv.isGuaranteedConvention() && |
302 | 756 | argConv != SILArgumentConvention::Indirect_InoutAliasable) |
303 | 624 | newArgsToDestroy.push_back(newArg); |
304 | 756 | }; |
305 | | // Copy the argument if it's to be owned by the newly created closure. |
306 | | // Objects are to be retained. |
307 | 756 | if (arg->getType().isObject()) { |
308 | 660 | auto newArg = arg; |
309 | 660 | if (newArg->getOwnershipKind() != OwnershipKind::None) |
310 | 520 | newArg = copyBuilder.emitCopyValueOperation(loc, arg); |
311 | 660 | collectNewArg(newArg); |
312 | 660 | continue; |
313 | 660 | } |
314 | | // Addresses depend on argument conventions. |
315 | | // If the argument is an aliasable inout reference, do not copy the |
316 | | // argument since it's a `@noescape` capture. |
317 | 96 | if (argConv == SILArgumentConvention::Indirect_InoutAliasable) { |
318 | 0 | collectNewArg(arg); |
319 | 0 | continue; |
320 | 0 | } |
321 | | // Otherwise, it must be address-only. Create a new buffer and perform |
322 | | // `copy_addr`. |
323 | 96 | auto *argCopy = copyBuilder.createAllocStack(loc, arg->getType()); |
324 | 96 | newBuffersToDealloc.push_back(argCopy); |
325 | 96 | copyBuilder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization); |
326 | 96 | collectNewArg(argCopy); |
327 | 96 | } |
328 | 6.97k | } |
329 | | |
330 | | /// When a function value is used in an instruction (usually `apply`), there may |
331 | | /// be conversion instructions in between, e.g. `thin_to_thick_function`. Given |
332 | | /// a new function value and an old function value, this helper function |
333 | | /// recursively converts the new function just like how the old function is |
334 | | /// converted. |
335 | | /// |
336 | | /// If the new function's generic signature is specified, it is used |
337 | | /// to create substitution maps for reapplied `partial_apply` instructions. |
338 | | static SILValue reapplyFunctionConversion( |
339 | | ADContext &context, SILValue newFunc, SILValue oldFunc, |
340 | | SILValue oldConvertedFunc, SILBuilder &builder, SILLocation loc, |
341 | | SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc, |
342 | | IndexSubset *parameterIndices, IndexSubset *resultIndices, |
343 | 37.3k | GenericSignature newFuncGenSig = GenericSignature()) { |
344 | | // If the old func is the new func, then there's no conversion. |
345 | 37.3k | if (oldFunc == oldConvertedFunc) |
346 | 23.1k | return newFunc; |
347 | | // Handle a few instruction cases. |
348 | | // copy_value |
349 | 14.1k | if (auto *cvi = dyn_cast<CopyValueInst>(oldConvertedFunc)) { |
350 | | // Note: no `copy_value` is needed for the re-converted function because the |
351 | | // caller of `reapplyFunctionConversion` should consume the re-converted |
352 | | // function. |
353 | 704 | return reapplyFunctionConversion( |
354 | 704 | context, newFunc, oldFunc, cvi->getOperand(), builder, loc, |
355 | 704 | newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig); |
356 | 704 | } |
357 | | // begin_borrow |
358 | 13.4k | if (auto *bbi = dyn_cast<BeginBorrowInst>(oldConvertedFunc)) { |
359 | | // Note: no `begin_borrow` is needed for the re-converted function because |
360 | | // the caller of `reapplyFunctionConversion` should consume the re-converted |
361 | | // function. |
362 | 0 | return reapplyFunctionConversion( |
363 | 0 | context, newFunc, oldFunc, bbi->getOperand(), builder, loc, |
364 | 0 | newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig); |
365 | 0 | } |
366 | | // convert_function |
367 | 13.4k | if (auto *cfi = dyn_cast<ConvertFunctionInst>(oldConvertedFunc)) { |
368 | 208 | return reapplyFunctionConversion( |
369 | 208 | context, newFunc, oldFunc, cfi->getOperand(), builder, loc, |
370 | 208 | newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig); |
371 | 208 | } |
372 | | // thin_to_thick_function |
373 | 13.2k | if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) { |
374 | 6.23k | auto innerNewFunc = reapplyFunctionConversion( |
375 | 6.23k | context, newFunc, oldFunc, tttfi->getOperand(), builder, loc, |
376 | 6.23k | newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig); |
377 | 6.23k | auto operandFnTy = innerNewFunc->getType().castTo<SILFunctionType>(); |
378 | 6.23k | auto thickTy = operandFnTy->getWithRepresentation( |
379 | 6.23k | SILFunctionTypeRepresentation::Thick); |
380 | 6.23k | auto silTy = SILType::getPrimitiveObjectType(thickTy); |
381 | 6.23k | return builder.createThinToThickFunction(loc, innerNewFunc, silTy); |
382 | 6.23k | } |
383 | | // partial_apply |
384 | 6.96k | if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) { |
385 | 6.96k | SmallVector<SILValue, 8> newArgs; |
386 | 6.96k | newArgs.reserve(pai->getNumArguments()); |
387 | 6.96k | SmallVector<SILValue, 1> newArgsToDestroy; |
388 | 6.96k | copyParameterArgumentsForApply(pai, newArgs, newArgsToDestroy, |
389 | 6.96k | newBuffersToDealloc); |
390 | 6.96k | auto innerNewFunc = reapplyFunctionConversion( |
391 | 6.96k | context, newFunc, oldFunc, pai->getCallee(), builder, loc, |
392 | 6.96k | newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig); |
393 | | // Reabstraction thunk `partial_apply` reapplications require special |
394 | | // support. Reabstraction thunk JVP/VJP expects a `@differentiable` |
395 | | // function-typed argument to avoid opaque function non-differentiability |
396 | | // errors. Thus, `partial_apply` reapplications must first form a |
397 | | // `differentiable_function` of the function-typed thunk argument. |
398 | 6.96k | auto isReabstractionThunkCallee = [&]() -> bool { |
399 | 6.96k | auto *fri = dyn_cast<FunctionRefInst>(oldFunc); |
400 | 6.96k | return fri && fri->getReferencedFunction()->isThunk() == |
401 | 6.57k | IsReabstractionThunk; |
402 | 6.96k | }; |
403 | 6.96k | if (isReabstractionThunkCallee()) { |
404 | 352 | assert(newArgs.size() == 1 && |
405 | 352 | "Expected reabstraction thunk to be partially applied with only " |
406 | 352 | "one argument"); |
407 | 0 | auto *dfi = context.createDifferentiableFunction( |
408 | 352 | builder, loc, parameterIndices, resultIndices, newArgs.back()); |
409 | 352 | context.getDifferentiableFunctionInstWorklist().push_back(dfi); |
410 | 352 | newArgs.back() = dfi; |
411 | 352 | } |
412 | | // Compute substitution map for reapplying `partial_apply`. |
413 | | // - If reapplied function is not polymorphic, use empty substitution map |
414 | | // regardless of the original `partial_apply`'s substitution map. |
415 | | // - This case is triggered for reapplying `partial_apply` where `newFunc` |
416 | | // is a `differentiability_witness_function` where the witness generic |
417 | | // signature has all concrete parameters while the original function's |
418 | | // generic signature does not. In this case, the original function type |
419 | | // is polymorphic while derivative function types are not (specialized |
420 | | // with concrete types from same-type requirements). |
421 | | // - Otherwise, if `newFuncGenSig` is not specified, use the original |
422 | | // `partial_apply`'s substitution map. |
423 | | // - Otherwise, if `newFuncGenSig` is specified, combine it with the |
424 | | // original `partial_apply`'s substitution map. |
425 | 0 | SubstitutionMap substMap; |
426 | 6.96k | if (innerNewFunc->getType().castTo<SILFunctionType>()->isPolymorphic()) { |
427 | 6.35k | if (!newFuncGenSig) { |
428 | 384 | substMap = pai->getSubstitutionMap(); |
429 | 5.96k | } else { |
430 | 5.96k | substMap = SubstitutionMap::get( |
431 | 5.96k | newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()}, |
432 | 5.96k | LookUpConformanceInModule(builder.getModule().getSwiftModule())); |
433 | 5.96k | } |
434 | 6.35k | } |
435 | 6.96k | return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs, |
436 | 6.96k | ParameterConvention::Direct_Guaranteed); |
437 | 6.96k | } |
438 | 0 | llvm_unreachable("Unhandled function conversion instruction"); |
439 | 0 | } |
440 | | |
441 | | /// Emits a reference to a derivative function of `original`, differentiated |
442 | | /// with respect to a superset of `desiredIndices`. Returns the `SILValue` for |
443 | | /// the derivative function and the actual indices that the derivative function |
444 | | /// is with respect to. |
445 | | /// |
446 | | /// Returns `None` on failure, signifying that a diagnostic has been emitted |
447 | | /// using `invoker`. |
448 | | static llvm::Optional<std::pair<SILValue, AutoDiffConfig>> |
449 | | emitDerivativeFunctionReference( |
450 | | DifferentiationTransformer &transformer, SILBuilder &builder, |
451 | | const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind, |
452 | | SILValue original, DifferentiationInvoker invoker, |
453 | 23.2k | SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) { |
454 | 23.2k | ADContext &context = transformer.getContext(); |
455 | | |
456 | | // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind |
457 | | // matches the given kind and desired differentiation parameter indices, |
458 | | // simply extract the derivative function of its function operand, retain the |
459 | | // derivative function, and return it. |
460 | 23.2k | if (auto *inst = original->getDefiningInstruction()) |
461 | 23.2k | if (auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(inst)) |
462 | 8 | if (dfei->getExtractee() == |
463 | 8 | NormalDifferentiableFunctionTypeComponent::Original) |
464 | 8 | original = dfei->getOperand(); |
465 | | |
466 | | // If `original` is a `@differentiable` function, just extract the |
467 | | // derivative function. |
468 | 23.2k | if (auto diffableFnType = original->getType().castTo<SILFunctionType>()) { |
469 | 23.2k | if (diffableFnType->isDifferentiable()) { |
470 | 8 | auto paramIndices = |
471 | 8 | diffableFnType->getDifferentiabilityParameterIndices(); |
472 | 8 | for (auto i : desiredConfig.parameterIndices->getIndices()) { |
473 | 8 | if (!paramIndices->contains(i)) { |
474 | 0 | context.emitNondifferentiabilityError( |
475 | 0 | original, invoker, |
476 | 0 | diag:: |
477 | 0 | autodiff_function_noderivative_parameter_not_differentiable); |
478 | 0 | return llvm::None; |
479 | 0 | } |
480 | 8 | } |
481 | 8 | auto borrowedDiffFunc = |
482 | 8 | builder.emitBeginBorrowOperation(original.getLoc(), original); |
483 | 8 | SILValue derivativeFn = builder.createDifferentiableFunctionExtract( |
484 | 8 | borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc); |
485 | 8 | if (derivativeFn->getOwnershipKind() != OwnershipKind::None) |
486 | 0 | derivativeFn = |
487 | 0 | builder.emitCopyValueOperation(original.getLoc(), derivativeFn); |
488 | 8 | builder.emitEndBorrowOperation(original.getLoc(), borrowedDiffFunc); |
489 | 8 | return std::make_pair(derivativeFn, desiredConfig); |
490 | 8 | } |
491 | 23.2k | } |
492 | | |
493 | | // Handle `function_ref` original function. |
494 | 23.2k | if (auto *originalFRI = |
495 | 23.2k | peerThroughFunctionConversions<FunctionRefInst>(original)) { |
496 | 22.5k | auto loc = originalFRI->getLoc(); |
497 | 22.5k | auto *originalFn = originalFRI->getReferencedFunction(); |
498 | 22.5k | auto originalFnTy = originalFn->getLoweredFunctionType(); |
499 | 22.5k | auto *desiredParameterIndices = desiredConfig.parameterIndices; |
500 | 22.5k | auto *desiredResultIndices = desiredConfig.resultIndices; |
501 | | // NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has |
502 | | // parameters corresponding to captured variables. |
503 | | // TODO: If possible, change `autodiff::getLoweredParameterIndices` to |
504 | | // take `CaptureInfo` into account. |
505 | 22.5k | if (originalFnTy->getNumParameters() > |
506 | 22.5k | desiredParameterIndices->getCapacity()) { |
507 | 712 | desiredParameterIndices = desiredParameterIndices->extendingCapacity( |
508 | 712 | context.getASTContext(), originalFnTy->getNumParameters()); |
509 | 712 | } |
510 | | // Look up a differentiability witness with the exact configuration. |
511 | 22.5k | auto *minimalWitness = getExactDifferentiabilityWitness( |
512 | 22.5k | context.getModule(), originalFn, desiredParameterIndices, |
513 | 22.5k | desiredResultIndices); |
514 | | // Otherwise, look up a differentiability witness with a minimal superset |
515 | | // configuration. |
516 | 22.5k | if (!minimalWitness) |
517 | 6.34k | minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness( |
518 | 6.34k | context.getModule(), originalFn, DifferentiabilityKind::Reverse, |
519 | 6.34k | desiredParameterIndices, desiredResultIndices); |
520 | | // If no minimal witness exists, check non-differentiable cases before |
521 | | // creating a new private differentiability witness. |
522 | 22.5k | if (!minimalWitness) { |
523 | | // If the function is intentionally marked as being opaque to |
524 | | // differentiation, then we should not create a task for it. |
525 | 3.40k | if (originalFn->hasSemanticsAttr("autodiff.opaque")) { |
526 | 0 | context.emitNondifferentiabilityError( |
527 | 0 | original, invoker, |
528 | 0 | diag::autodiff_opaque_function_not_differentiable); |
529 | 0 | return llvm::None; |
530 | 0 | } |
531 | | // Check and diagnose non-differentiable arguments. |
532 | 3.40k | auto originalFnTy = originalFn->getLoweredFunctionType(); |
533 | 5.63k | for (unsigned paramIndex : range(originalFnTy->getNumParameters())) { |
534 | 5.63k | if (desiredConfig.isWrtParameter(paramIndex) && |
535 | 5.63k | !originalFnTy->getParameters()[paramIndex] |
536 | 4.90k | .getSILStorageInterfaceType() |
537 | 4.90k | .isDifferentiable(context.getModule())) { |
538 | 0 | auto diag = context.emitNondifferentiabilityError( |
539 | 0 | original, invoker, diag::autodiff_nondifferentiable_argument); |
540 | 0 | return llvm::None; |
541 | 0 | } |
542 | 5.63k | } |
543 | | // Check and diagnose non-differentiable results. |
544 | 3.46k | for (auto resultIndex : desiredResultIndices->getIndices()) { |
545 | 3.46k | SILType resultType; |
546 | 3.46k | if (resultIndex >= originalFnTy->getNumResults()) { |
547 | 200 | auto semanticResultParamIdx = resultIndex - originalFnTy->getNumResults(); |
548 | 200 | auto semanticResultParam = |
549 | 200 | *std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(), |
550 | 200 | semanticResultParamIdx); |
551 | 200 | resultType = semanticResultParam.getSILStorageInterfaceType(); |
552 | 3.26k | } else { |
553 | 3.26k | resultType = originalFnTy->getResults()[resultIndex] |
554 | 3.26k | .getSILStorageInterfaceType(); |
555 | 3.26k | } |
556 | 3.46k | if (!resultType.isDifferentiable(context.getModule())) { |
557 | 0 | context.emitNondifferentiabilityError( |
558 | 0 | original, invoker, diag::autodiff_nondifferentiable_result); |
559 | 0 | return llvm::None; |
560 | 0 | } |
561 | 3.46k | } |
562 | | // Check and diagnose external declarations. |
563 | 3.40k | if (originalFn->isExternalDeclaration()) { |
564 | 0 | context.emitNondifferentiabilityError( |
565 | 0 | original, invoker, |
566 | 0 | diag::autodiff_external_nondifferentiable_function); |
567 | 0 | return llvm::None; |
568 | 0 | } |
569 | | // Sanity check passed. Create a new differentiability witness and |
570 | | // canonicalize it. |
571 | 3.40k | GenericSignature contextualDerivativeGenSig = GenericSignature(); |
572 | 3.40k | if (invoker.getKind() == |
573 | 3.40k | DifferentiationInvoker::Kind::IndirectDifferentiation) |
574 | 0 | contextualDerivativeGenSig = |
575 | 0 | invoker.getIndirectDifferentiation() |
576 | 0 | .second->getDerivativeGenericSignature(); |
577 | 3.40k | auto derivativeConstrainedGenSig = |
578 | 3.40k | autodiff::getConstrainedDerivativeGenericSignature( |
579 | 3.40k | originalFn->getLoweredFunctionType(), |
580 | 3.40k | desiredParameterIndices, desiredResultIndices, |
581 | 3.40k | contextualDerivativeGenSig, |
582 | 3.40k | LookUpConformanceInModule(context.getModule().getSwiftModule())); |
583 | 3.40k | minimalWitness = SILDifferentiabilityWitness::createDefinition( |
584 | 3.40k | context.getModule(), SILLinkage::Private, originalFn, |
585 | 3.40k | DifferentiabilityKind::Reverse, desiredParameterIndices, |
586 | 3.40k | desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr, |
587 | 3.40k | /*vjp*/ nullptr, /*isSerialized*/ false); |
588 | 3.40k | if (transformer.canonicalizeDifferentiabilityWitness( |
589 | 3.40k | minimalWitness, invoker, IsNotSerialized)) |
590 | 12 | return llvm::None; |
591 | 3.40k | } |
592 | 22.5k | assert(minimalWitness); |
593 | 22.5k | if (original->getFunction()->isSerialized() && |
594 | 22.5k | !hasPublicVisibility(minimalWitness->getLinkage())) { |
595 | 16 | enum { Inlinable = 0, DefaultArgument = 1 }; |
596 | 16 | unsigned fragileKind = Inlinable; |
597 | | // FIXME: This is not a very robust way of determining if the function is |
598 | | // a default argument. Also, we have not exhaustively listed all the kinds |
599 | | // of fragility. |
600 | 16 | if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI) |
601 | 8 | fragileKind = DefaultArgument; |
602 | 16 | context.emitNondifferentiabilityError( |
603 | 16 | original, invoker, diag::autodiff_private_derivative_from_fragile, |
604 | 16 | fragileKind, |
605 | 16 | isa_and_nonnull<AbstractClosureExpr>( |
606 | 16 | originalFRI->getLoc().getAsASTNode<Expr>())); |
607 | 16 | return llvm::None; |
608 | 16 | } |
609 | | // TODO(TF-482): Move generic requirement checking logic to |
610 | | // `getExactDifferentiabilityWitness` and |
611 | | // `getOrCreateMinimalASTDifferentiabilityWitness`. |
612 | | // Get the substitution map for checking unmet generic requirements. |
613 | | // By default, use the forwarding substitution map of the original function. |
614 | | // If the original callee is a `partial_apply` or `apply` instruction, use |
615 | | // its substitution map instead. |
616 | 22.5k | auto substMap = original->getFunction()->getForwardingSubstitutionMap(); |
617 | 22.5k | if (auto *pai = |
618 | 22.5k | peerThroughFunctionConversions<PartialApplyInst>(original)) { |
619 | 6.58k | substMap = pai->getSubstitutionMap(); |
620 | 15.9k | } else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original)) { |
621 | 0 | substMap = ai->getSubstitutionMap(); |
622 | 0 | } |
623 | 22.5k | if (diagnoseUnsatisfiedRequirements( |
624 | 22.5k | context, original->getType().castTo<SILFunctionType>(), |
625 | 22.5k | minimalWitness->getDerivativeGenericSignature(), substMap, invoker, |
626 | 22.5k | original.getLoc().getSourceLoc())) |
627 | 4 | return llvm::None; |
628 | 22.5k | DifferentiabilityWitnessFunctionKind witnessKind; |
629 | 22.5k | switch (kind) { |
630 | 11.2k | case AutoDiffDerivativeFunctionKind::JVP: |
631 | 11.2k | witnessKind = DifferentiabilityWitnessFunctionKind::JVP; |
632 | 11.2k | break; |
633 | 11.2k | case AutoDiffDerivativeFunctionKind::VJP: |
634 | 11.2k | witnessKind = DifferentiabilityWitnessFunctionKind::VJP; |
635 | 11.2k | break; |
636 | 22.5k | } |
637 | 22.5k | auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction( |
638 | 22.5k | loc, witnessKind, minimalWitness); |
639 | 22.5k | auto convertedRef = reapplyFunctionConversion( |
640 | 22.5k | context, derivativeFnRef, originalFRI, original, builder, loc, |
641 | 22.5k | newBuffersToDealloc, desiredConfig.parameterIndices, |
642 | 22.5k | desiredConfig.resultIndices, |
643 | 22.5k | derivativeFnRef->getType() |
644 | 22.5k | .getASTType() |
645 | 22.5k | ->castTo<SILFunctionType>() |
646 | 22.5k | ->getSubstGenericSignature()); |
647 | 22.5k | return std::make_pair(convertedRef, minimalWitness->getConfig()); |
648 | 22.5k | } |
649 | | |
650 | | // Handle `witness_method`. |
651 | 684 | if (auto *witnessMethod = |
652 | 684 | peerThroughFunctionConversions<WitnessMethodInst>(original)) { |
653 | 372 | auto loc = witnessMethod->getLoc(); |
654 | 372 | auto requirementDeclRef = witnessMethod->getMember(); |
655 | 372 | auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl(); |
656 | | // If requirement declaration does not have any derivative function |
657 | | // configurations, produce an error. |
658 | 372 | if (requirementDecl->getDerivativeFunctionConfigurations().empty()) { |
659 | 4 | context.emitNondifferentiabilityError( |
660 | 4 | original, invoker, diag::autodiff_protocol_member_not_differentiable); |
661 | 4 | return llvm::None; |
662 | 4 | } |
663 | | // Find the minimal derivative configuration: minimal parameter indices and |
664 | | // corresponding derivative generic signature. If it does not exist, produce |
665 | | // an error. |
666 | 368 | IndexSubset *minimalASTParamIndices = nullptr; |
667 | 368 | auto minimalConfig = findMinimalDerivativeConfiguration( |
668 | 368 | requirementDecl, desiredConfig.parameterIndices, |
669 | 368 | minimalASTParamIndices); |
670 | 368 | if (!minimalConfig) { |
671 | 0 | context.emitNondifferentiabilityError( |
672 | 0 | original, invoker, |
673 | 0 | diag::autodiff_member_subset_indices_not_differentiable); |
674 | 0 | return llvm::None; |
675 | 0 | } |
676 | | // Emit a `witness_method` instruction for the derivative function. |
677 | 368 | auto originalType = witnessMethod->getType().castTo<SILFunctionType>(); |
678 | 368 | auto assocType = originalType->getAutoDiffDerivativeFunctionType( |
679 | 368 | minimalConfig->parameterIndices, minimalConfig->resultIndices, kind, |
680 | 368 | context.getTypeConverter(), |
681 | 368 | LookUpConformanceInModule(builder.getModule().getSwiftModule())); |
682 | 368 | auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( |
683 | 368 | kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature, |
684 | 368 | context.getASTContext()); |
685 | 368 | auto *ref = builder.createWitnessMethod( |
686 | 368 | loc, witnessMethod->getLookupType(), witnessMethod->getConformance(), |
687 | 368 | requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), |
688 | 368 | SILType::getPrimitiveObjectType(assocType)); |
689 | 368 | auto convertedRef = reapplyFunctionConversion( |
690 | 368 | context, ref, witnessMethod, original, builder, loc, |
691 | 368 | newBuffersToDealloc, desiredConfig.parameterIndices, |
692 | 368 | desiredConfig.resultIndices); |
693 | 368 | return std::make_pair(convertedRef, *minimalConfig); |
694 | 368 | } |
695 | | |
696 | | // Handle `class_method`. |
697 | 312 | if (auto *classMethod = |
698 | 312 | peerThroughFunctionConversions<ClassMethodInst>(original)) { |
699 | 312 | auto loc = classMethod->getLoc(); |
700 | 312 | auto methodDeclRef = classMethod->getMember(); |
701 | 312 | auto *methodDecl = methodDeclRef.getAbstractFunctionDecl(); |
702 | | // If method declaration does not have any derivative function |
703 | | // configurations, produce an error. |
704 | 312 | if (methodDecl->getDerivativeFunctionConfigurations().empty()) { |
705 | 8 | context.emitNondifferentiabilityError( |
706 | 8 | original, invoker, diag::autodiff_class_member_not_differentiable); |
707 | 8 | return llvm::None; |
708 | 8 | } |
709 | | // Find the minimal derivative configuration: minimal parameter indices and |
710 | | // corresponding derivative generic signature. If it does not exist, produce |
711 | | // an error. |
712 | 304 | IndexSubset *minimalASTParamIndices = nullptr; |
713 | 304 | auto minimalConfig = findMinimalDerivativeConfiguration( |
714 | 304 | methodDecl, desiredConfig.parameterIndices, minimalASTParamIndices); |
715 | 304 | if (!minimalConfig) { |
716 | 0 | context.emitNondifferentiabilityError( |
717 | 0 | original, invoker, |
718 | 0 | diag::autodiff_member_subset_indices_not_differentiable); |
719 | 0 | return llvm::None; |
720 | 0 | } |
721 | | // Emit a `class_method` instruction for the derivative function. |
722 | 304 | auto originalType = classMethod->getType().castTo<SILFunctionType>(); |
723 | 304 | auto assocType = originalType->getAutoDiffDerivativeFunctionType( |
724 | 304 | minimalConfig->parameterIndices, minimalConfig->resultIndices, kind, |
725 | 304 | context.getTypeConverter(), |
726 | 304 | LookUpConformanceInModule(builder.getModule().getSwiftModule())); |
727 | 304 | auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( |
728 | 304 | kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature, |
729 | 304 | context.getASTContext()); |
730 | 304 | auto *ref = builder.createClassMethod( |
731 | 304 | loc, classMethod->getOperand(), |
732 | 304 | methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), |
733 | 304 | SILType::getPrimitiveObjectType(assocType)); |
734 | 304 | auto convertedRef = reapplyFunctionConversion( |
735 | 304 | context, ref, classMethod, original, builder, loc, newBuffersToDealloc, |
736 | 304 | desiredConfig.parameterIndices, desiredConfig.resultIndices); |
737 | 304 | return std::make_pair(convertedRef, *minimalConfig); |
738 | 304 | } |
739 | | |
740 | | // Emit the general opaque function error. |
741 | 0 | context.emitNondifferentiabilityError( |
742 | 0 | original, invoker, diag::autodiff_opaque_function_not_differentiable); |
743 | 0 | return llvm::None; |
744 | 312 | } |
745 | | |
746 | | //===----------------------------------------------------------------------===// |
747 | | // `SILDifferentiabilityWitness` processing |
748 | | //===----------------------------------------------------------------------===// |
749 | | |
750 | | static SILFunction *createEmptyVJP(ADContext &context, |
751 | | SILDifferentiabilityWitness *witness, |
752 | 5.25k | IsSerialized_t isSerialized) { |
753 | 5.25k | auto original = witness->getOriginalFunction(); |
754 | 5.25k | auto config = witness->getConfig(); |
755 | 5.25k | LLVM_DEBUG({ |
756 | 5.25k | auto &s = getADDebugStream(); |
757 | 5.25k | s << "Creating VJP for " << original->getName() << ":\n\t"; |
758 | 5.25k | s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; |
759 | 5.25k | s << "Config: " << config << "\n\t"; |
760 | 5.25k | }); |
761 | | |
762 | 5.25k | auto &module = context.getModule(); |
763 | 5.25k | auto originalTy = original->getLoweredFunctionType(); |
764 | | |
765 | | // === Create an empty VJP. === |
766 | 5.25k | Mangle::DifferentiationMangler mangler; |
767 | 5.25k | auto vjpName = mangler.mangleDerivativeFunction( |
768 | 5.25k | original->getName(), AutoDiffDerivativeFunctionKind::VJP, config); |
769 | 5.25k | auto vjpCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature(); |
770 | 5.25k | GenericEnvironment *vjpGenericEnv = nullptr; |
771 | 5.25k | if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete()) |
772 | 940 | vjpGenericEnv = vjpCanGenSig.getGenericEnvironment(); |
773 | 5.25k | auto vjpType = originalTy->getAutoDiffDerivativeFunctionType( |
774 | 5.25k | config.parameterIndices, config.resultIndices, |
775 | 5.25k | AutoDiffDerivativeFunctionKind::VJP, |
776 | 5.25k | module.Types, LookUpConformanceInModule(module.getSwiftModule()), |
777 | 5.25k | vjpCanGenSig, |
778 | 5.25k | /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk); |
779 | | |
780 | 5.25k | SILOptFunctionBuilder fb(context.getTransform()); |
781 | 5.25k | auto *vjp = fb.createFunction( |
782 | 5.25k | witness->getLinkage(), |
783 | 5.25k | context.getASTContext().getIdentifier(vjpName).str(), vjpType, |
784 | 5.25k | vjpGenericEnv, original->getLocation(), original->isBare(), |
785 | 5.25k | IsNotTransparent, isSerialized, original->isDynamicallyReplaceable(), |
786 | 5.25k | original->isDistributed(), |
787 | 5.25k | original->isRuntimeAccessible()); |
788 | 5.25k | vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp)); |
789 | | |
790 | 5.25k | LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType() |
791 | 5.25k | << "\n"); |
792 | 5.25k | return vjp; |
793 | 5.25k | } |
794 | | |
795 | | static SILFunction *createEmptyJVP(ADContext &context, |
796 | | SILDifferentiabilityWitness *witness, |
797 | 5.50k | IsSerialized_t isSerialized) { |
798 | 5.50k | auto original = witness->getOriginalFunction(); |
799 | 5.50k | auto config = witness->getConfig(); |
800 | 5.50k | LLVM_DEBUG({ |
801 | 5.50k | auto &s = getADDebugStream(); |
802 | 5.50k | s << "Creating JVP for " << original->getName() << ":\n\t"; |
803 | 5.50k | s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; |
804 | 5.50k | s << "Config: " << config << "\n\t"; |
805 | 5.50k | }); |
806 | | |
807 | 5.50k | auto &module = context.getModule(); |
808 | 5.50k | auto originalTy = original->getLoweredFunctionType(); |
809 | | |
810 | 5.50k | Mangle::DifferentiationMangler mangler; |
811 | 5.50k | auto jvpName = mangler.mangleDerivativeFunction( |
812 | 5.50k | original->getName(), AutoDiffDerivativeFunctionKind::JVP, config); |
813 | 5.50k | auto jvpCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature(); |
814 | 5.50k | GenericEnvironment *jvpGenericEnv = nullptr; |
815 | 5.50k | if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete()) |
816 | 1.00k | jvpGenericEnv = jvpCanGenSig.getGenericEnvironment(); |
817 | 5.50k | auto jvpType = originalTy->getAutoDiffDerivativeFunctionType( |
818 | 5.50k | config.parameterIndices, config.resultIndices, |
819 | 5.50k | AutoDiffDerivativeFunctionKind::JVP, |
820 | 5.50k | module.Types, LookUpConformanceInModule(module.getSwiftModule()), |
821 | 5.50k | jvpCanGenSig, |
822 | 5.50k | /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk); |
823 | | |
824 | 5.50k | SILOptFunctionBuilder fb(context.getTransform()); |
825 | 5.50k | auto *jvp = fb.createFunction( |
826 | 5.50k | witness->getLinkage(), |
827 | 5.50k | context.getASTContext().getIdentifier(jvpName).str(), jvpType, |
828 | 5.50k | jvpGenericEnv, original->getLocation(), original->isBare(), |
829 | 5.50k | IsNotTransparent, isSerialized, original->isDynamicallyReplaceable(), |
830 | 5.50k | original->isDistributed(), |
831 | 5.50k | original->isRuntimeAccessible()); |
832 | 5.50k | jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp)); |
833 | | |
834 | 5.50k | LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType() |
835 | 5.50k | << "\n"); |
836 | 5.50k | return jvp; |
837 | 5.50k | } |
838 | | |
839 | | /// Apply the fatal error function with the given name of type |
840 | | /// `@convention(thin) () -> Never` in `f`. |
841 | | static void emitFatalError(ADContext &context, SILFunction *f, |
842 | 4.14k | StringRef fatalErrorFuncName) { |
843 | 4.14k | auto *entry = f->createBasicBlock(); |
844 | 4.14k | createEntryArguments(f); |
845 | 4.14k | SILBuilder builder(entry); |
846 | 4.14k | auto loc = f->getLocation(); |
847 | | // Destroy all owned arguments to pass ownership verification. |
848 | 4.14k | for (auto *arg : entry->getArguments()) |
849 | 8.05k | if (arg->getOwnershipKind() == OwnershipKind::Owned) |
850 | 96 | builder.emitDestroyOperation(loc, arg); |
851 | | // Fatal error with a nice message. |
852 | 4.14k | auto neverTy = |
853 | 4.14k | context.getModule().getASTContext().getNeverType()->getCanonicalType(); |
854 | 4.14k | auto neverResultInfo = SILResultInfo(neverTy, ResultConvention::Unowned); |
855 | | // Fatal error function must have type `@convention(thin) () -> Never`. |
856 | 4.14k | auto fatalErrorFnType = SILFunctionType::get( |
857 | 4.14k | /*genericSig*/ nullptr, SILFunctionType::ExtInfo::getThin(), |
858 | 4.14k | SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {}, |
859 | 4.14k | /*interfaceYields*/ {}, neverResultInfo, |
860 | 4.14k | /*interfaceErrorResults*/ llvm::None, {}, {}, context.getASTContext()); |
861 | 4.14k | auto fnBuilder = SILOptFunctionBuilder(context.getTransform()); |
862 | 4.14k | auto *fatalErrorFn = fnBuilder.getOrCreateFunction( |
863 | 4.14k | loc, fatalErrorFuncName, SILLinkage::PublicExternal, fatalErrorFnType, |
864 | 4.14k | IsNotBare, IsNotTransparent, IsNotSerialized, IsNotDynamic, |
865 | 4.14k | IsNotDistributed, IsNotRuntimeAccessible, ProfileCounter(), IsNotThunk); |
866 | 4.14k | auto *fatalErrorFnRef = builder.createFunctionRef(loc, fatalErrorFn); |
867 | 4.14k | builder.createApply(loc, fatalErrorFnRef, SubstitutionMap(), {}); |
868 | 4.14k | builder.createUnreachable(loc); |
869 | 4.14k | } |
870 | | |
871 | | /// Returns true on error. |
872 | | bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( |
873 | | SILDifferentiabilityWitness *witness, DifferentiationInvoker invoker, |
874 | 5.66k | IsSerialized_t serializeFunctions) { |
875 | 5.66k | std::string traceMessage; |
876 | 5.66k | llvm::raw_string_ostream OS(traceMessage); |
877 | 5.66k | OS << "processing "; |
878 | 5.66k | witness->print(OS); |
879 | 5.66k | OS << " on"; |
880 | 5.66k | OS.flush(); |
881 | 5.66k | PrettyStackTraceSILFunction trace( |
882 | 5.66k | traceMessage.c_str(), witness->getOriginalFunction()); |
883 | | |
884 | 5.66k | assert(witness->isDefinition()); |
885 | | |
886 | | // If the JVP doesn't exist, need to synthesize it. |
887 | 5.66k | if (!witness->getJVP()) { |
888 | | // Diagnose: |
889 | | // - Functions with no return. |
890 | | // - Functions with unsupported control flow. |
891 | 5.50k | if (context.getASTContext() |
892 | 5.50k | .LangOpts.hasFeature(Feature::ForwardModeDifferentiation) && |
893 | 5.50k | (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) || |
894 | 1.36k | diagnoseUnsupportedControlFlow( |
895 | 1.36k | context, witness->getOriginalFunction(), invoker))) |
896 | 0 | return true; |
897 | | |
898 | | // Create empty JVP. |
899 | 5.50k | auto *jvp = createEmptyJVP(context, witness, serializeFunctions); |
900 | 5.50k | witness->setJVP(jvp); |
901 | 5.50k | context.recordGeneratedFunction(jvp); |
902 | | |
903 | | // For now, only do JVP generation if the flag is enabled and if custom VJP |
904 | | // does not exist. If custom VJP exists but custom JVP does not, skip JVP |
905 | | // generation because generated JVP may not match semantics of custom VJP. |
906 | | // Instead, create an empty JVP. |
907 | 5.50k | if (context.getASTContext() |
908 | 5.50k | .LangOpts.hasFeature(Feature::ForwardModeDifferentiation) && |
909 | 5.50k | !witness->getVJP()) { |
910 | | // JVP and differential generation do not currently support functions with |
911 | | // multiple basic blocks. |
912 | 1.36k | if (witness->getOriginalFunction()->size() > 1) { |
913 | 8 | context.emitNondifferentiabilityError( |
914 | 8 | witness->getOriginalFunction()->getLocation().getSourceLoc(), |
915 | 8 | invoker, diag::autodiff_jvp_control_flow_not_supported); |
916 | 8 | return true; |
917 | 8 | } |
918 | | // Emit JVP function. |
919 | 1.35k | JVPCloner cloner(context, witness, jvp, invoker); |
920 | 1.35k | if (cloner.run()) |
921 | 20 | return true; |
922 | 4.14k | } else { |
923 | | // If JVP generation is disabled or a user-defined custom VJP function |
924 | | // exists, fatal error with a nice message. |
925 | 4.14k | emitFatalError(context, jvp, |
926 | 4.14k | "_fatalErrorForwardModeDifferentiationDisabled"); |
927 | 4.14k | LLVM_DEBUG(getADDebugStream() |
928 | 4.14k | << "Generated empty JVP for " |
929 | 4.14k | << witness->getOriginalFunction()->getName() << ":\n" |
930 | 4.14k | << *jvp); |
931 | 4.14k | } |
932 | 5.50k | } |
933 | | |
934 | | // If the VJP doesn't exist, need to synthesize it. |
935 | 5.63k | if (!witness->getVJP()) { |
936 | | // Diagnose: |
937 | | // - Functions with no return. |
938 | | // - Functions with unsupported control flow. |
939 | 5.26k | if (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) || |
940 | 5.26k | diagnoseUnsupportedControlFlow( |
941 | 5.25k | context, witness->getOriginalFunction(), invoker)) |
942 | 4 | return true; |
943 | | |
944 | | // Create empty VJP. |
945 | 5.25k | auto *vjp = createEmptyVJP(context, witness, serializeFunctions); |
946 | 5.25k | witness->setVJP(vjp); |
947 | 5.25k | context.recordGeneratedFunction(vjp); |
948 | | // Emit VJP function. |
949 | 5.25k | VJPCloner cloner(context, witness, vjp, invoker); |
950 | 5.25k | return cloner.run(); |
951 | 5.26k | } |
952 | 376 | return false; |
953 | 5.63k | } |
954 | | |
955 | | //===----------------------------------------------------------------------===// |
956 | | // Differentiation pass implementation |
957 | | //===----------------------------------------------------------------------===// |
958 | | |
959 | | /// The automatic differentiation pass. |
960 | | namespace { |
961 | | class Differentiation : public SILModuleTransform { |
962 | | public: |
963 | 24.3k | Differentiation() : SILModuleTransform() {} |
964 | | void run() override; |
965 | | }; |
966 | | } // end anonymous namespace |
967 | | |
968 | | /// Given a curry thunk application, clone the thunk to return a |
969 | | /// `@differentiable` function-typed value and apply the cloned thunk. |
970 | | /// |
971 | | /// Curry thunk type: `(Self) -> (T, ...) -> U`. |
972 | | /// Cloned thunk type: `(Self) -> @differentiable (T, ...) -> U`. |
973 | | static SILValue promoteCurryThunkApplicationToDifferentiableFunction( |
974 | | DifferentiationTransformer &dt, DifferentiableFunctionInst *dfi, |
975 | 11.6k | SILBuilder &builder, SILLocation loc, DifferentiationInvoker invoker) { |
976 | 11.6k | auto origFnOperand = dfi->getOriginalFunction(); |
977 | 11.6k | auto *parameterIndices = dfi->getParameterIndices(); |
978 | 11.6k | auto *resultIndices = dfi->getResultIndices(); |
979 | 11.6k | auto &context = dt.getContext(); |
980 | | |
981 | | // Check for curry thunk application: |
982 | | // - The original function operand must be an `apply` instruction. |
983 | | // - The `apply` callee must be a `function_ref` instruction. |
984 | | // - The callee must return a function-typed value. |
985 | 11.6k | auto *ai = dyn_cast<ApplyInst>(origFnOperand); |
986 | 11.6k | if (!ai) |
987 | 11.6k | return nullptr; |
988 | 4 | auto *thunkRef = dyn_cast<FunctionRefInst>(ai->getCallee()); |
989 | 4 | if (!thunkRef) |
990 | 0 | return nullptr; |
991 | 4 | auto *thunk = thunkRef->getReferencedFunction(); |
992 | 4 | auto thunkTy = thunk->getLoweredFunctionType(); |
993 | 4 | auto thunkResult = thunkTy->getSingleResult(); |
994 | 4 | auto resultFnTy = thunkResult.getInterfaceType()->getAs<SILFunctionType>(); |
995 | 4 | if (!resultFnTy) |
996 | 0 | return nullptr; |
997 | | |
998 | | // Create a new curry thunk. |
999 | 4 | AutoDiffConfig desiredConfig(parameterIndices, resultIndices); |
1000 | | // TODO(TF-685): Use more principled mangling for thunks. |
1001 | 4 | auto newThunkName = "AD__" + thunk->getName().str() + |
1002 | 4 | "__differentiable_curry_thunk_" + desiredConfig.mangle(); |
1003 | | |
1004 | | // Construct new curry thunk type with `@differentiable` function |
1005 | | // result. |
1006 | 4 | auto diffResultFnTy = resultFnTy->getWithExtInfo( |
1007 | 4 | resultFnTy->getExtInfo() |
1008 | 4 | .intoBuilder() |
1009 | 4 | .withDifferentiabilityKind(DifferentiabilityKind::Reverse) |
1010 | 4 | .build()); |
1011 | 4 | auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy); |
1012 | 4 | auto thunkType = SILFunctionType::get( |
1013 | 4 | thunkTy->getSubstGenericSignature(), thunkTy->getExtInfo(), |
1014 | 4 | thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(), |
1015 | 4 | thunkTy->getParameters(), {}, {newThunkResult}, {}, |
1016 | 4 | thunkTy->getPatternSubstitutions(), thunkTy->getInvocationSubstitutions(), |
1017 | 4 | thunkTy->getASTContext()); |
1018 | | |
1019 | | // Construct new curry thunk, returning a `@differentiable` function. |
1020 | 4 | SILOptFunctionBuilder fb(dt.getTransform()); |
1021 | 4 | auto *newThunk = fb.getOrCreateFunction( |
1022 | 4 | loc, newThunkName, getSpecializedLinkage(thunk, thunk->getLinkage()), |
1023 | 4 | thunkType, thunk->isBare(), thunk->isTransparent(), thunk->isSerialized(), |
1024 | 4 | thunk->isDynamicallyReplaceable(), thunk->isDistributed(), |
1025 | 4 | thunk->isRuntimeAccessible(), |
1026 | 4 | ProfileCounter(), thunk->isThunk()); |
1027 | | // If new thunk is newly created: clone the old thunk body, wrap the |
1028 | | // returned function value with an `differentiable_function` |
1029 | | // instruction, and process the `differentiable_function` instruction. |
1030 | 4 | if (newThunk->empty()) { |
1031 | 4 | newThunk->setGenericEnvironment(thunkType->getSubstGenericSignature().getGenericEnvironment()); |
1032 | | |
1033 | 4 | BasicTypeSubstCloner cloner(thunk, newThunk); |
1034 | 4 | cloner.cloneFunction(); |
1035 | 4 | auto *retInst = cast<ReturnInst>(newThunk->findReturnBB()->getTerminator()); |
1036 | 4 | auto returnValue = retInst->getOperand(); |
1037 | | // Create `differentiable_function` instruction directly after the |
1038 | | // defining instruction (e.g. `partial_apply`) of the returned value. |
1039 | | // Note: `differentiable_function` is not created at the end of the |
1040 | | // new thunk to avoid `alloc_stack`/`dealloc_stack` ordering issues. |
1041 | 4 | SILBuilderWithScope dfiBuilder( |
1042 | 4 | std::next(returnValue->getDefiningInstruction()->getIterator())); |
1043 | 4 | auto *dfi = context.createDifferentiableFunction( |
1044 | 4 | dfiBuilder, loc, parameterIndices, resultIndices, returnValue); |
1045 | 4 | dfiBuilder.setInsertionPoint(newThunk->findReturnBB()); |
1046 | 4 | dfiBuilder.createReturn(loc, dfi); |
1047 | 4 | retInst->eraseFromParent(); |
1048 | | |
1049 | 4 | context.recordGeneratedFunction(newThunk); |
1050 | 4 | context.getDifferentiableFunctionInstWorklist().push_back(dfi); |
1051 | 4 | if (dt.processDifferentiableFunctionInst(dfi)) |
1052 | 0 | return nullptr; |
1053 | 4 | } |
1054 | | |
1055 | | // Apply the new curry thunk. |
1056 | 4 | auto *newThunkRef = builder.createFunctionRef(loc, newThunk); |
1057 | 4 | context.recordGeneratedFunctionReference(newThunkRef); |
1058 | 4 | SmallVector<SILValue, 8> newArgs; |
1059 | 4 | SmallVector<SILValue, 8> newArgsToDestroy; |
1060 | 4 | SmallVector<AllocStackInst *, 1> newBuffersToDealloc; |
1061 | 4 | copyParameterArgumentsForApply(ai, newArgs, newArgsToDestroy, |
1062 | 4 | newBuffersToDealloc); |
1063 | 4 | auto *newApply = builder.createApply( |
1064 | 4 | loc, newThunkRef, ai->getSubstitutionMap(), newArgs, |
1065 | 4 | ai->getApplyOptions()); |
1066 | 4 | for (auto arg : newArgsToDestroy) |
1067 | 0 | builder.emitDestroyOperation(loc, arg); |
1068 | 4 | for (auto *alloc : newBuffersToDealloc) |
1069 | 0 | builder.createDeallocStack(loc, alloc); |
1070 | 4 | return newApply; |
1071 | 4 | } |
1072 | | |
1073 | | SILValue DifferentiationTransformer::promoteToDifferentiableFunction( |
1074 | | DifferentiableFunctionInst *dfi, SILBuilder &builder, SILLocation loc, |
1075 | 11.6k | DifferentiationInvoker invoker) { |
1076 | 11.6k | auto &astCtx = context.getASTContext(); |
1077 | 11.6k | auto origFnOperand = dfi->getOriginalFunction(); |
1078 | 11.6k | auto origFnTy = origFnOperand->getType().castTo<SILFunctionType>(); |
1079 | 11.6k | auto *parameterIndices = dfi->getParameterIndices(); |
1080 | 11.6k | auto *resultIndices = dfi->getResultIndices(); |
1081 | | |
1082 | 11.6k | if (auto diffFn = promoteCurryThunkApplicationToDifferentiableFunction( |
1083 | 11.6k | *this, dfi, builder, loc, invoker)) |
1084 | 4 | return diffFn; |
1085 | | |
1086 | 11.6k | AutoDiffConfig desiredConfig(parameterIndices, resultIndices); |
1087 | 11.6k | SmallVector<SILValue, 2> derivativeFns; |
1088 | 11.6k | SmallVector<AllocStackInst *, 2> newBuffersToDealloc; |
1089 | 11.6k | for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP, |
1090 | 23.2k | AutoDiffDerivativeFunctionKind::VJP}) { |
1091 | 23.2k | auto derivativeFnAndIndices = emitDerivativeFunctionReference( |
1092 | 23.2k | *this, builder, desiredConfig, derivativeFnKind, origFnOperand, |
1093 | 23.2k | invoker, newBuffersToDealloc); |
1094 | | // Show an error at the operator, highlight the argument, and show a note |
1095 | | // at the definition site of the argument. |
1096 | 23.2k | if (!derivativeFnAndIndices) |
1097 | 44 | return nullptr; |
1098 | | |
1099 | 23.2k | auto derivativeFn = derivativeFnAndIndices->first; |
1100 | 23.2k | context.recordGeneratedFunctionReference(derivativeFn); |
1101 | | |
1102 | | // If desired indices are a subset of actual indices, create a "subset |
1103 | | // indices thunk" and destroy the emitted derivative function reference. |
1104 | | // - For JVPs: the thunked JVP returns a differential taking fewer |
1105 | | // parameters (using `.zero` for the dropped parameters). |
1106 | | // - For VJPs: the thunked VJP returns a pullback that drops the unused |
1107 | | // tangent values. |
1108 | 23.2k | auto actualConfig = derivativeFnAndIndices->second; |
1109 | | // NOTE: `desiredIndices` may come from a partially-applied function and |
1110 | | // have smaller capacity than `actualIndices`. We expect this logic to go |
1111 | | // away when we support `@differentiable` partial apply. |
1112 | | // if (actualIndices != desiredIndices) { // TODO: Re-enable. |
1113 | 23.2k | auto extendedDesiredParameterIndices = |
1114 | 23.2k | desiredConfig.parameterIndices->extendingCapacity( |
1115 | 23.2k | astCtx, actualConfig.parameterIndices->getCapacity()); |
1116 | 23.2k | if (!actualConfig.parameterIndices->equals(extendedDesiredParameterIndices) |
1117 | 23.2k | || !actualConfig.resultIndices->equals(desiredConfig.resultIndices)) { |
1118 | | // Destroy the already emitted derivative function reference because it |
1119 | | // is no longer used. |
1120 | 1.81k | builder.emitDestroyValueOperation(loc, derivativeFn); |
1121 | | // Check if underlying original function reference has been partially |
1122 | | // applied with arguments. If so, produce an error: parameter subset |
1123 | | // thunks do not yet support this case because partially applied arguments |
1124 | | // cannot be propagated to parameter subset thunks. |
1125 | 1.81k | auto didPartiallyApplyArguments = [](SILValue original) { |
1126 | 2.40k | while (auto *pai = |
1127 | 1.81k | peerThroughFunctionConversions<PartialApplyInst>(original)) { |
1128 | 584 | if (pai->getNumArguments() > 0) |
1129 | 0 | return true; |
1130 | 584 | original = pai->getCallee(); |
1131 | 584 | } |
1132 | 1.81k | return false; |
1133 | 1.81k | }; |
1134 | 1.81k | if (didPartiallyApplyArguments(origFnOperand)) { |
1135 | 0 | context.emitNondifferentiabilityError( |
1136 | 0 | origFnOperand, invoker, |
1137 | 0 | diag::autodiff_cannot_param_subset_thunk_partially_applied_orig_fn); |
1138 | 0 | return nullptr; |
1139 | 0 | } |
1140 | | // Create the parameter subset thunk. |
1141 | 1.81k | assert(actualConfig.parameterIndices->isSupersetOf( |
1142 | 1.81k | extendedDesiredParameterIndices)); |
1143 | 0 | SILFunction *thunk; |
1144 | 1.81k | SubstitutionMap interfaceSubs; |
1145 | 1.81k | SILOptFunctionBuilder fb(transform); |
1146 | 1.81k | std::tie(thunk, interfaceSubs) = |
1147 | 1.81k | getOrCreateSubsetParametersThunkForDerivativeFunction( |
1148 | 1.81k | fb, origFnOperand, derivativeFn, derivativeFnKind, desiredConfig, |
1149 | 1.81k | actualConfig, context); |
1150 | 1.81k | auto *thunkFRI = builder.createFunctionRef(loc, thunk); |
1151 | 1.81k | if (auto genSig = |
1152 | 1.81k | thunk->getLoweredFunctionType()->getSubstGenericSignature()) { |
1153 | 96 | derivativeFn = |
1154 | 96 | builder.createPartialApply(loc, thunkFRI, interfaceSubs, {}, |
1155 | 96 | ParameterConvention::Direct_Guaranteed); |
1156 | 1.72k | } else { |
1157 | 1.72k | derivativeFn = thunkFRI; |
1158 | 1.72k | } |
1159 | 1.81k | } |
1160 | 23.2k | auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType( |
1161 | 23.2k | parameterIndices, resultIndices, derivativeFnKind, |
1162 | 23.2k | context.getTypeConverter(), |
1163 | 23.2k | LookUpConformanceInModule(context.getModule().getSwiftModule())); |
1164 | | // If `derivativeFn` is `@convention(thin)` but is expected to be |
1165 | | // `@convention(thick)`, emit a `thin_to_thick` instruction. |
1166 | 23.2k | if (expectedDerivativeFnTy->getRepresentation() == |
1167 | 23.2k | SILFunctionTypeRepresentation::Thick && |
1168 | 23.2k | derivativeFn->getType() |
1169 | 13.2k | .castTo<SILFunctionType>() |
1170 | 13.2k | ->getRepresentation() == SILFunctionTypeRepresentation::Thin) { |
1171 | 488 | derivativeFn = builder.createThinToThickFunction( |
1172 | 488 | loc, derivativeFn, |
1173 | 488 | SILType::getPrimitiveObjectType(expectedDerivativeFnTy)); |
1174 | 488 | } |
1175 | | // If derivative function value's type is not ABI-compatible with the |
1176 | | // expected derivative function type (i.e. parameter and result conventions |
1177 | | // do not match), perform reabstraction. |
1178 | 23.2k | auto abiCompatibility = expectedDerivativeFnTy->isABICompatibleWith( |
1179 | 23.2k | derivativeFn->getType().castTo<SILFunctionType>(), *dfi->getFunction()); |
1180 | 23.2k | if (!abiCompatibility.isCompatible()) { |
1181 | 96 | SILOptFunctionBuilder fb(context.getTransform()); |
1182 | 96 | auto newDerivativeFn = reabstractFunction( |
1183 | 96 | builder, fb, loc, derivativeFn, expectedDerivativeFnTy, |
1184 | 96 | [](SubstitutionMap substMap) { return substMap; }); |
1185 | 96 | derivativeFn = newDerivativeFn; |
1186 | 96 | assert(expectedDerivativeFnTy |
1187 | 96 | ->isABICompatibleWith( |
1188 | 96 | derivativeFn->getType().castTo<SILFunctionType>(), |
1189 | 96 | *dfi->getFunction()) |
1190 | 96 | .isCompatible()); |
1191 | 96 | } |
1192 | | |
1193 | 0 | derivativeFns.push_back(derivativeFn); |
1194 | 23.2k | } |
1195 | | // Deallocate temporary buffers used for creating derivative functions. |
1196 | 11.6k | for (auto *buf : llvm::reverse(newBuffersToDealloc)) |
1197 | 96 | builder.createDeallocStack(loc, buf); |
1198 | | |
1199 | | // If our original copy does not have none ownership, copy it. |
1200 | 11.6k | if (origFnOperand->getOwnershipKind() != OwnershipKind::None) |
1201 | 3.47k | origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand); |
1202 | 11.6k | auto *newDiffFn = context.createDifferentiableFunction( |
1203 | 11.6k | builder, loc, parameterIndices, resultIndices, origFnOperand, |
1204 | 11.6k | std::make_pair(derivativeFns[0], derivativeFns[1])); |
1205 | 11.6k | context.getDifferentiableFunctionInstWorklist().push_back(dfi); |
1206 | 11.6k | return newDiffFn; |
1207 | 11.6k | } |
1208 | | |
1209 | | SILValue DifferentiationTransformer::promoteToLinearFunction( |
1210 | | LinearFunctionInst *lfi, SILBuilder &builder, SILLocation loc, |
1211 | 12 | DifferentiationInvoker invoker) { |
1212 | | // Note: for now, this function creates a new `linear_function` instruction |
1213 | | // with an undef transpose function operand. Eventually, a legitimate |
1214 | | // transpose function operand should be created and used. |
1215 | 12 | auto origFnOperand = lfi->getOriginalFunction(); |
1216 | 12 | if (origFnOperand->getOwnershipKind() != OwnershipKind::None) |
1217 | 0 | origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand); |
1218 | 12 | auto *parameterIndices = lfi->getParameterIndices(); |
1219 | 12 | auto originalType = origFnOperand->getType().castTo<SILFunctionType>(); |
1220 | 12 | auto transposeFnType = originalType->getAutoDiffTransposeFunctionType( |
1221 | 12 | parameterIndices, context.getTypeConverter(), |
1222 | 12 | LookUpConformanceInModule(builder.getModule().getSwiftModule())); |
1223 | 12 | auto transposeType = SILType::getPrimitiveObjectType(transposeFnType); |
1224 | 12 | auto transposeFn = SILUndef::get(transposeType, builder.getFunction()); |
1225 | 12 | auto *newLinearFn = context.createLinearFunction( |
1226 | 12 | builder, loc, parameterIndices, origFnOperand, SILValue(transposeFn)); |
1227 | 12 | context.getLinearFunctionInstWorklist().push_back(lfi); |
1228 | 12 | return newLinearFn; |
1229 | 12 | } |
1230 | | |
1231 | | bool DifferentiationTransformer::processDifferentiableFunctionInst( |
1232 | 14.7k | DifferentiableFunctionInst *dfi) { |
1233 | 14.7k | PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`", |
1234 | 14.7k | dfi); |
1235 | 14.7k | PrettyStackTraceSILFunction fnTrace("...in", dfi->getFunction()); |
1236 | 14.7k | LLVM_DEBUG({ |
1237 | 14.7k | auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n"; |
1238 | 14.7k | dfi->printInContext(s); |
1239 | 14.7k | }); |
1240 | | |
1241 | | // If `dfi` already has derivative functions, do not process. |
1242 | 14.7k | if (dfi->hasDerivativeFunctions()) |
1243 | 3.09k | return false; |
1244 | | |
1245 | 11.6k | SILFunction *parent = dfi->getFunction(); |
1246 | 11.6k | auto loc = dfi->getLoc(); |
1247 | 11.6k | SILBuilderWithScope builder(dfi); |
1248 | 11.6k | auto differentiableFnValue = |
1249 | 11.6k | promoteToDifferentiableFunction(dfi, builder, loc, dfi); |
1250 | | // Mark `dfi` as processed so that it won't be reprocessed after deletion. |
1251 | 11.6k | context.markDifferentiableFunctionInstAsProcessed(dfi); |
1252 | 11.6k | if (!differentiableFnValue) |
1253 | 44 | return true; |
1254 | | // Replace all uses of `dfi`. |
1255 | 11.6k | dfi->replaceAllUsesWith(differentiableFnValue); |
1256 | | // Destroy the original operand. |
1257 | 11.6k | builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction()); |
1258 | 11.6k | dfi->eraseFromParent(); |
1259 | 11.6k | transform.invalidateAnalysis(parent, |
1260 | 11.6k | SILAnalysis::InvalidationKind::FunctionBody); |
1261 | 11.6k | return false; |
1262 | 11.6k | } |
1263 | | |
1264 | | bool DifferentiationTransformer::processLinearFunctionInst( |
1265 | 12 | LinearFunctionInst *lfi) { |
1266 | 12 | PrettyStackTraceSILNode dfiTrace("canonicalizing `linear_function`", lfi); |
1267 | 12 | PrettyStackTraceSILFunction fnTrace("...in", lfi->getFunction()); |
1268 | 12 | LLVM_DEBUG({ |
1269 | 12 | auto &s = getADDebugStream() << "Processing LinearFunctionInst:\n"; |
1270 | 12 | lfi->printInContext(s); |
1271 | 12 | }); |
1272 | | |
1273 | | // If `lfi` already has a transpose function, do not process. |
1274 | 12 | if (lfi->hasTransposeFunction()) |
1275 | 0 | return false; |
1276 | | |
1277 | 12 | SILFunction *parent = lfi->getFunction(); |
1278 | 12 | auto loc = lfi->getLoc(); |
1279 | 12 | SILBuilderWithScope builder(lfi); |
1280 | 12 | auto linearFnValue = promoteToLinearFunction(lfi, builder, loc, lfi); |
1281 | | // Mark `lfi` as processed so that it won't be reprocessed after deletion. |
1282 | 12 | context.markLinearFunctionInstAsProcessed(lfi); |
1283 | 12 | if (!linearFnValue) |
1284 | 0 | return true; |
1285 | | // Replace all uses of `lfi`. |
1286 | 12 | lfi->replaceAllUsesWith(linearFnValue); |
1287 | | // Destroy the original operand. |
1288 | 12 | builder.emitDestroyValueOperation(loc, lfi->getOriginalFunction()); |
1289 | 12 | lfi->eraseFromParent(); |
1290 | | |
1291 | 12 | transform.invalidateAnalysis(parent, |
1292 | 12 | SILAnalysis::InvalidationKind::FunctionBody); |
1293 | 12 | return false; |
1294 | 12 | } |
1295 | | |
1296 | | /// Automatic differentiation transform entry. |
1297 | 24.3k | void Differentiation::run() { |
1298 | 24.3k | auto &module = *getModule(); |
1299 | 24.3k | auto &astCtx = module.getASTContext(); |
1300 | 24.3k | debugDump(module); |
1301 | | |
1302 | | // A transformation helper. |
1303 | 24.3k | DifferentiationTransformer transformer(*this); |
1304 | 24.3k | ADContext &context = transformer.getContext(); |
1305 | | |
1306 | 24.3k | bool errorOccurred = false; |
1307 | | |
1308 | | // Register all the SIL differentiability witnesses in the module that trigger |
1309 | | // differentiation. |
1310 | 24.3k | for (auto &witness : module.getDifferentiabilityWitnesses()) { |
1311 | 2.29k | if (witness.isDeclaration()) |
1312 | 36 | continue; |
1313 | 2.26k | context.addInvoker(&witness); |
1314 | 2.26k | } |
1315 | | |
1316 | | // Register all the `differentiable_function` and `linear_function` |
1317 | | // instructions in the module that trigger differentiation. |
1318 | 1.55M | for (SILFunction &f : module) { |
1319 | 1.82M | for (SILBasicBlock &bb : f) { |
1320 | 19.9M | for (SILInstruction &i : bb) { |
1321 | 19.9M | if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i)) { |
1322 | 7.10k | context.getDifferentiableFunctionInstWorklist().push_back(dfi); |
1323 | 19.9M | } else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) { |
1324 | | // If linear map transposition is not enabled and an uncanonical |
1325 | | // `linear_function` instruction is encountered, emit a diagnostic. |
1326 | | // FIXME(https://github.com/apple/swift/issues/54256): Finish support for linear map transposition. |
1327 | 12 | if (!EnableExperimentalLinearMapTransposition) { |
1328 | 4 | if (!lfi->hasTransposeFunction()) { |
1329 | 4 | astCtx.Diags.diagnose( |
1330 | 4 | lfi->getLoc().getSourceLoc(), |
1331 | 4 | diag::autodiff_conversion_to_linear_function_not_supported); |
1332 | 4 | errorOccurred = true; |
1333 | 4 | } |
1334 | 4 | } |
1335 | 12 | context.getLinearFunctionInstWorklist().push_back(lfi); |
1336 | 12 | } |
1337 | 19.9M | } |
1338 | 1.82M | } |
1339 | 1.55M | } |
1340 | | |
1341 | | // If nothing has triggered differentiation, there's nothing to do. |
1342 | 24.3k | if (context.getInvokers().empty() && |
1343 | 24.3k | context.getDifferentiableFunctionInstWorklist().empty() && |
1344 | 24.3k | context.getLinearFunctionInstWorklist().empty()) |
1345 | 24.0k | return; |
1346 | | |
1347 | | // Differentiation relies on the stdlib (the Swift module). |
1348 | | // If it's not imported, it's an internal error. |
1349 | 372 | if (!astCtx.getStdlibModule()) { |
1350 | 0 | astCtx.Diags.diagnose(SourceLoc(), |
1351 | 0 | diag::autodiff_internal_swift_not_imported); |
1352 | 0 | return; |
1353 | 0 | } |
1354 | 372 | if (!astCtx.getLoadedModule(astCtx.Id_Differentiation)) { |
1355 | 0 | SourceLoc loc; |
1356 | 0 | if (!context.getInvokers().empty()) { |
1357 | 0 | loc = context.getInvokers().front().second.getLocation(); |
1358 | 0 | } else { |
1359 | 0 | assert(!context.getDifferentiableFunctionInstWorklist().empty()); |
1360 | 0 | loc = context.getDifferentiableFunctionInstWorklist() |
1361 | 0 | .pop_back_val() |
1362 | 0 | ->getLoc() |
1363 | 0 | .getSourceLoc(); |
1364 | 0 | } |
1365 | 0 | astCtx.Diags.diagnose(loc, |
1366 | 0 | diag::autodiff_differentiation_module_not_imported); |
1367 | 0 | return; |
1368 | 0 | } |
1369 | | |
1370 | | // Process all invokers. |
1371 | 2.26k | for (auto invokerPair : context.getInvokers()) { |
1372 | 2.26k | auto *witness = invokerPair.first; |
1373 | 2.26k | auto invoker = invokerPair.second; |
1374 | 2.26k | if (transformer.canonicalizeDifferentiabilityWitness( |
1375 | 2.26k | witness, invoker, witness->getOriginalFunction()->isSerialized())) |
1376 | 172 | errorOccurred = true; |
1377 | 2.26k | } |
1378 | | |
1379 | | // Iteratively process `differentiable_function` instruction worklist. |
1380 | 26.7k | while (!context.getDifferentiableFunctionInstWorklist().empty()) { |
1381 | 26.3k | auto *dfi = context.getDifferentiableFunctionInstWorklist().pop_back_val(); |
1382 | | // Skip instructions that have been already been processed. |
1383 | 26.3k | if (context.isDifferentiableFunctionInstProcessed(dfi)) |
1384 | 11.6k | continue; |
1385 | 14.7k | errorOccurred |= transformer.processDifferentiableFunctionInst(dfi); |
1386 | 14.7k | } |
1387 | | |
1388 | | // Iteratively process `linear_function` instruction worklist. |
1389 | 396 | while (!context.getLinearFunctionInstWorklist().empty()) { |
1390 | 24 | auto *lfi = context.getLinearFunctionInstWorklist().pop_back_val(); |
1391 | | // Skip instructions that have been already been processed. |
1392 | 24 | if (context.isLinearFunctionInstProcessed(lfi)) |
1393 | 12 | continue; |
1394 | 12 | errorOccurred |= transformer.processLinearFunctionInst(lfi); |
1395 | 12 | } |
1396 | | |
1397 | | // If any error occurred while processing witnesses or |
1398 | | // `differentiable_function` instructions, clean up. |
1399 | 372 | if (errorOccurred) { |
1400 | 24 | context.cleanUp(); |
1401 | 24 | return; |
1402 | 24 | } |
1403 | | |
1404 | 348 | LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n"); |
1405 | 348 | } |
1406 | | |
1407 | | //===----------------------------------------------------------------------===// |
1408 | | // Pass creation |
1409 | | //===----------------------------------------------------------------------===// |
1410 | | |
1411 | 24.3k | SILTransform *swift::createDifferentiation() { return new Differentiation; } |