/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/ADContext.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- ADContext.cpp - Differentiation Context --------------*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // Per-module contextual information for the differentiation transform. |
14 | | // |
15 | | //===----------------------------------------------------------------------===// |
16 | | |
17 | | #define DEBUG_TYPE "differentiation" |
18 | | |
19 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
20 | | #include "swift/AST/DiagnosticsSIL.h" |
21 | | #include "swift/AST/SourceFile.h" |
22 | | #include "swift/SILOptimizer/PassManager/Transforms.h" |
23 | | |
24 | | using llvm::DenseMap; |
25 | | using llvm::SmallPtrSet; |
26 | | using llvm::SmallVector; |
27 | | |
28 | | namespace swift { |
29 | | namespace autodiff { |
30 | | |
31 | | //===----------------------------------------------------------------------===// |
32 | | // Local helpers |
33 | | //===----------------------------------------------------------------------===// |
34 | | |
35 | | /// Given an operator name, such as '+', and a protocol, returns the '+' |
36 | | /// operator. If the operator does not exist in the protocol, returns null. |
37 | | static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName, |
38 | 404 | ProtocolDecl *protocol) { |
39 | 404 | assert(operatorName.isOperator()); |
40 | | // Find the operator requirement in the given protocol declaration. |
41 | 0 | auto opLookup = protocol->lookupDirect(operatorName); |
42 | 404 | for (auto *decl : opLookup) { |
43 | 404 | if (!decl->isProtocolRequirement()) |
44 | 0 | continue; |
45 | 404 | auto *fd = dyn_cast<FuncDecl>(decl); |
46 | 404 | if (!fd || !fd->isStatic() || !fd->isOperator()) |
47 | 0 | continue; |
48 | 404 | return fd; |
49 | 404 | } |
50 | | // Not found. |
51 | 0 | return nullptr; |
52 | 404 | } |
53 | | |
54 | | //===----------------------------------------------------------------------===// |
55 | | // ADContext methods |
56 | | //===----------------------------------------------------------------------===// |
57 | | |
58 | | ADContext::ADContext(SILModuleTransform &transform) |
59 | | : transform(transform), module(*transform.getModule()), |
60 | 24.3k | passManager(*transform.getPassManager()) {} |
61 | | |
62 | | /// Get the source file for the given `SILFunction`. |
63 | 6.61k | static SourceFile &getSourceFile(SILFunction *f) { |
64 | 6.61k | if (f->hasLocation()) |
65 | 6.61k | if (auto *declContext = f->getLocation().getAsDeclContext()) |
66 | 6.49k | if (auto *parentSourceFile = declContext->getParentSourceFile()) |
67 | 6.49k | return *parentSourceFile; |
68 | 120 | for (auto *file : f->getModule().getSwiftModule()->getFiles()) |
69 | 120 | if (auto *sourceFile = dyn_cast<SourceFile>(file)) |
70 | 120 | return *sourceFile; |
71 | 0 | llvm_unreachable("Could not resolve SourceFile from SILFunction"); |
72 | 0 | } |
73 | | |
74 | | SynthesizedFileUnit & |
75 | 6.61k | ADContext::getOrCreateSynthesizedFile(SILFunction *original) { |
76 | 6.61k | auto &SF = getSourceFile(original); |
77 | 6.61k | return SF.getOrCreateSynthesizedFile(); |
78 | 6.61k | } |
79 | | |
80 | 2.34k | FuncDecl *ADContext::getPlusDecl() const { |
81 | 2.34k | if (!cachedPlusFn) { |
82 | 128 | cachedPlusFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+"), |
83 | 128 | additiveArithmeticProtocol); |
84 | 128 | assert(cachedPlusFn && "AdditiveArithmetic.+ not found"); |
85 | 128 | } |
86 | 0 | return cachedPlusFn; |
87 | 2.34k | } |
88 | | |
89 | 8.19k | FuncDecl *ADContext::getPlusEqualDecl() const { |
90 | 8.19k | if (!cachedPlusEqualFn) { |
91 | 276 | cachedPlusEqualFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+="), |
92 | 276 | additiveArithmeticProtocol); |
93 | 276 | assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found"); |
94 | 276 | } |
95 | 0 | return cachedPlusEqualFn; |
96 | 8.19k | } |
97 | | |
98 | 19.5k | AccessorDecl *ADContext::getAdditiveArithmeticZeroGetter() const { |
99 | 19.5k | if (cachedZeroGetter) |
100 | 19.2k | return cachedZeroGetter; |
101 | 296 | auto zeroDeclLookup = getAdditiveArithmeticProtocol() |
102 | 296 | ->lookupDirect(getASTContext().Id_zero); |
103 | 296 | auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front()); |
104 | 296 | assert(zeroDecl->isProtocolRequirement()); |
105 | 0 | cachedZeroGetter = zeroDecl->getOpaqueAccessor(AccessorKind::Get); |
106 | 296 | return cachedZeroGetter; |
107 | 19.5k | } |
108 | | |
109 | 24 | void ADContext::cleanUp() { |
110 | | // Delete all references to generated functions. |
111 | 996 | for (auto fnRef : generatedFunctionReferences) { |
112 | 996 | if (auto *fnRefInst = |
113 | 996 | peerThroughFunctionConversions<FunctionRefInst>(fnRef)) { |
114 | 4 | fnRefInst->replaceAllUsesWithUndef(); |
115 | 4 | fnRefInst->eraseFromParent(); |
116 | 4 | } |
117 | 996 | } |
118 | | // Delete all generated functions. |
119 | 1.78k | for (auto *generatedFunction : generatedFunctions) { |
120 | 1.78k | LLVM_DEBUG(getADDebugStream() << "Deleting generated function " |
121 | 1.78k | << generatedFunction->getName() << '\n'); |
122 | 1.78k | generatedFunction->dropAllReferences(); |
123 | 1.78k | transform.notifyWillDeleteFunction(generatedFunction); |
124 | 1.78k | module.eraseFunction(generatedFunction); |
125 | 1.78k | } |
126 | 24 | } |
127 | | |
128 | | DifferentiableFunctionInst *ADContext::createDifferentiableFunction( |
129 | | SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices, |
130 | | IndexSubset *resultIndices, SILValue original, |
131 | 19.0k | llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions) { |
132 | 19.0k | auto *dfi = builder.createDifferentiableFunction( |
133 | 19.0k | loc, parameterIndices, resultIndices, original, derivativeFunctions); |
134 | 19.0k | processedDifferentiableFunctionInsts.erase(dfi); |
135 | 19.0k | return dfi; |
136 | 19.0k | } |
137 | | |
138 | | LinearFunctionInst *ADContext::createLinearFunction( |
139 | | SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices, |
140 | 12 | SILValue original, llvm::Optional<SILValue> transposeFunction) { |
141 | 12 | auto *lfi = builder.createLinearFunction(loc, parameterIndices, original, |
142 | 12 | transposeFunction); |
143 | 12 | processedLinearFunctionInsts.erase(lfi); |
144 | 12 | return lfi; |
145 | 12 | } |
146 | | |
147 | | DifferentiableFunctionExpr * |
148 | 44 | ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) { |
149 | 44 | return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>(); |
150 | 44 | } |
151 | | |
152 | | LinearFunctionExpr * |
153 | 0 | ADContext::findDifferentialOperator(LinearFunctionInst *inst) { |
154 | 0 | return inst->getLoc().getAsASTNode<LinearFunctionExpr>(); |
155 | 0 | } |
156 | | |
157 | | } // end namespace autodiff |
158 | | } // end namespace swift |