/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/Common.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // Automatic differentiation common utilities. |
14 | | // |
15 | | //===----------------------------------------------------------------------===// |
16 | | |
17 | | #include "swift/Basic/STLExtras.h" |
18 | | #define DEBUG_TYPE "differentiation" |
19 | | |
20 | | #include "swift/SILOptimizer/Differentiation/Common.h" |
21 | | #include "swift/AST/TypeCheckRequests.h" |
22 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
23 | | |
24 | | namespace swift { |
25 | | namespace autodiff { |
26 | | |
27 | 25.9k | raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; } |
28 | | |
29 | | //===----------------------------------------------------------------------===// |
30 | | // Helpers |
31 | | //===----------------------------------------------------------------------===// |
32 | | |
33 | 10.5k | ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) { |
34 | | // Find the `pointer_to_address` result, peering through `index_addr`. |
35 | 10.5k | auto *ptai = dyn_cast<PointerToAddressInst>(v); |
36 | 10.5k | if (auto *iai = dyn_cast<IndexAddrInst>(v)) |
37 | 124 | ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0)); |
38 | 10.5k | if (!ptai) |
39 | 10.0k | return nullptr; |
40 | | // Return the `array.uninitialized_intrinsic` application, if it exists. |
41 | 488 | if (auto *dti = dyn_cast<DestructureTupleInst>( |
42 | 488 | ptai->getOperand()->getDefiningInstruction())) |
43 | 488 | return ArraySemanticsCall(dti->getOperand(), |
44 | 488 | semantics::ARRAY_UNINITIALIZED_INTRINSIC); |
45 | 0 | return nullptr; |
46 | 488 | } |
47 | | |
48 | 30.9k | DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) { |
49 | 30.9k | bool foundDestructureTupleUser = false; |
50 | 30.9k | if (!value->getType().is<TupleType>()) |
51 | 0 | return nullptr; |
52 | 30.9k | DestructureTupleInst *result = nullptr; |
53 | 30.9k | for (auto *use : value->getUses()) { |
54 | 1.63k | if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) { |
55 | 1.63k | assert(!foundDestructureTupleUser && |
56 | 1.63k | "There should only be one `destructure_tuple` user of a tuple"); |
57 | 0 | foundDestructureTupleUser = true; |
58 | 1.63k | result = dti; |
59 | 1.63k | } |
60 | 1.63k | } |
61 | 30.9k | return result; |
62 | 30.9k | } |
63 | | |
64 | 47.0k | bool isSemanticMemberAccessor(SILFunction *original) { |
65 | 47.0k | auto *dc = original->getDeclContext(); |
66 | 47.0k | if (!dc) |
67 | 744 | return false; |
68 | 46.2k | auto *decl = dc->getAsDecl(); |
69 | 46.2k | if (!decl) |
70 | 10.3k | return false; |
71 | 35.9k | auto *accessor = dyn_cast<AccessorDecl>(decl); |
72 | 35.9k | if (!accessor) |
73 | 33.2k | return false; |
74 | | // Currently, only getters and setters are supported. |
75 | | // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors. |
76 | 2.63k | if (accessor->getAccessorKind() != AccessorKind::Get && |
77 | 2.63k | accessor->getAccessorKind() != AccessorKind::Set) |
78 | 0 | return false; |
79 | | // Accessor must come from a `var` declaration. |
80 | 2.63k | auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage()); |
81 | 2.63k | if (!varDecl) |
82 | 68 | return false; |
83 | | // Return true for stored property accessors. |
84 | 2.56k | if (varDecl->hasStorage() && varDecl->isInstanceMember()) |
85 | 716 | return true; |
86 | | // Return true for properties that have attached property wrappers. |
87 | 1.85k | if (varDecl->hasAttachedPropertyWrapper()) |
88 | 1.28k | return true; |
89 | | // Otherwise, return false. |
90 | | // User-defined accessors can never be supported because they may use custom |
91 | | // logic that does not semantically perform a member access. |
92 | 564 | return false; |
93 | 1.85k | } |
94 | | |
95 | 0 | bool hasSemanticMemberAccessorCallee(ApplySite applySite) { |
96 | 0 | if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee())) |
97 | 0 | if (auto *F = FRI->getReferencedFunctionOrNull()) |
98 | 0 | return isSemanticMemberAccessor(F); |
99 | 0 | return false; |
100 | 0 | } |
101 | | |
102 | | void forEachApplyDirectResult( |
103 | | FullApplySite applySite, |
104 | 77.1k | llvm::function_ref<void(SILValue)> resultCallback) { |
105 | 77.1k | switch (applySite.getKind()) { |
106 | 77.0k | case FullApplySiteKind::ApplyInst: { |
107 | 77.0k | auto *ai = cast<ApplyInst>(applySite.getInstruction()); |
108 | 77.0k | if (!ai->getType().is<TupleType>()) { |
109 | 46.0k | resultCallback(ai); |
110 | 46.0k | return; |
111 | 46.0k | } |
112 | 30.9k | if (auto *dti = getSingleDestructureTupleUser(ai)) |
113 | 1.63k | for (auto directResult : dti->getResults()) |
114 | 3.26k | resultCallback(directResult); |
115 | 30.9k | break; |
116 | 77.0k | } |
117 | 96 | case FullApplySiteKind::BeginApplyInst: { |
118 | 96 | auto *bai = cast<BeginApplyInst>(applySite.getInstruction()); |
119 | 96 | for (auto directResult : bai->getResults()) |
120 | 192 | resultCallback(directResult); |
121 | 96 | break; |
122 | 77.0k | } |
123 | 68 | case FullApplySiteKind::TryApplyInst: { |
124 | 68 | auto *tai = cast<TryApplyInst>(applySite.getInstruction()); |
125 | 68 | for (auto *succBB : tai->getSuccessorBlocks()) |
126 | 136 | for (auto *arg : succBB->getArguments()) |
127 | 136 | resultCallback(arg); |
128 | 68 | break; |
129 | 77.0k | } |
130 | 77.1k | } |
131 | 77.1k | } |
132 | | |
133 | | void collectAllFormalResultsInTypeOrder(SILFunction &function, |
134 | 11.9k | SmallVectorImpl<SILValue> &results) { |
135 | 11.9k | SILFunctionConventions convs(function.getLoweredFunctionType(), |
136 | 11.9k | function.getModule()); |
137 | 11.9k | auto indResults = function.getIndirectResults(); |
138 | 11.9k | auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator()); |
139 | 11.9k | auto retVal = retInst->getOperand(); |
140 | 11.9k | SmallVector<SILValue, 8> dirResults; |
141 | 11.9k | if (auto *tupleInst = |
142 | 11.9k | dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction())) |
143 | 3.39k | dirResults.append(tupleInst->getElements().begin(), |
144 | 3.39k | tupleInst->getElements().end()); |
145 | 8.54k | else |
146 | 8.54k | dirResults.push_back(retVal); |
147 | 11.9k | unsigned indResIdx = 0, dirResIdx = 0; |
148 | 11.9k | for (auto &resInfo : convs.getResults()) |
149 | 11.4k | results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++] |
150 | 11.4k | : indResults[indResIdx++]); |
151 | | // Treat semantic result parameters as semantic results. |
152 | | // Append them` parameters after formal results. |
153 | 19.2k | for (auto i : range(convs.getNumParameters())) { |
154 | 19.2k | auto paramInfo = convs.getParameters()[i]; |
155 | 19.2k | if (!paramInfo.isAutoDiffSemanticResult()) |
156 | 18.3k | continue; |
157 | 816 | auto *argument = function.getArgumentsWithoutIndirectResults()[i]; |
158 | 816 | results.push_back(argument); |
159 | 816 | } |
160 | 11.9k | } |
161 | | |
162 | | void collectAllDirectResultsInTypeOrder(SILFunction &function, |
163 | 1.35k | SmallVectorImpl<SILValue> &results) { |
164 | 1.35k | SILFunctionConventions convs(function.getLoweredFunctionType(), |
165 | 1.35k | function.getModule()); |
166 | 1.35k | auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator()); |
167 | 1.35k | auto retVal = retInst->getOperand(); |
168 | 1.35k | if (auto *tupleInst = dyn_cast<TupleInst>(retVal)) |
169 | 188 | results.append(tupleInst->getElements().begin(), |
170 | 188 | tupleInst->getElements().end()); |
171 | 1.16k | else |
172 | 1.16k | results.push_back(retVal); |
173 | 1.35k | } |
174 | | |
175 | | void collectAllActualResultsInTypeOrder( |
176 | | ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults, |
177 | 15.1k | SmallVectorImpl<SILValue> &results) { |
178 | 15.1k | auto calleeConvs = ai->getSubstCalleeConv(); |
179 | 15.1k | unsigned indResIdx = 0, dirResIdx = 0; |
180 | 17.4k | for (auto &resInfo : calleeConvs.getResults()) { |
181 | 17.4k | results.push_back(resInfo.isFormalDirect() |
182 | 17.4k | ? extractedDirectResults[dirResIdx++] |
183 | 17.4k | : ai->getIndirectSILResults()[indResIdx++]); |
184 | 17.4k | } |
185 | 15.1k | } |
186 | | |
187 | | void collectMinimalIndicesForFunctionCall( |
188 | | ApplyInst *ai, const AutoDiffConfig &parentConfig, |
189 | | const DifferentiableActivityInfo &activityInfo, |
190 | | SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices, |
191 | 14.6k | SmallVectorImpl<unsigned> &resultIndices) { |
192 | 14.6k | auto calleeFnTy = ai->getSubstCalleeType(); |
193 | 14.6k | auto calleeConvs = ai->getSubstCalleeConv(); |
194 | | |
195 | | // Parameter indices are indices (in the callee type signature) of parameter |
196 | | // arguments that are varied or are arguments. |
197 | | // Record all parameter indices in type order. |
198 | 14.6k | unsigned currentParamIdx = 0; |
199 | 35.6k | for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) { |
200 | 35.6k | if (activityInfo.isActive(applyArg, parentConfig)) |
201 | 22.8k | paramIndices.push_back(currentParamIdx); |
202 | 35.6k | ++currentParamIdx; |
203 | 35.6k | } |
204 | | |
205 | | // Result indices are indices (in the callee type signature) of results that |
206 | | // are useful. |
207 | 14.6k | SmallVector<SILValue, 8> directResults; |
208 | 14.6k | forEachApplyDirectResult(ai, [&](SILValue directResult) { |
209 | 9.00k | directResults.push_back(directResult); |
210 | 9.00k | }); |
211 | 14.6k | auto indirectResults = ai->getIndirectSILResults(); |
212 | | // Record all results and result indices in type order. |
213 | 14.6k | results.reserve(calleeFnTy->getNumResults()); |
214 | 14.6k | unsigned dirResIdx = 0; |
215 | 14.6k | unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult(); |
216 | 14.6k | for (const auto &resAndIdx : enumerate(calleeConvs.getResults())) { |
217 | 13.8k | const auto &res = resAndIdx.value(); |
218 | 13.8k | unsigned idx = resAndIdx.index(); |
219 | 13.8k | if (res.isFormalDirect()) { |
220 | 9.00k | results.push_back(directResults[dirResIdx]); |
221 | 9.00k | if (auto dirRes = directResults[dirResIdx]) |
222 | 9.00k | if (dirRes && activityInfo.isActive(dirRes, parentConfig)) |
223 | 8.93k | resultIndices.push_back(idx); |
224 | 9.00k | ++dirResIdx; |
225 | 9.00k | } else { |
226 | 4.79k | results.push_back(indirectResults[indResIdx]); |
227 | 4.79k | if (activityInfo.isActive(indirectResults[indResIdx], parentConfig)) |
228 | 4.76k | resultIndices.push_back(idx); |
229 | 4.79k | ++indResIdx; |
230 | 4.79k | } |
231 | 13.8k | } |
232 | | |
233 | | // Record all semantic result parameters as results. |
234 | 14.6k | auto semanticResultParamResultIndex = calleeFnTy->getNumResults(); |
235 | 35.6k | for (const auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) { |
236 | 35.6k | const auto ¶m = paramAndIdx.value(); |
237 | 35.6k | if (!param.isAutoDiffSemanticResult()) |
238 | 34.4k | continue; |
239 | 1.21k | unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); |
240 | 1.21k | results.push_back(ai->getArgument(idx)); |
241 | 1.21k | resultIndices.push_back(semanticResultParamResultIndex++); |
242 | 1.21k | } |
243 | | |
244 | | // Make sure the function call has active results. |
245 | 14.6k | #ifndef NDEBUG |
246 | 14.6k | assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults()); |
247 | 0 | assert(llvm::any_of(results, [&](SILValue result) { |
248 | 14.6k | return activityInfo.isActive(result, parentConfig); |
249 | 14.6k | })); |
250 | 14.6k | #endif |
251 | 14.6k | } |
252 | | |
253 | | llvm::Optional<std::pair<SILDebugLocation, SILDebugVariable>> |
254 | 34.9k | findDebugLocationAndVariable(SILValue originalValue) { |
255 | 34.9k | if (auto *asi = dyn_cast<AllocStackInst>(originalValue)) |
256 | 6.66k | return swift::transform(asi->getVarInfo(), [&](SILDebugVariable var) { |
257 | 2.90k | return std::make_pair(asi->getDebugLocation(), var); |
258 | 2.90k | }); |
259 | 58.3k | for (auto *use : originalValue->getUses()) { |
260 | 58.3k | if (auto *dvi = dyn_cast<DebugValueInst>(use->getUser())) |
261 | 14.1k | return swift::transform(dvi->getVarInfo(), [&](SILDebugVariable var) { |
262 | | // We need to drop `op_deref` here as we're transferring debug info |
263 | | // location from debug_value instruction (which describes how to get value) |
264 | | // into alloc_stack (which describes the location) |
265 | 14.1k | if (var.DIExpr.startsWithDeref()) |
266 | 2.26k | var.DIExpr.eraseElement(var.DIExpr.element_begin()); |
267 | 14.1k | return std::make_pair(dvi->getDebugLocation(), var); |
268 | 14.1k | }); |
269 | 58.3k | } |
270 | 14.1k | return llvm::None; |
271 | 28.2k | } |
272 | | |
273 | | //===----------------------------------------------------------------------===// |
274 | | // Diagnostic utilities |
275 | | //===----------------------------------------------------------------------===// |
276 | | |
277 | 92 | SILLocation getValidLocation(SILValue v) { |
278 | 92 | auto loc = v.getLoc(); |
279 | 92 | if (loc.isNull() || loc.getSourceLoc().isInvalid()) |
280 | 4 | loc = v->getFunction()->getLocation(); |
281 | 92 | return loc; |
282 | 92 | } |
283 | | |
284 | 4.52k | SILLocation getValidLocation(SILInstruction *inst) { |
285 | 4.52k | auto loc = inst->getLoc(); |
286 | 4.52k | if (loc.isNull() || loc.getSourceLoc().isInvalid()) |
287 | 484 | loc = inst->getFunction()->getLocation(); |
288 | 4.52k | return loc; |
289 | 4.52k | } |
290 | | |
291 | | //===----------------------------------------------------------------------===// |
292 | | // Tangent property lookup utilities |
293 | | //===----------------------------------------------------------------------===// |
294 | | |
295 | | VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, |
296 | | CanType baseType, SILLocation loc, |
297 | 4.10k | DifferentiationInvoker invoker) { |
298 | 4.10k | auto &astCtx = context.getASTContext(); |
299 | 4.10k | auto tanFieldInfo = evaluateOrDefault( |
300 | 4.10k | astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType}, |
301 | 4.10k | TangentPropertyInfo(nullptr)); |
302 | | // If no error, return the tangent property. |
303 | 4.10k | if (tanFieldInfo) |
304 | 4.04k | return tanFieldInfo.tangentProperty; |
305 | | // Otherwise, diagnose error and return nullptr. |
306 | 52 | assert(tanFieldInfo.error); |
307 | 0 | auto *parentDC = originalField->getDeclContext(); |
308 | 52 | assert(parentDC->isTypeContext()); |
309 | 0 | auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr(); |
310 | 52 | auto fieldName = originalField->getNameStr(); |
311 | 52 | auto sourceLoc = loc.getSourceLoc(); |
312 | 52 | switch (tanFieldInfo.error->kind) { |
313 | 0 | case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: |
314 | 0 | llvm_unreachable( |
315 | 0 | "`@noDerivative` stored property accesses should not be " |
316 | 0 | "differentiated; activity analysis should not mark as varied"); |
317 | 0 | case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: |
318 | 0 | context.emitNondifferentiabilityError( |
319 | 0 | sourceLoc, invoker, |
320 | 0 | diag::autodiff_stored_property_parent_not_differentiable, |
321 | 0 | parentDeclName, fieldName); |
322 | 0 | break; |
323 | 8 | case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: |
324 | 8 | context.emitNondifferentiabilityError( |
325 | 8 | sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable, |
326 | 8 | parentDeclName, fieldName, originalField->getInterfaceType()); |
327 | 8 | break; |
328 | 8 | case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: |
329 | 8 | context.emitNondifferentiabilityError( |
330 | 8 | sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct, |
331 | 8 | parentDeclName, fieldName); |
332 | 8 | break; |
333 | 12 | case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: |
334 | 12 | context.emitNondifferentiabilityError( |
335 | 12 | sourceLoc, invoker, |
336 | 12 | diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName, |
337 | 12 | fieldName); |
338 | 12 | break; |
339 | 12 | case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: |
340 | 12 | context.emitNondifferentiabilityError( |
341 | 12 | sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type, |
342 | 12 | parentDeclName, fieldName, tanFieldInfo.error->getType()); |
343 | 12 | break; |
344 | 12 | case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: |
345 | 12 | context.emitNondifferentiabilityError( |
346 | 12 | sourceLoc, invoker, diag::autodiff_tangent_property_not_stored, |
347 | 12 | parentDeclName, fieldName); |
348 | 12 | break; |
349 | 52 | } |
350 | 52 | return nullptr; |
351 | 52 | } |
352 | | |
353 | | VarDecl *getTangentStoredProperty(ADContext &context, |
354 | | SingleValueInstruction *projectionInst, |
355 | | CanType baseType, |
356 | 3.57k | DifferentiationInvoker invoker) { |
357 | 3.57k | assert(isa<StructExtractInst>(projectionInst) || |
358 | 3.57k | isa<StructElementAddrInst>(projectionInst) || |
359 | 3.57k | isa<RefElementAddrInst>(projectionInst)); |
360 | 0 | Projection proj(projectionInst); |
361 | 3.57k | auto loc = getValidLocation(projectionInst); |
362 | 3.57k | auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType()); |
363 | 3.57k | return getTangentStoredProperty(context, field, baseType, |
364 | 3.57k | loc, invoker); |
365 | 3.57k | } |
366 | | |
367 | | //===----------------------------------------------------------------------===// |
368 | | // Code emission utilities |
369 | | //===----------------------------------------------------------------------===// |
370 | | |
371 | | SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder, |
372 | 22.2k | SILLocation loc) { |
373 | 22.2k | if (elements.size() == 1) |
374 | 10.2k | return elements.front(); |
375 | 11.9k | return builder.createTuple(loc, elements); |
376 | 22.2k | } |
377 | | |
378 | | void extractAllElements(SILValue value, SILBuilder &builder, |
379 | 23.3k | SmallVectorImpl<SILValue> &results) { |
380 | 23.3k | auto tupleType = value->getType().getAs<TupleType>(); |
381 | 23.3k | if (!tupleType) { |
382 | 11.4k | results.push_back(value); |
383 | 11.4k | return; |
384 | 11.4k | } |
385 | 11.9k | if (builder.hasOwnership()) { |
386 | 11.9k | auto *dti = builder.createDestructureTuple(value.getLoc(), value); |
387 | 11.9k | results.append(dti->getResults().begin(), dti->getResults().end()); |
388 | 11.9k | return; |
389 | 11.9k | } |
390 | 0 | for (auto i : range(tupleType->getNumElements())) |
391 | 0 | results.push_back(builder.createTupleExtract(value.getLoc(), value, i)); |
392 | 0 | } |
393 | | |
394 | | SILValue emitMemoryLayoutSize( |
395 | 0 | SILBuilder &builder, SILLocation loc, CanType type) { |
396 | 0 | auto &ctx = builder.getASTContext(); |
397 | 0 | auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof)); |
398 | 0 | auto *builtin = cast<FuncDecl>(getBuiltinValueDecl(ctx, id)); |
399 | 0 | auto metatypeTy = SILType::getPrimitiveObjectType( |
400 | 0 | CanMetatypeType::get(type, MetatypeRepresentation::Thin)); |
401 | 0 | auto metatypeVal = builder.createMetatype(loc, metatypeTy); |
402 | 0 | return builder.createBuiltin( |
403 | 0 | loc, id, SILType::getBuiltinWordType(ctx), |
404 | 0 | SubstitutionMap::get( |
405 | 0 | builtin->getGenericSignature(), ArrayRef<Type>{type}, {}), |
406 | 0 | {metatypeVal}); |
407 | 0 | } |
408 | | |
409 | | SILValue emitProjectTopLevelSubcontext( |
410 | | SILBuilder &builder, SILLocation loc, SILValue context, |
411 | 204 | SILType subcontextType) { |
412 | 204 | assert(context->getOwnershipKind() == OwnershipKind::Guaranteed); |
413 | 0 | auto &ctx = builder.getASTContext(); |
414 | 204 | auto id = ctx.getIdentifier( |
415 | 204 | getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext)); |
416 | 204 | assert(context->getType() == SILType::getNativeObjectType(ctx)); |
417 | 0 | auto *subcontextAddr = builder.createBuiltin( |
418 | 204 | loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context}); |
419 | 204 | return builder.createPointerToAddress( |
420 | 204 | loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true); |
421 | 204 | } |
422 | | |
423 | | //===----------------------------------------------------------------------===// |
424 | | // Utilities for looking up derivatives of functions |
425 | | //===----------------------------------------------------------------------===// |
426 | | |
427 | | /// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one, |
428 | | /// returns `nullptr`. |
429 | 6.34k | static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) { |
430 | 6.34k | auto *DC = F->getDeclContext(); |
431 | 6.34k | if (!DC) |
432 | 88 | return nullptr; |
433 | 6.25k | auto *D = DC->getAsDecl(); |
434 | 6.25k | if (!D) |
435 | 1.49k | return nullptr; |
436 | 4.76k | return dyn_cast<AbstractFunctionDecl>(D); |
437 | 6.25k | } |
438 | | |
439 | | SILDifferentiabilityWitness * |
440 | | getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, |
441 | | IndexSubset *parameterIndices, |
442 | 22.5k | IndexSubset *resultIndices) { |
443 | 22.5k | for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction( |
444 | 22.5k | original->getName())) { |
445 | 18.4k | if (w->getParameterIndices() == parameterIndices && |
446 | 18.4k | w->getResultIndices() == resultIndices) |
447 | 16.2k | return w; |
448 | 18.4k | } |
449 | 6.34k | return nullptr; |
450 | 22.5k | } |
451 | | |
452 | | llvm::Optional<AutoDiffConfig> |
453 | | findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, |
454 | | IndexSubset *parameterIndices, |
455 | 5.43k | IndexSubset *&minimalASTParameterIndices) { |
456 | 5.43k | llvm::Optional<AutoDiffConfig> minimalConfig = llvm::None; |
457 | 5.43k | auto configs = original->getDerivativeFunctionConfigurations(); |
458 | 5.43k | for (auto &config : configs) { |
459 | 3.85k | auto *silParameterIndices = autodiff::getLoweredParameterIndices( |
460 | 3.85k | config.parameterIndices, |
461 | 3.85k | original->getInterfaceType()->castTo<AnyFunctionType>()); |
462 | | |
463 | 3.85k | if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) { |
464 | 0 | assert(original->getCaptureInfo().hasLocalCaptures()); |
465 | 0 | silParameterIndices = |
466 | 0 | silParameterIndices->extendingCapacity(original->getASTContext(), |
467 | 0 | parameterIndices->getCapacity()); |
468 | 0 | } |
469 | | |
470 | | // If all indices in `parameterIndices` are in `daParameterIndices`, and |
471 | | // it has fewer indices than our current candidate and a primitive VJP, |
472 | | // then `attr` is our new candidate. |
473 | | // |
474 | | // NOTE(TF-642): `attr` may come from a un-partial-applied function and |
475 | | // have larger capacity than the desired indices. We expect this logic to |
476 | | // go away when `partial_apply` supports `@differentiable` callees. |
477 | 3.85k | if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity( |
478 | 3.85k | original->getASTContext(), silParameterIndices->getCapacity())) && |
479 | | // fewer parameters than before |
480 | 3.85k | (!minimalConfig || |
481 | 3.63k | silParameterIndices->getNumIndices() < |
482 | 3.61k | minimalConfig->parameterIndices->getNumIndices())) { |
483 | 3.61k | minimalASTParameterIndices = config.parameterIndices; |
484 | 3.61k | minimalConfig = |
485 | 3.61k | AutoDiffConfig(silParameterIndices, config.resultIndices, |
486 | 3.61k | autodiff::getDifferentiabilityWitnessGenericSignature( |
487 | 3.61k | original->getGenericSignature(), |
488 | 3.61k | config.derivativeGenericSignature)); |
489 | 3.61k | } |
490 | 3.85k | } |
491 | 5.43k | return minimalConfig; |
492 | 5.43k | } |
493 | | |
494 | | SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( |
495 | | SILModule &module, SILFunction *original, DifferentiabilityKind kind, |
496 | 6.34k | IndexSubset *parameterIndices, IndexSubset *resultIndices) { |
497 | | // Explicit differentiability witnesses only exist on SIL functions that come |
498 | | // from AST functions. |
499 | 6.34k | auto *originalAFD = findAbstractFunctionDecl(original); |
500 | 6.34k | if (!originalAFD) |
501 | 1.58k | return nullptr; |
502 | | |
503 | 4.76k | IndexSubset *minimalASTParameterIndices = nullptr; |
504 | 4.76k | auto minimalConfig = findMinimalDerivativeConfiguration( |
505 | 4.76k | originalAFD, parameterIndices, minimalASTParameterIndices); |
506 | 4.76k | if (!minimalConfig) |
507 | 1.82k | return nullptr; |
508 | | |
509 | 2.94k | std::string originalName = original->getName().str(); |
510 | | // If original function requires a foreign entry point, use the foreign SIL |
511 | | // function to get or create the minimal differentiability witness. |
512 | 2.94k | if (requiresForeignEntryPoint(originalAFD)) { |
513 | 304 | originalName = SILDeclRef(originalAFD).asForeign().mangle(); |
514 | 304 | original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign()); |
515 | 304 | } |
516 | | |
517 | 2.94k | auto *existingWitness = module.lookUpDifferentiabilityWitness( |
518 | 2.94k | {originalName, kind, *minimalConfig}); |
519 | 2.94k | if (existingWitness) |
520 | 1.90k | return existingWitness; |
521 | | |
522 | 1.03k | assert(original->isExternalDeclaration() && |
523 | 1.03k | "SILGen should create differentiability witnesses for all function " |
524 | 1.03k | "definitions with explicit differentiable attributes"); |
525 | | |
526 | 0 | return SILDifferentiabilityWitness::createDeclaration( |
527 | 1.03k | module, SILLinkage::PublicExternal, original, kind, |
528 | 1.03k | minimalConfig->parameterIndices, minimalConfig->resultIndices, |
529 | 1.03k | minimalConfig->derivativeGenericSignature); |
530 | 2.94k | } |
531 | | |
532 | | } // end namespace autodiff |
533 | | } // end namespace swift |