/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- LinearMapInfo.cpp ------------------------------------*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // Linear map tuple and branching trace enum information for differentiation. |
14 | | // |
15 | | //===----------------------------------------------------------------------===// |
16 | | |
17 | | #define DEBUG_TYPE "differentiation" |
18 | | |
19 | | #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" |
20 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
21 | | |
22 | | #include "swift/AST/DeclContext.h" |
23 | | #include "swift/AST/ParameterList.h" |
24 | | #include "swift/AST/SourceFile.h" |
25 | | #include "swift/SIL/LoopInfo.h" |
26 | | |
27 | | namespace swift { |
28 | | namespace autodiff { |
29 | | |
30 | | //===----------------------------------------------------------------------===// |
31 | | // Local helpers |
32 | | //===----------------------------------------------------------------------===// |
33 | | |
34 | | /// Clone the generic parameters of the given generic signature and return a new |
35 | | /// `GenericParamList`. |
36 | | static GenericParamList *cloneGenericParameters(ASTContext &ctx, |
37 | | DeclContext *dc, |
38 | 1.54k | CanGenericSignature sig) { |
39 | 1.54k | SmallVector<GenericTypeParamDecl *, 2> clonedParams; |
40 | 1.71k | for (auto paramType : sig.getGenericParams()) { |
41 | 1.71k | auto *clonedParam = GenericTypeParamDecl::createImplicit( |
42 | 1.71k | dc, paramType->getName(), paramType->getDepth(), paramType->getIndex(), |
43 | 1.71k | paramType->isParameterPack()); |
44 | 1.71k | clonedParam->setDeclContext(dc); |
45 | 1.71k | clonedParams.push_back(clonedParam); |
46 | 1.71k | } |
47 | 1.54k | return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc()); |
48 | 1.54k | } |
49 | | |
50 | | //===----------------------------------------------------------------------===// |
51 | | // LinearMapInfo methods |
52 | | //===----------------------------------------------------------------------===// |
53 | | |
54 | | LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, |
55 | | SILFunction *original, SILFunction *derivative, |
56 | | const AutoDiffConfig &config, |
57 | | const DifferentiableActivityInfo &activityInfo, |
58 | | SILLoopInfo *loopInfo) |
59 | | : kind(kind), original(original), derivative(derivative), |
60 | | activityInfo(activityInfo), loopInfo(loopInfo), config(config), |
61 | | synthesizedFile(context.getOrCreateSynthesizedFile(original)), |
62 | 6.61k | typeConverter(context.getTypeConverter()) { |
63 | 6.61k | generateDifferentiationDataStructures(context, derivative); |
64 | 6.61k | } |
65 | | |
66 | 7.35k | SILType LinearMapInfo::remapTypeInDerivative(SILType ty) { |
67 | 7.35k | if (ty.hasArchetype()) |
68 | 656 | return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext()); |
69 | 6.69k | return derivative->mapTypeIntoContext(ty); |
70 | 7.35k | } |
71 | | |
72 | | EnumDecl * |
73 | | LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB, |
74 | 8.60k | CanGenericSignature genericSig) { |
75 | 8.60k | assert(originalBB->getParent() == original); |
76 | 0 | auto &astCtx = original->getASTContext(); |
77 | 8.60k | auto &file = getSynthesizedFile(); |
78 | | // Create a branching trace enum. |
79 | 8.60k | Mangle::ASTMangler mangler; |
80 | 8.60k | auto config = this->config.withGenericSignature(genericSig); |
81 | 8.60k | auto enumName = mangler.mangleAutoDiffGeneratedDeclaration( |
82 | 8.60k | AutoDiffGeneratedDeclarationKind::BranchingTraceEnum, |
83 | 8.60k | original->getName().str(), originalBB->getDebugID(), kind, config); |
84 | 8.60k | auto enumId = astCtx.getIdentifier(enumName); |
85 | 8.60k | auto loc = original->getLocation().getSourceLoc(); |
86 | 8.60k | GenericParamList *genericParams = nullptr; |
87 | 8.60k | if (genericSig) |
88 | 1.54k | genericParams = cloneGenericParameters(astCtx, &file, genericSig); |
89 | 8.60k | auto *branchingTraceDecl = new (astCtx) EnumDecl( |
90 | 8.60k | /*EnumLoc*/ SourceLoc(), /*Name*/ enumId, /*NameLoc*/ loc, |
91 | 8.60k | /*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file); |
92 | | // Note: must mark enum as implicit to satisfy assertion in |
93 | | // `Parser::parseDeclListDelayed`. |
94 | 8.60k | branchingTraceDecl->setImplicit(); |
95 | 8.60k | if (genericSig) |
96 | 1.54k | branchingTraceDecl->setGenericSignature(genericSig); |
97 | 8.60k | switch (original->getEffectiveSymbolLinkage()) { |
98 | 252 | case swift::SILLinkage::Public: |
99 | 252 | case swift::SILLinkage::PublicNonABI: |
100 | | // Branching trace enums shall not be resilient. |
101 | 252 | branchingTraceDecl->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true)); |
102 | 252 | branchingTraceDecl->getAttrs().add(new (astCtx) UsableFromInlineAttr(/*Implicit*/ true)); |
103 | 252 | LLVM_FALLTHROUGH; |
104 | 3.08k | case swift::SILLinkage::Hidden: |
105 | 3.19k | case swift::SILLinkage::Shared: |
106 | 3.19k | branchingTraceDecl->setAccess(AccessLevel::Internal); |
107 | 3.19k | break; |
108 | 5.41k | case swift::SILLinkage::Private: |
109 | 5.41k | branchingTraceDecl->setAccess(AccessLevel::FilePrivate); |
110 | 5.41k | break; |
111 | 0 | default: |
112 | | // When the original function has external linkage, we create an internal |
113 | | // struct for use by our own module. This is necessary for cross-cell |
114 | | // differentiation in Jupyter. |
115 | | // TODO: Add a test in the compiler that exercises a similar situation as |
116 | | // cross-cell differentiation in Jupyter. |
117 | 0 | branchingTraceDecl->setAccess(AccessLevel::Internal); |
118 | 8.60k | } |
119 | 8.60k | file.addTopLevelDecl(branchingTraceDecl); |
120 | 8.60k | file.getParentModule()->clearLookupCache(); |
121 | | |
122 | 8.60k | return branchingTraceDecl; |
123 | 8.60k | } |
124 | | |
125 | | void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB, |
126 | 1.99k | SILLoopInfo *loopInfo) { |
127 | 1.99k | auto &astCtx = original->getASTContext(); |
128 | 1.99k | auto *moduleDecl = original->getModule().getSwiftModule(); |
129 | 1.99k | auto loc = original->getLocation().getSourceLoc(); |
130 | 1.99k | auto *branchingTraceDecl = getBranchingTraceDecl(originalBB); |
131 | | |
132 | | // Add basic block enum cases. |
133 | 2.64k | for (auto *predBB : originalBB->getPredecessorBlocks()) { |
134 | | // Create dummy declaration representing enum case parameter. |
135 | 2.64k | auto *decl = new (astCtx) |
136 | 2.64k | ParamDecl(loc, loc, Identifier(), loc, Identifier(), moduleDecl); |
137 | 2.64k | decl->setSpecifier(ParamDecl::Specifier::Default); |
138 | | // If predecessor block is in a loop, its linear map tuple will be |
139 | | // indirectly referenced in memory owned by the context object. The payload |
140 | | // is just a raw pointer. |
141 | 2.64k | if (loopInfo->getLoopFor(predBB)) { |
142 | 408 | heapAllocatedContext = true; |
143 | 408 | decl->setInterfaceType(astCtx.TheRawPointerType); |
144 | 2.24k | } else { // Otherwise the payload is the linear map tuple. |
145 | 2.24k | auto *linearMapStructTy = getLinearMapTupleType(predBB); |
146 | 2.24k | assert(linearMapStructTy && "must have linear map struct type for predecessor BB"); |
147 | 0 | auto canLinearMapStructTy = linearMapStructTy->getCanonicalType(); |
148 | 2.24k | decl->setInterfaceType( |
149 | 2.24k | canLinearMapStructTy->hasArchetype() |
150 | 2.24k | ? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy); |
151 | 2.24k | } |
152 | | // Create enum element and enum case declarations. |
153 | 0 | auto *paramList = ParameterList::create(astCtx, {decl}); |
154 | 2.64k | auto bbId = "bb" + std::to_string(predBB->getDebugID()); |
155 | 2.64k | auto *enumEltDecl = new (astCtx) EnumElementDecl( |
156 | 2.64k | /*IdentifierLoc*/ loc, DeclName(astCtx.getIdentifier(bbId)), paramList, |
157 | 2.64k | loc, /*RawValueExpr*/ nullptr, branchingTraceDecl); |
158 | 2.64k | enumEltDecl->setImplicit(); |
159 | 2.64k | auto *enumCaseDecl = EnumCaseDecl::create( |
160 | 2.64k | /*CaseLoc*/ loc, {enumEltDecl}, branchingTraceDecl); |
161 | 2.64k | enumCaseDecl->setImplicit(); |
162 | 2.64k | branchingTraceDecl->addMember(enumEltDecl); |
163 | 2.64k | branchingTraceDecl->addMember(enumCaseDecl); |
164 | | // Record enum element declaration. |
165 | 2.64k | branchingTraceEnumCases.insert({{predBB, originalBB}, enumEltDecl}); |
166 | 2.64k | } |
167 | 1.99k | } |
168 | | |
169 | | |
170 | 7.35k | Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { |
171 | 7.35k | SmallVector<SILValue, 4> allResults; |
172 | 7.35k | SmallVector<unsigned, 8> activeParamIndices; |
173 | 7.35k | SmallVector<unsigned, 8> activeResultIndices; |
174 | 7.35k | collectMinimalIndicesForFunctionCall(ai, config, activityInfo, allResults, |
175 | 7.35k | activeParamIndices, activeResultIndices); |
176 | | |
177 | | // Check if there are any active results or arguments. If not, skip |
178 | | // this instruction. |
179 | 7.37k | auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) { |
180 | 7.37k | return activityInfo.isActive(res, config); |
181 | 7.37k | }); |
182 | 7.35k | bool hasActiveSemanticResultArgument = false; |
183 | 7.35k | bool hasActiveArguments = false; |
184 | 7.35k | auto numIndirectResults = ai->getNumIndirectResults(); |
185 | 17.8k | for (auto argIdx : range(ai->getSubstCalleeConv().getNumParameters())) { |
186 | 17.8k | auto arg = ai->getArgumentsWithoutIndirectResults()[argIdx]; |
187 | 17.8k | if (activityInfo.isActive(arg, config)) { |
188 | 11.4k | hasActiveArguments = true; |
189 | 11.4k | auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( |
190 | 11.4k | numIndirectResults + argIdx); |
191 | 11.4k | if (paramInfo.isAutoDiffSemanticResult()) |
192 | 608 | hasActiveSemanticResultArgument = true; |
193 | 11.4k | } |
194 | 17.8k | } |
195 | 7.35k | if (!hasActiveArguments) |
196 | 0 | return {}; |
197 | 7.35k | if (!hasActiveResults && !hasActiveSemanticResultArgument) |
198 | 0 | return {}; |
199 | | |
200 | | // Compute differentiability parameters. |
201 | | // - If the callee has `@differentiable` function type, use differentiation |
202 | | // parameters from the function type. |
203 | | // - Otherwise, use the active parameters. |
204 | 7.35k | IndexSubset *parameters; |
205 | 7.35k | auto origFnSubstTy = ai->getSubstCalleeType(); |
206 | 7.35k | auto remappedOrigFnSubstTy = |
207 | 7.35k | remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy)) |
208 | 7.35k | .castTo<SILFunctionType>() |
209 | 7.35k | ->getUnsubstitutedType(original->getModule()); |
210 | 7.35k | if (remappedOrigFnSubstTy->isDifferentiable()) { |
211 | 80 | parameters = remappedOrigFnSubstTy->getDifferentiabilityParameterIndices(); |
212 | 7.27k | } else { |
213 | 7.27k | parameters = IndexSubset::get( |
214 | 7.27k | original->getASTContext(), |
215 | 7.27k | ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices); |
216 | 7.27k | } |
217 | | // Compute differentiability results. |
218 | 7.35k | auto *results = IndexSubset::get(original->getASTContext(), |
219 | 7.35k | remappedOrigFnSubstTy->getNumAutoDiffSemanticResults(), |
220 | 7.35k | activeResultIndices); |
221 | | // Create autodiff indices for the `apply` instruction. |
222 | 7.35k | AutoDiffConfig applyConfig(parameters, results); |
223 | | |
224 | | // Check for non-differentiable original function type. |
225 | 7.35k | auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType |
226 | 7.35k | origFnTy) { |
227 | | // Check non-differentiable arguments. |
228 | 11.4k | for (auto paramIndex : applyConfig.parameterIndices->getIndices()) { |
229 | 11.4k | auto remappedParamType = |
230 | 11.4k | origFnTy->getParameters()[paramIndex].getSILStorageInterfaceType(); |
231 | 11.4k | if (!remappedParamType.isDifferentiable(derivative->getModule())) |
232 | 20 | return true; |
233 | 11.4k | } |
234 | | // Check non-differentiable results. |
235 | 7.44k | for (auto resultIndex : applyConfig.resultIndices->getIndices()) { |
236 | 7.44k | SILType remappedResultType; |
237 | 7.44k | if (resultIndex >= origFnTy->getNumResults()) { |
238 | 604 | auto semanticResultArgIdx = resultIndex - origFnTy->getNumResults(); |
239 | 604 | auto semanticResultArg = |
240 | 604 | *std::next(ai->getAutoDiffSemanticResultArguments().begin(), |
241 | 604 | semanticResultArgIdx); |
242 | 604 | remappedResultType = semanticResultArg->getType(); |
243 | 6.83k | } else { |
244 | 6.83k | remappedResultType = |
245 | 6.83k | origFnTy->getResults()[resultIndex].getSILStorageInterfaceType(); |
246 | 6.83k | } |
247 | 7.44k | if (!remappedResultType.isDifferentiable(derivative->getModule())) |
248 | 12 | return true; |
249 | 7.44k | } |
250 | 7.32k | return false; |
251 | 7.33k | }; |
252 | 7.35k | if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy)) |
253 | 32 | return nullptr; |
254 | | |
255 | 7.32k | AutoDiffDerivativeFunctionKind derivativeFnKind(kind); |
256 | 7.32k | auto derivativeFnType = |
257 | 7.32k | remappedOrigFnSubstTy |
258 | 7.32k | ->getAutoDiffDerivativeFunctionType( |
259 | 7.32k | parameters, results, derivativeFnKind, context.getTypeConverter(), |
260 | 7.32k | LookUpConformanceInModule( |
261 | 7.32k | derivative->getModule().getSwiftModule())) |
262 | 7.32k | ->getUnsubstitutedType(original->getModule()); |
263 | | |
264 | 7.32k | auto derivativeFnResultTypes = derivativeFnType->getAllResultsInterfaceType(); |
265 | 7.32k | auto linearMapSILType = derivativeFnResultTypes; |
266 | 7.32k | if (auto tupleType = linearMapSILType.getAs<TupleType>()) { |
267 | 6.78k | linearMapSILType = SILType::getPrimitiveObjectType( |
268 | 6.78k | tupleType.getElementType(tupleType->getElements().size() - 1)); |
269 | 6.78k | } |
270 | 7.32k | if (auto fnTy = linearMapSILType.getAs<SILFunctionType>()) { |
271 | 7.32k | linearMapSILType = SILType::getPrimitiveObjectType( |
272 | 7.32k | fnTy->getUnsubstitutedType(original->getModule())); |
273 | 7.32k | } |
274 | | |
275 | | // IRGen requires decls to have AST types (not `SILFunctionType`), so we |
276 | | // convert the `SILFunctionType` of the linear map to a `FunctionType` with |
277 | | // the same parameters and results. |
278 | 7.32k | auto silFnTy = linearMapSILType.castTo<SILFunctionType>(); |
279 | 7.32k | SmallVector<AnyFunctionType::Param, 8> params; |
280 | 8.37k | for (auto ¶m : silFnTy->getParameters()) { |
281 | 8.37k | ParameterTypeFlags flags; |
282 | 8.37k | if (param.isAutoDiffSemanticResult()) |
283 | 604 | flags = flags.withInOut(true); |
284 | | |
285 | 8.37k | params.push_back( |
286 | 8.37k | AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags)); |
287 | 8.37k | } |
288 | | |
289 | 7.32k | AnyFunctionType *astFnTy; |
290 | 7.32k | if (auto genSig = silFnTy->getSubstGenericSignature()) { |
291 | | // FIXME: Verify ExtInfo state is correct, not working by accident. |
292 | 0 | GenericFunctionType::ExtInfo info; |
293 | 0 | astFnTy = GenericFunctionType::get( |
294 | 0 | genSig, params, silFnTy->getAllResultsInterfaceType().getASTType(), |
295 | 0 | info); |
296 | 7.32k | } else { |
297 | 7.32k | FunctionType::ExtInfo info; |
298 | 7.32k | astFnTy = FunctionType::get( |
299 | 7.32k | params, silFnTy->getAllResultsInterfaceType().getASTType(), info); |
300 | 7.32k | } |
301 | | |
302 | 7.32k | if (astFnTy->hasArchetype()) |
303 | 628 | return astFnTy->mapTypeOutOfContext(); |
304 | | |
305 | 6.69k | return astFnTy; |
306 | 7.32k | } |
307 | | |
308 | | void LinearMapInfo::generateDifferentiationDataStructures( |
309 | 6.61k | ADContext &context, SILFunction *derivativeFn) { |
310 | 6.61k | auto &astCtx = original->getASTContext(); |
311 | | // Get the derivative function generic signature. |
312 | 6.61k | CanGenericSignature derivativeFnGenSig = nullptr; |
313 | 6.61k | if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment()) |
314 | 1.10k | derivativeFnGenSig = |
315 | 1.10k | derivativeFnGenEnv->getGenericSignature().getCanonicalSignature(); |
316 | | |
317 | | // Create branching trace enum for each original block and add it as a field |
318 | | // in the corresponding struct. |
319 | 6.61k | StringRef traceEnumFieldName; |
320 | 6.61k | switch (kind) { |
321 | 1.35k | case AutoDiffLinearMapKind::Differential: |
322 | 1.35k | traceEnumFieldName = "successor"; |
323 | 1.35k | break; |
324 | 5.25k | case AutoDiffLinearMapKind::Pullback: |
325 | 5.25k | traceEnumFieldName = "predecessor"; |
326 | 5.25k | break; |
327 | 6.61k | } |
328 | | |
329 | 8.60k | for (auto &origBB : *original) { |
330 | 8.60k | auto *traceEnum = |
331 | 8.60k | createBranchingTraceDecl(&origBB, derivativeFnGenSig); |
332 | 8.60k | branchingTraceDecls.insert({&origBB, traceEnum}); |
333 | 8.60k | } |
334 | | |
335 | | // Add linear map fields to the linear map tuples. |
336 | | // |
337 | | // Now we need to be very careful as we're having a very subtle |
338 | | // chicken-and-egg problem. We need lowered branch trace enum type for the |
339 | | // linear map typle type. However branch trace enum type lowering depends on |
340 | | // the lowering of its elements (at very least, the type classification of |
341 | | // being trivial / non-trivial). As the lowering is cached we need to ensure |
342 | | // we compute lowered type for the branch trace enum when the corresponding |
343 | | // EnumDecl is fully complete: we cannot add more entries without causing some |
344 | | // very subtle issues later on. However, the elements of the enum are linear |
345 | | // map tuples of predecessors, that correspondingly may contain branch trace |
346 | | // enums of corresponding predecessor BBs. |
347 | | // |
348 | | // Traverse all BBs in reverse post-order traversal order to ensure we process |
349 | | // each BB before its predecessors. |
350 | 6.61k | llvm::ReversePostOrderTraversal<SILFunction *> RPOT(original); |
351 | 15.2k | for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) { |
352 | 8.60k | auto *origBB = *Iter; |
353 | 8.60k | SmallVector<TupleTypeElt, 4> linearTupleTypes; |
354 | 8.60k | if (!origBB->isEntry()) { |
355 | 1.99k | populateBranchingTraceDecl(origBB, loopInfo); |
356 | | |
357 | 1.99k | CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType(); |
358 | 1.99k | linearTupleTypes.emplace_back(traceEnumType, |
359 | 1.99k | astCtx.getIdentifier(traceEnumFieldName)); |
360 | 1.99k | } |
361 | | |
362 | 8.60k | if (isSemanticMemberAccessor(original)) { |
363 | | // Do not add linear map fields for semantic member accessors, which have |
364 | | // special-case pullback generation. Linear map tuples should be empty. |
365 | 8.33k | } else { |
366 | 100k | for (auto &inst : *origBB) { |
367 | 100k | if (auto *ai = dyn_cast<ApplyInst>(&inst)) { |
368 | | // Add linear map field to struct for active `apply` instructions. |
369 | | // Skip array literal intrinsic applications since array literal |
370 | | // initialization is linear and handled separately. |
371 | 11.5k | if (!shouldDifferentiateApplySite(ai) || |
372 | 11.5k | ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) |
373 | 3.99k | continue; |
374 | 7.56k | if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) |
375 | 212 | continue; |
376 | 7.35k | LLVM_DEBUG(getADDebugStream() |
377 | 7.35k | << "Adding linear map tuple field for " << *ai); |
378 | 7.35k | if (Type linearMapType = getLinearMapType(context, ai)) { |
379 | 7.32k | linearMapIndexMap.insert({ai, linearTupleTypes.size()}); |
380 | 7.32k | linearTupleTypes.emplace_back(linearMapType); |
381 | 7.32k | } |
382 | 7.35k | } |
383 | 100k | } |
384 | 8.33k | } |
385 | | |
386 | 8.60k | linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)}); |
387 | 8.60k | } |
388 | | |
389 | | // Print generated linear map structs and branching trace enums. |
390 | | // These declarations do not show up with `-emit-sil` because they are |
391 | | // implicit. Instead, use `-Xllvm -debug-only=differentiation` to test |
392 | | // declarations with FileCheck. |
393 | 6.61k | LLVM_DEBUG({ |
394 | 6.61k | auto &s = getADDebugStream(); |
395 | 6.61k | PrintOptions printOptions; |
396 | 6.61k | printOptions.TypeDefinitions = true; |
397 | 6.61k | printOptions.ExplodePatternBindingDecls = true; |
398 | 6.61k | printOptions.SkipImplicit = false; |
399 | 6.61k | s << "Generated linear map tuples and branching trace enums for @" |
400 | 6.61k | << original->getName() << ":\n"; |
401 | 6.61k | for (auto &origBB : *original) { |
402 | 6.61k | auto *linearMapTuple = getLinearMapTupleType(&origBB); |
403 | 6.61k | linearMapTuple->print(s, printOptions); |
404 | 6.61k | s << '\n'; |
405 | 6.61k | } |
406 | | |
407 | 6.61k | for (auto &origBB : *original) { |
408 | 6.61k | auto *traceEnum = getBranchingTraceDecl(&origBB); |
409 | 6.61k | traceEnum->print(s, printOptions); |
410 | 6.61k | s << '\n'; |
411 | 6.61k | } |
412 | 6.61k | }); |
413 | 6.61k | } |
414 | | |
415 | | /// Returns a flag that indicates whether the `apply` instruction should be |
416 | | /// differentiated, given the differentiation indices of the instruction's |
417 | | /// parent function. Whether the `apply` should be differentiated is determined |
418 | | /// sequentially from the following conditions: |
419 | | /// 1. The instruction has an active `inout` argument. |
420 | | /// 2. The instruction is a call to the array literal initialization intrinsic |
421 | | /// ("array.uninitialized_intrinsic"), where the result is active and where |
422 | | /// there is a `store` of an active value into the array's buffer. |
423 | | /// 3. The instruction has both an active result (direct or indirect) and an |
424 | | /// active argument. |
425 | 42.6k | bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) { |
426 | | // Function applications with an active inout argument should be |
427 | | // differentiated. |
428 | 42.6k | for (auto inoutArg : applySite.getInoutArguments()) |
429 | 2.80k | if (activityInfo.isActive(inoutArg, config)) |
430 | 2.28k | return true; |
431 | | |
432 | 40.3k | bool hasActiveDirectResults = false; |
433 | 40.3k | forEachApplyDirectResult(applySite, [&](SILValue directResult) { |
434 | 29.1k | hasActiveDirectResults |= activityInfo.isActive(directResult, config); |
435 | 29.1k | }); |
436 | 40.3k | bool hasActiveIndirectResults = |
437 | 40.3k | llvm::any_of(applySite.getIndirectSILResults(), [&](SILValue result) { |
438 | 12.2k | return activityInfo.isActive(result, config); |
439 | 12.2k | }); |
440 | 40.3k | bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults; |
441 | | |
442 | | // TODO: Pattern match to make sure there is at least one `store` to the |
443 | | // array's active buffer. |
444 | 40.3k | if (ArraySemanticsCall(applySite.getInstruction(), |
445 | 40.3k | semantics::ARRAY_UNINITIALIZED_INTRINSIC) && |
446 | 40.3k | hasActiveResults) |
447 | 868 | return true; |
448 | | |
449 | 39.5k | auto arguments = applySite.getArgumentsWithoutIndirectResults(); |
450 | 53.7k | bool hasActiveArguments = llvm::any_of(arguments, [&](SILValue arg) { |
451 | 53.7k | return activityInfo.isActive(arg, config); |
452 | 53.7k | }); |
453 | 39.5k | return hasActiveResults && hasActiveArguments; |
454 | 40.3k | } |
455 | | |
456 | | static bool shouldDifferentiateInjectEnumAddr( |
457 | | const InjectEnumAddrInst &inject, |
458 | | const DifferentiableActivityInfo &activityInfo, |
459 | 12 | const AutoDiffConfig &config) { |
460 | 12 | SILValue en = inject.getOperand(); |
461 | 24 | for (auto use : en->getUses()) { |
462 | 24 | auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser()); |
463 | 24 | if (init && activityInfo.isActive(init, config)) |
464 | 8 | return true; |
465 | 24 | } |
466 | 4 | return false; |
467 | 12 | } |
468 | | |
469 | | /// Returns a flag indicating whether the instruction should be differentiated, |
470 | | /// given the differentiation indices of the instruction's parent function. |
471 | | /// Whether the instruction should be differentiated is determined sequentially |
472 | | /// from any of the following conditions: |
473 | | /// 1. The instruction is a full apply site and `shouldDifferentiateApplyInst` |
474 | | /// returns true. |
475 | | /// 2. The instruction has a source operand and a destination operand, both |
476 | | /// being active. |
477 | | /// 3. The instruction is an allocation instruction and has an active result. |
478 | | /// 4. The instruction performs reference counting, lifetime ending, access |
479 | | /// ending, or destroying on an active operand. |
480 | | /// 5. The instruction creates an SSA copy of an active operand. |
481 | 108k | bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) { |
482 | | // A full apply site with an active argument and an active result (direct or |
483 | | // indirect) should be differentiated. |
484 | 108k | if (FullApplySite::isa(inst)) |
485 | 11.7k | return shouldDifferentiateApplySite(FullApplySite(inst)); |
486 | | // Anything with an active result and an active operand should be |
487 | | // differentiated. |
488 | 97.1k | auto hasActiveOperands = |
489 | 97.1k | llvm::any_of(inst->getAllOperands(), [&](Operand &op) { |
490 | 65.5k | return activityInfo.isActive(op.get(), config); |
491 | 65.5k | }); |
492 | 97.1k | auto hasActiveResults = llvm::any_of(inst->getResults(), [&](SILValue val) { |
493 | 51.5k | return activityInfo.isActive(val, config); |
494 | 51.5k | }); |
495 | 97.1k | if (hasActiveOperands && hasActiveResults) |
496 | 14.1k | return true; |
497 | | // `store`-like instructions do not have an SSA result, but have two |
498 | | // operands that represent the source and the destination. We treat them as |
499 | | // the input and the output, respectively. |
500 | | // For `store`-like instructions whose destination is an element address |
501 | | // from an `array.uninitialized_intrinsic` application, return true if the |
502 | | // intrinsic application (representing the semantic destination) is active. |
503 | 82.9k | #define CHECK_INST_TYPE_ACTIVE_DEST(INST) \ |
504 | 311k | if (auto *castInst = dyn_cast<INST##Inst>(inst)) \ |
505 | 311k | return activityInfo.isActive(castInst->getDest(), config); |
506 | 82.9k | CHECK_INST_TYPE_ACTIVE_DEST(Store) |
507 | 77.1k | CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrow) |
508 | 77.0k | CHECK_INST_TYPE_ACTIVE_DEST(CopyAddr) |
509 | 74.3k | CHECK_INST_TYPE_ACTIVE_DEST(UnconditionalCheckedCastAddr) |
510 | 74.2k | #undef CHECK_INST_TYPE_ACTIVE_DEST |
511 | | // Should differentiate any allocation instruction that has an active result. |
512 | 74.2k | if ((isa<AllocationInst>(inst) && hasActiveResults)) |
513 | 6.83k | return true; |
514 | 67.4k | if (hasActiveOperands) { |
515 | | // Should differentiate any instruction that performs reference counting, |
516 | | // lifetime ending, access ending, or destroying on an active operand. |
517 | 29.7k | if (isa<RefCountingInst>(inst) || isa<EndAccessInst>(inst) || |
518 | 29.7k | isa<EndBorrowInst>(inst) || isa<DeallocationInst>(inst) || |
519 | 29.7k | isa<DestroyValueInst>(inst) || isa<DestroyAddrInst>(inst)) |
520 | 14.6k | return true; |
521 | 29.7k | } |
522 | | |
523 | | // Should differentiate `inject_enum_addr` if the corresponding |
524 | | // `init_enum_addr` has an active operand. |
525 | 52.8k | if (auto inject = dyn_cast<InjectEnumAddrInst>(inst)) |
526 | 12 | if (shouldDifferentiateInjectEnumAddr(*inject, activityInfo, config)) |
527 | 8 | return true; |
528 | | |
529 | 52.8k | return false; |
530 | 52.8k | } |
531 | | |
532 | | } // end namespace autodiff |
533 | | } // end namespace swift |