/Volumes/compiler/apple/swift/lib/AST/AutoDiff.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- AutoDiff.cpp - Swift automatic differentiation utilities ---------===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | |
13 | | #include "swift/AST/AutoDiff.h" |
14 | | #include "swift/AST/ASTContext.h" |
15 | | #include "swift/AST/GenericEnvironment.h" |
16 | | #include "swift/AST/ImportCache.h" |
17 | | #include "swift/AST/Module.h" |
18 | | #include "swift/AST/TypeCheckRequests.h" |
19 | | #include "swift/AST/Types.h" |
20 | | |
21 | | using namespace swift; |
22 | | |
23 | | AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind( |
24 | 48 | StringRef string) { |
25 | 48 | llvm::Optional<innerty> result = |
26 | 48 | llvm::StringSwitch<llvm::Optional<innerty>>(string) |
27 | 48 | .Case("jvp", JVP) |
28 | 48 | .Case("vjp", VJP); |
29 | 48 | assert(result && "Invalid string"); |
30 | 0 | rawValue = *result; |
31 | 48 | } |
32 | | |
33 | | NormalDifferentiableFunctionTypeComponent:: |
34 | | NormalDifferentiableFunctionTypeComponent( |
35 | 6.63k | AutoDiffDerivativeFunctionKind kind) { |
36 | 6.63k | switch (kind) { |
37 | 3.31k | case AutoDiffDerivativeFunctionKind::JVP: |
38 | 3.31k | rawValue = JVP; |
39 | 3.31k | return; |
40 | 3.32k | case AutoDiffDerivativeFunctionKind::VJP: |
41 | 3.32k | rawValue = VJP; |
42 | 3.32k | return; |
43 | 6.63k | } |
44 | 6.63k | } |
45 | | |
46 | | NormalDifferentiableFunctionTypeComponent:: |
47 | 144 | NormalDifferentiableFunctionTypeComponent(StringRef string) { |
48 | 144 | llvm::Optional<innerty> result = |
49 | 144 | llvm::StringSwitch<llvm::Optional<innerty>>(string) |
50 | 144 | .Case("original", Original) |
51 | 144 | .Case("jvp", JVP) |
52 | 144 | .Case("vjp", VJP); |
53 | 144 | assert(result && "Invalid string"); |
54 | 0 | rawValue = *result; |
55 | 144 | } |
56 | | |
57 | | llvm::Optional<AutoDiffDerivativeFunctionKind> |
58 | 29.4k | NormalDifferentiableFunctionTypeComponent::getAsDerivativeFunctionKind() const { |
59 | 29.4k | switch (rawValue) { |
60 | 3.19k | case Original: |
61 | 3.19k | return llvm::None; |
62 | 10.8k | case JVP: |
63 | 10.8k | return {AutoDiffDerivativeFunctionKind::JVP}; |
64 | 15.4k | case VJP: |
65 | 15.4k | return {AutoDiffDerivativeFunctionKind::VJP}; |
66 | 29.4k | } |
67 | 0 | llvm_unreachable("invalid derivative kind"); |
68 | 0 | } |
69 | | |
70 | | LinearDifferentiableFunctionTypeComponent:: |
71 | 24 | LinearDifferentiableFunctionTypeComponent(StringRef string) { |
72 | 24 | llvm::Optional<innerty> result = |
73 | 24 | llvm::StringSwitch<llvm::Optional<innerty>>(string) |
74 | 24 | .Case("original", Original) |
75 | 24 | .Case("transpose", Transpose); |
76 | 24 | assert(result && "Invalid string"); |
77 | 0 | rawValue = *result; |
78 | 24 | } |
79 | | |
80 | | DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind( |
81 | 208 | StringRef string) { |
82 | 208 | llvm::Optional<innerty> result = |
83 | 208 | llvm::StringSwitch<llvm::Optional<innerty>>(string) |
84 | 208 | .Case("jvp", JVP) |
85 | 208 | .Case("vjp", VJP) |
86 | 208 | .Case("transpose", Transpose); |
87 | 208 | assert(result && "Invalid string"); |
88 | 0 | rawValue = *result; |
89 | 208 | } |
90 | | |
91 | | llvm::Optional<AutoDiffDerivativeFunctionKind> |
92 | 85.4k | DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const { |
93 | 85.4k | switch (rawValue) { |
94 | 42.4k | case JVP: |
95 | 42.4k | return {AutoDiffDerivativeFunctionKind::JVP}; |
96 | 42.6k | case VJP: |
97 | 42.6k | return {AutoDiffDerivativeFunctionKind::VJP}; |
98 | 336 | case Transpose: |
99 | 336 | return llvm::None; |
100 | 85.4k | } |
101 | 0 | llvm_unreachable("invalid derivative kind"); |
102 | 0 | } |
103 | | |
104 | 616 | void AutoDiffConfig::print(llvm::raw_ostream &s) const { |
105 | 616 | s << "(parameters="; |
106 | 616 | parameterIndices->print(s); |
107 | 616 | s << " results="; |
108 | 616 | resultIndices->print(s); |
109 | 616 | if (derivativeGenericSignature) { |
110 | 112 | s << " where="; |
111 | 112 | derivativeGenericSignature->print(s); |
112 | 112 | } |
113 | 616 | s << ')'; |
114 | 616 | } |
115 | | |
116 | 70.9k | bool swift::isDifferentiableProgrammingEnabled(SourceFile &SF) { |
117 | 70.9k | auto &ctx = SF.getASTContext(); |
118 | | // Return true if differentiable programming is explicitly enabled. |
119 | 70.9k | if (ctx.LangOpts.hasFeature(Feature::DifferentiableProgramming)) |
120 | 0 | return true; |
121 | | // Otherwise, return true iff the `_Differentiation` module is imported in |
122 | | // the given source file. |
123 | 70.9k | bool importsDifferentiationModule = false; |
124 | 501k | for (auto import : namelookup::getAllImports(&SF)) { |
125 | 501k | if (import.importedModule->getName() == ctx.Id_Differentiation) { |
126 | 18.5k | importsDifferentiationModule = true; |
127 | 18.5k | break; |
128 | 18.5k | } |
129 | 501k | } |
130 | 70.9k | return importsDifferentiationModule; |
131 | 70.9k | } |
132 | | |
133 | | // TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at |
134 | | // most once (for curried method types) is sufficient. |
135 | | static void unwrapCurryLevels(AnyFunctionType *fnTy, |
136 | 39.0k | SmallVectorImpl<AnyFunctionType *> &results) { |
137 | 89.5k | while (fnTy != nullptr) { |
138 | 50.4k | results.push_back(fnTy); |
139 | 50.4k | fnTy = fnTy->getResult()->getAs<AnyFunctionType>(); |
140 | 50.4k | } |
141 | 39.0k | } |
142 | | |
143 | 19.3k | static unsigned countNumFlattenedElementTypes(Type type) { |
144 | 19.3k | if (auto *tupleTy = type->getCanonicalType()->getAs<TupleType>()) |
145 | 20 | return accumulate(tupleTy->getElementTypes(), 0, |
146 | 40 | [&](unsigned num, Type type) { |
147 | 40 | return num + countNumFlattenedElementTypes(type); |
148 | 40 | }); |
149 | 19.3k | return 1; |
150 | 19.3k | } |
151 | | |
152 | | // TODO(TF-874): Simplify this helper and remove the `reverseCurryLevels` flag. |
153 | | void AnyFunctionType::getSubsetParameters( |
154 | | IndexSubset *parameterIndices, |
155 | 29.7k | SmallVectorImpl<AnyFunctionType::Param> &results, bool reverseCurryLevels) { |
156 | 29.7k | SmallVector<AnyFunctionType *, 2> curryLevels; |
157 | 29.7k | unwrapCurryLevels(this, curryLevels); |
158 | | |
159 | 29.7k | SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size()); |
160 | 29.7k | unsigned currentOffset = 0; |
161 | 34.1k | for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) { |
162 | 34.1k | curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset; |
163 | 34.1k | currentOffset += curryLevels[curryLevelIndex]->getNumParams(); |
164 | 34.1k | } |
165 | | |
166 | | // If `reverseCurryLevels` is true, reverse the curry levels and offsets. |
167 | 29.7k | if (reverseCurryLevels) { |
168 | 25.0k | std::reverse(curryLevels.begin(), curryLevels.end()); |
169 | 25.0k | std::reverse(curryLevelParameterIndexOffsets.begin(), |
170 | 25.0k | curryLevelParameterIndexOffsets.end()); |
171 | 25.0k | } |
172 | | |
173 | 34.1k | for (unsigned curryLevelIndex : indices(curryLevels)) { |
174 | 34.1k | auto *curryLevel = curryLevels[curryLevelIndex]; |
175 | 34.1k | unsigned parameterIndexOffset = |
176 | 34.1k | curryLevelParameterIndexOffsets[curryLevelIndex]; |
177 | 34.1k | for (unsigned paramIndex : range(curryLevel->getNumParams())) |
178 | 45.3k | if (parameterIndices->contains(parameterIndexOffset + paramIndex)) |
179 | 41.8k | results.push_back(curryLevel->getParams()[paramIndex]); |
180 | 34.1k | } |
181 | 29.7k | } |
182 | | |
183 | | void autodiff::getFunctionSemanticResults( |
184 | | const AnyFunctionType *functionType, |
185 | | const IndexSubset *parameterIndices, |
186 | 39.9k | SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) { |
187 | 39.9k | auto &ctx = functionType->getASTContext(); |
188 | | |
189 | | // Collect formal result type as a semantic result, unless it is |
190 | | // `Void`. |
191 | 39.9k | auto formalResultType = functionType->getResult(); |
192 | 39.9k | if (auto *resultFunctionType = |
193 | 39.9k | functionType->getResult()->getAs<AnyFunctionType>()) |
194 | 11.2k | formalResultType = resultFunctionType->getResult(); |
195 | | |
196 | 39.9k | unsigned resultIdx = 0; |
197 | 39.9k | if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) { |
198 | | // Separate tuple elements into individual results. |
199 | 38.1k | if (formalResultType->is<TupleType>()) { |
200 | 416 | for (auto elt : formalResultType->castTo<TupleType>()->getElements()) { |
201 | 416 | resultTypes.emplace_back(elt.getType(), resultIdx++, |
202 | 416 | /*isParameter*/ false); |
203 | 416 | } |
204 | 37.9k | } else { |
205 | 37.9k | resultTypes.emplace_back(formalResultType, resultIdx++, |
206 | 37.9k | /*isParameter*/ false); |
207 | 37.9k | } |
208 | 38.1k | } |
209 | | |
210 | | // Collect wrt semantic result (`inout`) parameters as |
211 | | // semantic results |
212 | 39.9k | auto collectSemanticResults = [&](const AnyFunctionType *functionType, |
213 | 51.1k | unsigned curryOffset = 0) { |
214 | 63.8k | for (auto paramAndIndex : enumerate(functionType->getParams())) { |
215 | 63.8k | if (!paramAndIndex.value().isAutoDiffSemanticResult()) |
216 | 61.7k | continue; |
217 | | |
218 | 2.12k | unsigned idx = paramAndIndex.index() + curryOffset; |
219 | 2.12k | assert(idx < parameterIndices->getCapacity() && |
220 | 2.12k | "invalid parameter index"); |
221 | 2.12k | if (parameterIndices->contains(idx)) |
222 | 2.00k | resultTypes.emplace_back(paramAndIndex.value().getPlainType(), |
223 | 2.00k | resultIdx, /*isParameter*/ true); |
224 | 2.12k | resultIdx += 1; |
225 | 2.12k | } |
226 | 51.1k | }; |
227 | | |
228 | 39.9k | if (auto *resultFnType = |
229 | 39.9k | functionType->getResult()->getAs<AnyFunctionType>()) { |
230 | | // Here we assume that the input is a function type with curried `Self` |
231 | 11.2k | assert(functionType->getNumParams() == 1 && "unexpected function type"); |
232 | | |
233 | 0 | collectSemanticResults(resultFnType); |
234 | 11.2k | collectSemanticResults(functionType, resultFnType->getNumParams()); |
235 | 11.2k | } else |
236 | 28.7k | collectSemanticResults(functionType); |
237 | 39.9k | } |
238 | | |
239 | | IndexSubset * |
240 | | autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType, |
241 | 10.7k | const IndexSubset *parameterIndices) { |
242 | 10.7k | auto &ctx = functionType->getASTContext(); |
243 | | |
244 | 10.7k | SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults; |
245 | 10.7k | autodiff::getFunctionSemanticResults(functionType, parameterIndices, |
246 | 10.7k | semanticResults); |
247 | 10.7k | SmallVector<unsigned> resultIndices; |
248 | 10.7k | unsigned cap = 0; |
249 | 11.0k | for (const auto& result : semanticResults) { |
250 | 11.0k | resultIndices.push_back(result.index); |
251 | 11.0k | cap = std::max(cap, result.index + 1U); |
252 | 11.0k | } |
253 | | |
254 | 10.7k | return IndexSubset::get(ctx, cap, resultIndices); |
255 | 10.7k | } |
256 | | |
257 | | IndexSubset * |
258 | | autodiff::getFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD, |
259 | 7.34k | const IndexSubset *parameterIndices) { |
260 | 7.34k | return getFunctionSemanticResultIndices(AFD->getInterfaceType()->castTo<AnyFunctionType>(), |
261 | 7.34k | parameterIndices); |
262 | 7.34k | } |
263 | | |
264 | | // TODO(TF-874): Simplify this helper. See TF-874 for WIP. |
265 | | IndexSubset * |
266 | | autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices, |
267 | 9.36k | AnyFunctionType *functionType) { |
268 | 9.36k | SmallVector<AnyFunctionType *, 2> curryLevels; |
269 | 9.36k | unwrapCurryLevels(functionType, curryLevels); |
270 | | |
271 | | // Compute the lowered sizes of all AST parameter types. |
272 | 9.36k | SmallVector<unsigned, 8> paramLoweredSizes; |
273 | 9.36k | unsigned totalLoweredSize = 0; |
274 | 19.3k | auto addLoweredParamInfo = [&](Type type) { |
275 | 19.3k | unsigned paramLoweredSize = countNumFlattenedElementTypes(type); |
276 | 19.3k | paramLoweredSizes.push_back(paramLoweredSize); |
277 | 19.3k | totalLoweredSize += paramLoweredSize; |
278 | 19.3k | }; |
279 | 9.36k | for (auto *curryLevel : llvm::reverse(curryLevels)) |
280 | 16.3k | for (auto ¶m : curryLevel->getParams()) |
281 | 19.3k | addLoweredParamInfo(param.getPlainType()); |
282 | | |
283 | | // Build lowered SIL parameter indices by setting the range of bits that |
284 | | // corresponds to each "set" AST parameter. |
285 | 9.36k | llvm::SmallVector<unsigned, 8> loweredSILIndices; |
286 | 9.36k | unsigned currentBitIndex = 0; |
287 | 19.3k | for (unsigned i : range(parameterIndices->getCapacity())) { |
288 | 19.3k | auto paramLoweredSize = paramLoweredSizes[i]; |
289 | 19.3k | if (parameterIndices->contains(i)) { |
290 | 14.2k | auto indices = range(currentBitIndex, currentBitIndex + paramLoweredSize); |
291 | 14.2k | loweredSILIndices.append(indices.begin(), indices.end()); |
292 | 14.2k | } |
293 | 19.3k | currentBitIndex += paramLoweredSize; |
294 | 19.3k | } |
295 | | |
296 | 9.36k | return IndexSubset::get(functionType->getASTContext(), totalLoweredSize, |
297 | 9.36k | loweredSILIndices); |
298 | 9.36k | } |
299 | | |
300 | | /// Collects the semantic results of the given function type in |
301 | | /// `originalResults`. The semantic results are formal results followed by |
302 | | /// semantic result parameters, in type order. |
303 | | void |
304 | | autodiff::getSemanticResults(SILFunctionType *functionType, |
305 | | IndexSubset *parameterIndices, |
306 | 29.5k | SmallVectorImpl<SILResultInfo> &originalResults) { |
307 | | // Collect original formal results. |
308 | 29.5k | originalResults.append(functionType->getResults().begin(), |
309 | 29.5k | functionType->getResults().end()); |
310 | | |
311 | | // Collect original semantic result parameters. |
312 | 51.3k | for (auto i : range(functionType->getNumParameters())) { |
313 | 51.3k | auto param = functionType->getParameters()[i]; |
314 | 51.3k | if (!param.isAutoDiffSemanticResult()) |
315 | 49.3k | continue; |
316 | 1.94k | if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) |
317 | 1.94k | originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect); |
318 | 1.94k | } |
319 | 29.5k | } |
320 | | |
321 | | GenericSignature autodiff::getConstrainedDerivativeGenericSignature( |
322 | | SILFunctionType *originalFnTy, |
323 | | IndexSubset *diffParamIndices, IndexSubset *diffResultIndices, |
324 | | GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance, |
325 | 11.0k | bool isTranspose) { |
326 | 11.0k | if (!derivativeGenSig) |
327 | 3.54k | derivativeGenSig = originalFnTy->getInvocationGenericSignature(); |
328 | 11.0k | if (!derivativeGenSig) |
329 | 3.02k | return nullptr; |
330 | 8.06k | auto &ctx = originalFnTy->getASTContext(); |
331 | 8.06k | auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
332 | 8.06k | SmallVector<Requirement, 4> requirements; |
333 | | |
334 | 19.6k | auto addRequirement = [&](CanType type) { |
335 | 19.6k | Requirement req(RequirementKind::Conformance, type, |
336 | 19.6k | diffableProto->getDeclaredInterfaceType()); |
337 | 19.6k | requirements.push_back(req); |
338 | 19.6k | if (isTranspose) { |
339 | | // Require linearity parameters to additionally satisfy |
340 | | // `Self == Self.TangentVector`. |
341 | 90 | auto tanSpace = type->getAutoDiffTangentSpace(lookupConformance); |
342 | 90 | auto tanType = tanSpace->getCanonicalType(); |
343 | 90 | Requirement req(RequirementKind::SameType, type, tanType); |
344 | 90 | requirements.push_back(req); |
345 | 90 | } |
346 | 19.6k | }; |
347 | | |
348 | | // Require differentiability parameters to conform to `Differentiable`. |
349 | 11.6k | for (unsigned paramIdx : diffParamIndices->getIndices()) { |
350 | 11.6k | auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType(); |
351 | 11.6k | addRequirement(paramType); |
352 | 11.6k | } |
353 | | |
354 | | // Require differentiability results to conform to `Differentiable`. |
355 | 8.06k | SmallVector<SILResultInfo, 2> originalResults; |
356 | 8.06k | getSemanticResults(originalFnTy, diffParamIndices, originalResults); |
357 | 8.06k | for (unsigned resultIdx : diffResultIndices->getIndices()) { |
358 | | // Handle formal original result. |
359 | 8.03k | if (resultIdx < originalFnTy->getNumResults()) { |
360 | 7.58k | auto resultType = originalResults[resultIdx].getInterfaceType(); |
361 | 7.58k | addRequirement(resultType); |
362 | 7.58k | continue; |
363 | 7.58k | } |
364 | | // Handle original semantic result parameters. |
365 | | // FIXME: Constraint generic yields when we will start supporting them |
366 | 456 | auto resultParamIndex = resultIdx - originalFnTy->getNumResults(); |
367 | 456 | auto resultParamIt = std::next( |
368 | 456 | originalFnTy->getAutoDiffSemanticResultsParameters().begin(), |
369 | 456 | resultParamIndex); |
370 | 456 | auto paramIndex = |
371 | 456 | std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); |
372 | 456 | addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType()); |
373 | 456 | } |
374 | | |
375 | 8.06k | return buildGenericSignature(ctx, derivativeGenSig, |
376 | 8.06k | /*addedGenericParams*/ {}, |
377 | 8.06k | std::move(requirements)); |
378 | 11.0k | } |
379 | | |
380 | | // Given the rest of a `Builtin.applyDerivative_{jvp|vjp}` or |
381 | | // `Builtin.applyTranspose` operation name, attempts to parse the arity and |
382 | | // throwing-ness from the operation name. Modifies the operation name argument |
383 | | // in place as substrings get dropped. |
384 | | static void parseAutoDiffBuiltinCommonConfig( |
385 | 28 | StringRef &operationName, unsigned &arity, bool &throws) { |
386 | | // Parse '_arity'. |
387 | 28 | constexpr char arityPrefix[] = "_arity"; |
388 | 28 | if (operationName.startswith(arityPrefix)) { |
389 | 8 | operationName = operationName.drop_front(sizeof(arityPrefix) - 1); |
390 | 8 | auto arityStr = operationName.take_while(llvm::isDigit); |
391 | 8 | operationName = operationName.drop_front(arityStr.size()); |
392 | 8 | auto converted = llvm::to_integer(arityStr, arity); |
393 | 8 | assert(converted); (void)converted; |
394 | 8 | assert(arity > 0); |
395 | 20 | } else { |
396 | 20 | arity = 1; |
397 | 20 | } |
398 | | // Parse '_throws'. |
399 | 0 | constexpr char throwsPrefix[] = "_throws"; |
400 | 28 | if (operationName.startswith(throwsPrefix)) { |
401 | 0 | operationName = operationName.drop_front(sizeof(throwsPrefix) - 1); |
402 | 0 | throws = true; |
403 | 28 | } else { |
404 | 28 | throws = false; |
405 | 28 | } |
406 | 28 | } |
407 | | |
408 | | bool autodiff::getBuiltinApplyDerivativeConfig( |
409 | | StringRef operationName, AutoDiffDerivativeFunctionKind &kind, |
410 | 28 | unsigned &arity, bool &throws) { |
411 | 28 | constexpr char prefix[] = "applyDerivative"; |
412 | 28 | if (!operationName.startswith(prefix)) |
413 | 0 | return false; |
414 | 28 | operationName = operationName.drop_front(sizeof(prefix) - 1); |
415 | | // Parse 'jvp' or 'vjp'. |
416 | 28 | constexpr char jvpPrefix[] = "_jvp"; |
417 | 28 | constexpr char vjpPrefix[] = "_vjp"; |
418 | 28 | if (operationName.startswith(jvpPrefix)) |
419 | 8 | kind = AutoDiffDerivativeFunctionKind::JVP; |
420 | 20 | else if (operationName.startswith(vjpPrefix)) |
421 | 20 | kind = AutoDiffDerivativeFunctionKind::VJP; |
422 | 28 | operationName = operationName.drop_front(sizeof(jvpPrefix) - 1); |
423 | 28 | parseAutoDiffBuiltinCommonConfig(operationName, arity, throws); |
424 | 28 | return operationName.empty(); |
425 | 28 | } |
426 | | |
427 | | bool autodiff::getBuiltinApplyTransposeConfig( |
428 | 0 | StringRef operationName, unsigned &arity, bool &throws) { |
429 | 0 | constexpr char prefix[] = "applyTranspose"; |
430 | 0 | if (!operationName.startswith(prefix)) |
431 | 0 | return false; |
432 | 0 | operationName = operationName.drop_front(sizeof(prefix) - 1); |
433 | 0 | parseAutoDiffBuiltinCommonConfig(operationName, arity, throws); |
434 | 0 | return operationName.empty(); |
435 | 0 | } |
436 | | |
437 | | bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig( |
438 | 0 | StringRef operationName, unsigned &arity, bool &throws) { |
439 | 0 | constexpr char differentiablePrefix[] = "differentiableFunction"; |
440 | 0 | constexpr char linearPrefix[] = "linearFunction"; |
441 | 0 | if (operationName.startswith(differentiablePrefix)) |
442 | 0 | operationName = operationName.drop_front(sizeof(differentiablePrefix) - 1); |
443 | 0 | else if (operationName.startswith(linearPrefix)) |
444 | 0 | operationName = operationName.drop_front(sizeof(linearPrefix) - 1); |
445 | 0 | else |
446 | 0 | return false; |
447 | 0 | parseAutoDiffBuiltinCommonConfig(operationName, arity, throws); |
448 | 0 | return operationName.empty(); |
449 | 0 | } |
450 | | |
451 | | GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature( |
452 | 7.50k | GenericSignature origGenSig, GenericSignature derivativeGenSig) { |
453 | | // If there is no derivative generic signature, return the original generic |
454 | | // signature. |
455 | 7.50k | if (!derivativeGenSig) |
456 | 5.17k | return origGenSig; |
457 | | // If derivative generic signature has all concrete generic parameters and is |
458 | | // equal to the original generic signature, return `nullptr`. |
459 | 2.33k | auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature(); |
460 | 2.33k | auto origCanGenSig = origGenSig.getCanonicalSignature(); |
461 | 2.33k | if (origCanGenSig == derivativeCanGenSig && |
462 | 2.33k | derivativeCanGenSig->areAllParamsConcrete()) |
463 | 100 | return GenericSignature(); |
464 | | // Otherwise, return the derivative generic signature. |
465 | 2.23k | return derivativeGenSig; |
466 | 2.33k | } |
467 | | |
468 | 317k | Type TangentSpace::getType() const { |
469 | 317k | switch (kind) { |
470 | 314k | case Kind::TangentVector: |
471 | 314k | return value.tangentVectorType; |
472 | 3.48k | case Kind::Tuple: |
473 | 3.48k | return value.tupleType; |
474 | 317k | } |
475 | 0 | llvm_unreachable("invalid tangent space kind"); |
476 | 0 | } |
477 | | |
478 | 230k | CanType TangentSpace::getCanonicalType() const { |
479 | 230k | return getType()->getCanonicalType(); |
480 | 230k | } |
481 | | |
482 | 0 | NominalTypeDecl *TangentSpace::getNominal() const { |
483 | 0 | assert(isTangentVector()); |
484 | 0 | return getTangentVector()->getNominalOrBoundGenericNominal(); |
485 | 0 | } |
486 | | |
487 | | const char DerivativeFunctionTypeError::ID = '\0'; |
488 | | |
489 | 0 | void DerivativeFunctionTypeError::log(raw_ostream &OS) const { |
490 | 0 | OS << "original function type '"; |
491 | 0 | functionType->print(OS); |
492 | 0 | OS << "' "; |
493 | 0 | switch (kind) { |
494 | 0 | case Kind::NoSemanticResults: |
495 | 0 | OS << "has no semantic results ('Void' result)"; |
496 | 0 | break; |
497 | 0 | case Kind::NoDifferentiabilityParameters: |
498 | 0 | OS << "has no differentiability parameters"; |
499 | 0 | break; |
500 | 0 | case Kind::NonDifferentiableDifferentiabilityParameter: { |
501 | 0 | auto nonDiffParam = getNonDifferentiableTypeAndIndex(); |
502 | 0 | OS << "has non-differentiable differentiability parameter " |
503 | 0 | << nonDiffParam.second << ": " << nonDiffParam.first; |
504 | 0 | break; |
505 | 0 | } |
506 | 0 | case Kind::NonDifferentiableResult: { |
507 | 0 | auto nonDiffResult = getNonDifferentiableTypeAndIndex(); |
508 | 0 | OS << "has non-differentiable result " << nonDiffResult.second << ": " |
509 | 0 | << nonDiffResult.first; |
510 | 0 | break; |
511 | 0 | } |
512 | 0 | } |
513 | 0 | } |
514 | | |
515 | | inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
516 | 0 | const DeclNameRefWithLoc &name) { |
517 | 0 | os << name.Name; |
518 | 0 | if (auto accessorKind = name.AccessorKind) |
519 | 0 | os << '.' << getAccessorLabel(*accessorKind); |
520 | 0 | return os; |
521 | 0 | } |
522 | | |
523 | | bool swift::operator==(const TangentPropertyInfo::Error &lhs, |
524 | 0 | const TangentPropertyInfo::Error &rhs) { |
525 | 0 | if (lhs.kind != rhs.kind) |
526 | 0 | return false; |
527 | 0 | switch (lhs.kind) { |
528 | 0 | case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: |
529 | 0 | case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: |
530 | 0 | case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: |
531 | 0 | case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: |
532 | 0 | case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: |
533 | 0 | case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: |
534 | 0 | return true; |
535 | 0 | case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: |
536 | 0 | return lhs.getType()->isEqual(rhs.getType()); |
537 | 0 | } |
538 | 0 | llvm_unreachable("unhandled tangent property!"); |
539 | 0 | } |
540 | | |
541 | 0 | void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) { |
542 | 0 | os << "{ "; |
543 | 0 | os << "tangent property: " |
544 | 0 | << (info.tangentProperty ? info.tangentProperty->printRef() : "null"); |
545 | 0 | if (info.error) { |
546 | 0 | os << ", error: "; |
547 | 0 | switch (info.error->kind) { |
548 | 0 | case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: |
549 | 0 | os << "'@noDerivative' original property has no tangent property"; |
550 | 0 | break; |
551 | 0 | case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: |
552 | 0 | os << "nominal parent does not conform to 'Differentiable'"; |
553 | 0 | break; |
554 | 0 | case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: |
555 | 0 | os << "original property type does not conform to 'Differentiable'"; |
556 | 0 | break; |
557 | 0 | case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: |
558 | 0 | os << "'TangentVector' type is not a struct"; |
559 | 0 | break; |
560 | 0 | case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: |
561 | 0 | os << "'TangentVector' struct does not have stored property with the " |
562 | 0 | "same name as the original property"; |
563 | 0 | break; |
564 | 0 | case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: |
565 | 0 | os << "tangent property's type is not equal to the original property's " |
566 | 0 | "'TangentVector' type"; |
567 | 0 | break; |
568 | 0 | case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: |
569 | 0 | os << "'TangentVector' property '" << info.tangentProperty->getName() |
570 | 0 | << "' is not a stored property"; |
571 | 0 | break; |
572 | 0 | } |
573 | 0 | } |
574 | 0 | os << " }"; |
575 | 0 | } |
576 | | |
577 | | TangentPropertyInfo TangentStoredPropertyRequest::evaluate( |
578 | 748 | Evaluator &evaluator, VarDecl *originalField, CanType baseType) const { |
579 | 748 | assert(((originalField->hasStorage() && originalField->isInstanceMember()) || |
580 | 748 | originalField->hasAttachedPropertyWrapper()) && |
581 | 748 | "Expected a stored property or a property-wrapped property"); |
582 | 0 | auto *parentDC = originalField->getDeclContext(); |
583 | 748 | assert(parentDC->isTypeContext()); |
584 | 0 | auto *moduleDecl = originalField->getModuleContext(); |
585 | 748 | auto parentTan = |
586 | 748 | baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl)); |
587 | | // Error if parent nominal type does not conform to `Differentiable`. |
588 | 748 | if (!parentTan) { |
589 | 0 | return TangentPropertyInfo( |
590 | 0 | TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable); |
591 | 0 | } |
592 | | // Error if original stored property is `@noDerivative`. |
593 | 748 | if (originalField->getAttrs().hasAttribute<NoDerivativeAttr>()) { |
594 | 0 | return TangentPropertyInfo( |
595 | 0 | TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty); |
596 | 0 | } |
597 | | // Error if original property's type does not conform to `Differentiable`. |
598 | 748 | auto originalFieldType = baseType->getTypeOfMember( |
599 | 748 | originalField->getModuleContext(), originalField); |
600 | 748 | auto originalFieldTan = originalFieldType->getAutoDiffTangentSpace( |
601 | 748 | LookUpConformanceInModule(moduleDecl)); |
602 | 748 | if (!originalFieldTan) { |
603 | 8 | return TangentPropertyInfo( |
604 | 8 | TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable); |
605 | 8 | } |
606 | | // Get the parent `TangentVector` type. |
607 | 740 | auto parentTanType = |
608 | 740 | baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl)) |
609 | 740 | ->getType(); |
610 | 740 | auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct(); |
611 | | // Error if parent `TangentVector` is not a struct. |
612 | 740 | if (!parentTanStruct) { |
613 | 8 | return TangentPropertyInfo( |
614 | 8 | TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct); |
615 | 8 | } |
616 | | // Find the corresponding field in the tangent space. |
617 | 732 | VarDecl *tanField = nullptr; |
618 | | // If `TangentVector` is the original struct, then the tangent property is the |
619 | | // original property. |
620 | 732 | if (parentTanStruct == parentDC->getSelfStructDecl()) { |
621 | 148 | tanField = originalField; |
622 | 148 | } |
623 | | // Otherwise, look up the field by name. |
624 | 584 | else { |
625 | 584 | auto tanFieldLookup = |
626 | 584 | parentTanStruct->lookupDirect(originalField->getName()); |
627 | 584 | llvm::erase_if(tanFieldLookup, |
628 | 584 | [](ValueDecl *v) { return !isa<VarDecl>(v); }); |
629 | | // Error if tangent property could not be found. |
630 | 584 | if (tanFieldLookup.empty()) { |
631 | 12 | return TangentPropertyInfo( |
632 | 12 | TangentPropertyInfo::Error::Kind::TangentPropertyNotFound); |
633 | 12 | } |
634 | 572 | tanField = cast<VarDecl>(tanFieldLookup.front()); |
635 | 572 | } |
636 | | // Error if tangent property's type is not equal to the original property's |
637 | | // `TangentVector` type. |
638 | 720 | auto originalFieldTanType = originalFieldTan->getType(); |
639 | 720 | auto tanFieldType = |
640 | 720 | parentTanType->getTypeOfMember(tanField->getModuleContext(), tanField); |
641 | 720 | if (!originalFieldTanType->isEqual(tanFieldType)) { |
642 | 12 | return TangentPropertyInfo( |
643 | 12 | TangentPropertyInfo::Error::Kind::TangentPropertyWrongType, |
644 | 12 | originalFieldTanType); |
645 | 12 | } |
646 | | // Error if tangent property is not a stored property. |
647 | 708 | if (!tanField->hasStorage()) { |
648 | 12 | return TangentPropertyInfo( |
649 | 12 | TangentPropertyInfo::Error::Kind::TangentPropertyNotStored); |
650 | 12 | } |
651 | | // Otherwise, tangent property is valid. |
652 | 696 | return TangentPropertyInfo(tanField); |
653 | 708 | } |
654 | | |
655 | 0 | void SILDifferentiabilityWitnessKey::print(llvm::raw_ostream &s) const { |
656 | 0 | s << "(original=@" << originalFunctionName << " kind="; |
657 | 0 | switch (kind) { |
658 | 0 | case DifferentiabilityKind::NonDifferentiable: |
659 | 0 | s << "nondifferentiable"; |
660 | 0 | break; |
661 | 0 | case DifferentiabilityKind::Forward: |
662 | 0 | s << "forward"; |
663 | 0 | break; |
664 | 0 | case DifferentiabilityKind::Reverse: |
665 | 0 | s << "reverse"; |
666 | 0 | break; |
667 | 0 | case DifferentiabilityKind::Normal: |
668 | 0 | s << "normal"; |
669 | 0 | break; |
670 | 0 | case DifferentiabilityKind::Linear: |
671 | 0 | s << "linear"; |
672 | 0 | break; |
673 | 0 | } |
674 | 0 | s << " config=" << config << ')'; |
675 | 0 | } |
676 | | |
677 | | Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind( |
678 | 14.4k | AutoDiffDerivativeFunctionKind kind) { |
679 | 14.4k | switch (kind) { |
680 | 7.24k | case AutoDiffDerivativeFunctionKind::JVP: |
681 | 7.24k | return Demangle::AutoDiffFunctionKind::JVP; |
682 | 7.20k | case AutoDiffDerivativeFunctionKind::VJP: return Demangle::AutoDiffFunctionKind::VJP; |
683 | 14.4k | } |
684 | 14.4k | } |
685 | | |
686 | | Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind( |
687 | 7.44k | AutoDiffLinearMapKind kind) { |
688 | 7.44k | switch (kind) { |
689 | 1.72k | case AutoDiffLinearMapKind::Differential: |
690 | 1.72k | return Demangle::AutoDiffFunctionKind::Differential; |
691 | 5.72k | case AutoDiffLinearMapKind::Pullback: |
692 | 5.72k | return Demangle::AutoDiffFunctionKind::Pullback; |
693 | 7.44k | } |
694 | 7.44k | } |
695 | | |
696 | | Demangle::MangledDifferentiabilityKind |
697 | 20.8k | Demangle::getMangledDifferentiabilityKind(DifferentiabilityKind kind) { |
698 | 20.8k | using namespace Demangle; |
699 | 20.8k | switch (kind) { |
700 | 0 | #define SIMPLE_CASE(CASE) \ |
701 | 20.8k | case DifferentiabilityKind::CASE: return MangledDifferentiabilityKind::CASE; |
702 | 0 | SIMPLE_CASE(NonDifferentiable) |
703 | 0 | SIMPLE_CASE(Forward) |
704 | 20.8k | SIMPLE_CASE(Reverse) |
705 | 0 | SIMPLE_CASE(Normal) |
706 | 15 | SIMPLE_CASE(Linear) |
707 | 20.8k | #undef SIMPLE_CASE |
708 | 20.8k | } |
709 | 20.8k | } |