/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/Thunk.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- Thunk.cpp - Automatic differentiation thunks ---------*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // Automatic differentiation thunk generation utilities. |
14 | | // |
15 | | //===----------------------------------------------------------------------===// |
16 | | |
17 | | #define DEBUG_TYPE "differentiation" |
18 | | |
19 | | #include "swift/SILOptimizer/Differentiation/Thunk.h" |
20 | | #include "swift/SILOptimizer/Differentiation/Common.h" |
21 | | |
22 | | #include "swift/AST/AnyFunctionRef.h" |
23 | | #include "swift/AST/Requirement.h" |
24 | | #include "swift/AST/SubstitutionMap.h" |
25 | | #include "swift/AST/TypeCheckRequests.h" |
26 | | #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
27 | | #include "swift/SILOptimizer/Utils/DifferentiationMangler.h" |
28 | | |
29 | | namespace swift { |
30 | | namespace autodiff { |
31 | | |
32 | | //===----------------------------------------------------------------------===// |
33 | | // Thunk helpers |
34 | | //===----------------------------------------------------------------------===// |
35 | | // These helpers are copied/adapted from SILGen. They should be refactored and |
36 | | // moved to a shared location. |
37 | | //===----------------------------------------------------------------------===// |
38 | | |
39 | | CanSILFunctionType buildThunkType(SILFunction *fn, |
40 | | CanSILFunctionType &sourceType, |
41 | | CanSILFunctionType &expectedType, |
42 | | GenericEnvironment *&genericEnv, |
43 | | SubstitutionMap &interfaceSubs, |
44 | | bool withoutActuallyEscaping, |
45 | 6.30k | DifferentiationThunkKind thunkKind) { |
46 | 6.30k | CanType inputSubstType; |
47 | 6.30k | CanType outputSubstType; |
48 | 6.30k | CanType dynamicSelfType; |
49 | 6.30k | return buildSILFunctionThunkType( |
50 | 6.30k | fn, sourceType, expectedType, inputSubstType, outputSubstType, genericEnv, |
51 | 6.30k | interfaceSubs, dynamicSelfType, withoutActuallyEscaping, thunkKind); |
52 | 6.30k | } |
53 | | |
54 | | /// Forward function arguments, handling ownership convention mismatches. |
55 | | /// Adapted from `forwardFunctionArguments` in SILGenPoly.cpp. |
56 | | /// |
57 | | /// Forwarded arguments are appended to `forwardedArgs`. |
58 | | /// |
59 | | /// Local allocations are appended to `localAllocations`. They need to be |
60 | | /// deallocated via `dealloc_stack`. |
61 | | /// |
62 | | /// Local values requiring cleanup are appended to `valuesToCleanup`. |
63 | | static void forwardFunctionArgumentsConvertingOwnership( |
64 | | SILBuilder &builder, SILLocation loc, CanSILFunctionType fromTy, |
65 | | CanSILFunctionType toTy, ArrayRef<SILArgument *> originalArgs, |
66 | | SmallVectorImpl<SILValue> &forwardedArgs, |
67 | | SmallVectorImpl<AllocStackInst *> &localAllocations, |
68 | 940 | SmallVectorImpl<SILValue> &valuesToCleanup) { |
69 | 940 | auto fromParameters = fromTy->getParameters(); |
70 | 940 | auto toParameters = toTy->getParameters(); |
71 | 940 | assert(fromParameters.size() == toParameters.size()); |
72 | 0 | assert(fromParameters.size() == originalArgs.size()); |
73 | 1.04k | for (auto index : indices(originalArgs)) { |
74 | 1.04k | auto &arg = originalArgs[index]; |
75 | 1.04k | auto fromParam = fromParameters[index]; |
76 | 1.04k | auto toParam = toParameters[index]; |
77 | | // To convert guaranteed argument to be owned, create a copy. |
78 | 1.04k | if (fromParam.isConsumed() && !toParam.isConsumed()) { |
79 | | // If the argument has an object type, create a `copy_value`. |
80 | 28 | if (arg->getType().isObject()) { |
81 | 28 | auto argCopy = builder.emitCopyValueOperation(loc, arg); |
82 | 28 | forwardedArgs.push_back(argCopy); |
83 | 28 | continue; |
84 | 28 | } |
85 | | // If the argument has an address type, create a local allocation and |
86 | | // `copy_addr` its contents to the local allocation. |
87 | 0 | auto *alloc = builder.createAllocStack(loc, arg->getType()); |
88 | 0 | builder.createCopyAddr(loc, arg, alloc, IsNotTake, IsInitialization); |
89 | 0 | localAllocations.push_back(alloc); |
90 | 0 | forwardedArgs.push_back(alloc); |
91 | 0 | continue; |
92 | 28 | } |
93 | | // To convert owned argument to be guaranteed, borrow the argument. |
94 | 1.01k | if (fromParam.isGuaranteed() && !toParam.isGuaranteed()) { |
95 | 268 | auto bbi = builder.emitBeginBorrowOperation(loc, arg); |
96 | 268 | forwardedArgs.push_back(bbi); |
97 | 268 | valuesToCleanup.push_back(bbi); |
98 | 268 | valuesToCleanup.push_back(arg); |
99 | 268 | continue; |
100 | 268 | } |
101 | | // Otherwise, simply forward the argument. |
102 | 744 | forwardedArgs.push_back(arg); |
103 | 744 | } |
104 | 940 | } |
105 | | |
106 | | SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb, |
107 | | SILModule &module, SILLocation loc, |
108 | | SILFunction *caller, |
109 | | CanSILFunctionType fromType, |
110 | 3.84k | CanSILFunctionType toType) { |
111 | 3.84k | assert(!fromType->getCombinedSubstitutions()); |
112 | 0 | assert(!toType->getCombinedSubstitutions()); |
113 | | |
114 | 0 | SubstitutionMap interfaceSubs; |
115 | 3.84k | GenericEnvironment *genericEnv = nullptr; |
116 | 3.84k | auto thunkType = |
117 | 3.84k | buildThunkType(caller, fromType, toType, genericEnv, interfaceSubs, |
118 | 3.84k | /*withoutActuallyEscaping*/ false, |
119 | 3.84k | DifferentiationThunkKind::Reabstraction); |
120 | 3.84k | auto thunkDeclType = |
121 | 3.84k | thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false)); |
122 | | |
123 | 3.84k | auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType(); |
124 | 3.84k | auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType(); |
125 | | |
126 | 3.84k | Mangle::ASTMangler mangler; |
127 | 3.84k | std::string name = mangler.mangleReabstractionThunkHelper( |
128 | 3.84k | thunkType, fromInterfaceType, toInterfaceType, Type(), Type(), |
129 | 3.84k | module.getSwiftModule()); |
130 | | |
131 | 3.84k | auto *thunk = fb.getOrCreateSharedFunction( |
132 | 3.84k | loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized, |
133 | 3.84k | ProfileCounter(), IsReabstractionThunk, IsNotDynamic, IsNotDistributed, |
134 | 3.84k | IsNotRuntimeAccessible); |
135 | 3.84k | if (!thunk->empty()) |
136 | 2.90k | return thunk; |
137 | | |
138 | 940 | thunk->setGenericEnvironment(genericEnv); |
139 | 940 | auto *entry = thunk->createBasicBlock(); |
140 | 940 | SILBuilder builder(entry); |
141 | 940 | createEntryArguments(thunk); |
142 | | |
143 | 940 | SILFunctionConventions fromConv(fromType, module); |
144 | 940 | SILFunctionConventions toConv(toType, module); |
145 | 940 | assert(toConv.useLoweredAddresses()); |
146 | | |
147 | | // Forward thunk arguments, handling ownership convention mismatches. |
148 | 0 | SmallVector<SILValue, 4> forwardedArgs; |
149 | 940 | for (auto indRes : thunk->getIndirectResults()) |
150 | 596 | forwardedArgs.push_back(indRes); |
151 | 940 | SmallVector<AllocStackInst *, 4> localAllocations; |
152 | 940 | SmallVector<SILValue, 4> valuesToCleanup; |
153 | 940 | forwardFunctionArgumentsConvertingOwnership( |
154 | 940 | builder, loc, fromType, toType, |
155 | 940 | thunk->getArgumentsWithoutIndirectResults().drop_back(), forwardedArgs, |
156 | 940 | localAllocations, valuesToCleanup); |
157 | | |
158 | 940 | SmallVector<SILValue, 4> arguments; |
159 | 940 | auto toArgIter = forwardedArgs.begin(); |
160 | 964 | auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); }; |
161 | | |
162 | 940 | auto createAllocStack = [&](SILType type) { |
163 | 688 | auto *alloc = builder.createAllocStack(loc, type); |
164 | 688 | localAllocations.push_back(alloc); |
165 | 688 | return alloc; |
166 | 688 | }; |
167 | | |
168 | | // Handle indirect results. |
169 | 940 | assert(fromType->getNumResults() == toType->getNumResults()); |
170 | 1.28k | for (unsigned resIdx : range(toType->getNumResults())) { |
171 | 1.28k | auto fromRes = fromConv.getResults()[resIdx]; |
172 | 1.28k | auto toRes = toConv.getResults()[resIdx]; |
173 | | // No abstraction mismatch. |
174 | 1.28k | if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) { |
175 | | // If result types are indirect, directly pass as next argument. |
176 | 720 | if (toRes.isFormalIndirect()) |
177 | 412 | useNextArgument(); |
178 | 720 | continue; |
179 | 720 | } |
180 | | // Convert indirect result to direct result. |
181 | 560 | if (fromRes.isFormalIndirect()) { |
182 | 376 | SILType resultTy = |
183 | 376 | fromConv.getSILType(fromRes, builder.getTypeExpansionContext()); |
184 | 376 | assert(resultTy.isAddress()); |
185 | 0 | auto *indRes = createAllocStack(resultTy); |
186 | 376 | arguments.push_back(indRes); |
187 | 376 | continue; |
188 | 376 | } |
189 | | // Convert direct result to indirect result. |
190 | | // Increment thunk argument iterator; reabstraction handled later. |
191 | 184 | ++toArgIter; |
192 | 184 | } |
193 | | |
194 | | // Reabstract parameters. |
195 | 940 | assert(toType->getNumParameters() == fromType->getNumParameters()); |
196 | 1.04k | for (unsigned paramIdx : range(toType->getNumParameters())) { |
197 | 1.04k | auto fromParam = fromConv.getParameters()[paramIdx]; |
198 | 1.04k | auto toParam = toConv.getParameters()[paramIdx]; |
199 | | // No abstraction mismatch. Directly use next argument. |
200 | 1.04k | if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) { |
201 | 552 | useNextArgument(); |
202 | 552 | continue; |
203 | 552 | } |
204 | | // Convert indirect parameter to direct parameter. |
205 | 488 | if (fromParam.isFormalIndirect()) { |
206 | 312 | auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx], |
207 | 312 | builder.getTypeExpansionContext()); |
208 | 312 | if (!paramTy.hasArchetype()) |
209 | 308 | paramTy = thunk->mapTypeIntoContext(paramTy); |
210 | 312 | assert(paramTy.isAddress()); |
211 | 0 | auto toArg = *toArgIter++; |
212 | 312 | auto *buf = createAllocStack(toArg->getType()); |
213 | 312 | toArg = builder.emitCopyValueOperation(loc, toArg); |
214 | 312 | builder.emitStoreValueOperation(loc, toArg, buf, |
215 | 312 | StoreOwnershipQualifier::Init); |
216 | 312 | valuesToCleanup.push_back(buf); |
217 | 312 | arguments.push_back(buf); |
218 | 312 | continue; |
219 | 312 | } |
220 | | // Convert direct parameter to indirect parameter. |
221 | 176 | assert(toParam.isFormalIndirect()); |
222 | 0 | auto toArg = *toArgIter++; |
223 | 176 | auto load = builder.emitLoadBorrowOperation(loc, toArg); |
224 | 176 | if (isa<LoadBorrowInst>(load)) |
225 | 20 | valuesToCleanup.push_back(load); |
226 | 176 | arguments.push_back(load); |
227 | 176 | } |
228 | | |
229 | 940 | auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back(); |
230 | 940 | auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments); |
231 | | |
232 | | // Get return elements. |
233 | 940 | SmallVector<SILValue, 4> results; |
234 | | // Extract all direct results. |
235 | 940 | SmallVector<SILValue, 4> directResults; |
236 | 940 | extractAllElements(apply, builder, directResults); |
237 | | |
238 | 940 | auto fromDirResultsIter = directResults.begin(); |
239 | 940 | auto fromIndResultsIter = apply->getIndirectSILResults().begin(); |
240 | 940 | auto toIndResultsIter = thunk->getIndirectResults().begin(); |
241 | | // Reabstract results. |
242 | 1.28k | for (unsigned resIdx : range(toType->getNumResults())) { |
243 | 1.28k | auto fromRes = fromConv.getResults()[resIdx]; |
244 | 1.28k | auto toRes = toConv.getResults()[resIdx]; |
245 | | // Check function-typed results. |
246 | 1.28k | if (isa<SILFunctionType>(fromRes.getInterfaceType()) && |
247 | 1.28k | isa<SILFunctionType>(toRes.getInterfaceType())) { |
248 | 40 | auto fromFnType = cast<SILFunctionType>(fromRes.getInterfaceType()); |
249 | 40 | auto toFnType = cast<SILFunctionType>(toRes.getInterfaceType()); |
250 | 40 | auto fromUnsubstFnType = fromFnType->getUnsubstitutedType(module); |
251 | 40 | auto toUnsubstFnType = toFnType->getUnsubstitutedType(module); |
252 | | // If unsubstituted function types are not equal, perform reabstraction. |
253 | 40 | if (fromUnsubstFnType != toUnsubstFnType) { |
254 | 40 | auto fromFn = *fromDirResultsIter++; |
255 | 40 | auto newFromFn = reabstractFunction( |
256 | 40 | builder, fb, loc, fromFn, toFnType, |
257 | 40 | [](SubstitutionMap substMap) { return substMap; }); |
258 | 40 | results.push_back(newFromFn); |
259 | 40 | continue; |
260 | 40 | } |
261 | 40 | } |
262 | | // No abstraction mismatch. |
263 | 1.24k | if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) { |
264 | | // If result types are direct, add call result as direct thunk result. |
265 | 680 | if (toRes.isFormalDirect()) |
266 | 268 | results.push_back(*fromDirResultsIter++); |
267 | | // If result types are indirect, increment indirect result iterators. |
268 | 412 | else { |
269 | 412 | ++fromIndResultsIter; |
270 | 412 | ++toIndResultsIter; |
271 | 412 | } |
272 | 680 | continue; |
273 | 680 | } |
274 | | // Load direct results from indirect results. |
275 | 560 | if (fromRes.isFormalIndirect()) { |
276 | 376 | auto indRes = *fromIndResultsIter++; |
277 | 376 | auto load = builder.emitLoadValueOperation(loc, indRes, |
278 | 376 | LoadOwnershipQualifier::Take); |
279 | 376 | results.push_back(load); |
280 | 376 | continue; |
281 | 376 | } |
282 | | // Store direct results to indirect results. |
283 | 184 | assert(toRes.isFormalIndirect()); |
284 | 0 | #ifndef NDEBUG |
285 | 0 | SILType resultTy = |
286 | 184 | toConv.getSILType(toRes, builder.getTypeExpansionContext()); |
287 | 184 | assert(resultTy.isAddress()); |
288 | 0 | #endif |
289 | 0 | auto indRes = *toIndResultsIter++; |
290 | 184 | auto dirRes = *fromDirResultsIter++; |
291 | 184 | builder.emitStoreValueOperation(loc, dirRes, indRes, |
292 | 184 | StoreOwnershipQualifier::Init); |
293 | 184 | } |
294 | 940 | auto retVal = joinElements(results, builder, loc); |
295 | | |
296 | | // Clean up local values. |
297 | | // Guaranteed values need an `end_borrow`. |
298 | | // Owned values need to be destroyed. |
299 | 940 | for (auto arg : valuesToCleanup) { |
300 | 868 | switch (arg->getOwnershipKind()) { |
301 | 0 | case OwnershipKind::Any: |
302 | 0 | llvm_unreachable("value with any ownership kind?!"); |
303 | 24 | case OwnershipKind::Guaranteed: |
304 | 24 | builder.emitEndBorrowOperation(loc, arg); |
305 | 24 | break; |
306 | 4 | case OwnershipKind::Owned: |
307 | 4 | case OwnershipKind::Unowned: |
308 | 844 | case OwnershipKind::None: |
309 | 844 | builder.emitDestroyOperation(loc, arg); |
310 | 844 | break; |
311 | 868 | } |
312 | 868 | } |
313 | | |
314 | | // Deallocate local allocations. |
315 | 940 | for (auto *alloc : llvm::reverse(localAllocations)) |
316 | 688 | builder.createDeallocStack(loc, alloc); |
317 | | |
318 | | // Create return. |
319 | 940 | builder.createReturn(loc, retVal); |
320 | | |
321 | 940 | LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n"; |
322 | 940 | s << " From type: " << fromType << '\n'; |
323 | 940 | s << " To type: " << toType << '\n'; s << '\n' |
324 | 940 | << *thunk); |
325 | | |
326 | 940 | return thunk; |
327 | 940 | } |
328 | | |
329 | | SILValue reabstractFunction( |
330 | | SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc, |
331 | | SILValue fn, CanSILFunctionType toType, |
332 | 3.84k | std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) { |
333 | 3.84k | auto &module = *fn->getModule(); |
334 | 3.84k | auto fromType = fn->getType().getAs<SILFunctionType>(); |
335 | 3.84k | auto unsubstFromType = fromType->getUnsubstitutedType(module); |
336 | 3.84k | auto unsubstToType = toType->getUnsubstitutedType(module); |
337 | | |
338 | 3.84k | auto *thunk = getOrCreateReabstractionThunk(fb, module, loc, |
339 | 3.84k | /*caller*/ fn->getFunction(), |
340 | 3.84k | unsubstFromType, unsubstToType); |
341 | 3.84k | auto *thunkRef = builder.createFunctionRef(loc, thunk); |
342 | | |
343 | 3.84k | if (fromType != unsubstFromType) |
344 | 656 | fn = builder.createConvertFunction( |
345 | 656 | loc, fn, SILType::getPrimitiveObjectType(unsubstFromType), |
346 | 656 | /*withoutActuallyEscaping*/ false); |
347 | | |
348 | 3.84k | fn = builder.createPartialApply( |
349 | 3.84k | loc, thunkRef, remapSubstitutions(thunk->getForwardingSubstitutionMap()), |
350 | 3.84k | {fn}, fromType->getCalleeConvention()); |
351 | | |
352 | 3.84k | if (toType != unsubstToType) |
353 | 652 | fn = builder.createConvertFunction(loc, fn, |
354 | 652 | SILType::getPrimitiveObjectType(toType), |
355 | 652 | /*withoutActuallyEscaping*/ false); |
356 | | |
357 | 3.84k | return fn; |
358 | 3.84k | } |
359 | | |
360 | | std::pair<SILFunction *, SubstitutionMap> |
361 | | getOrCreateSubsetParametersThunkForLinearMap( |
362 | | SILOptFunctionBuilder &fb, SILFunction *parentThunk, |
363 | | CanSILFunctionType origFnType, CanSILFunctionType linearMapType, |
364 | | CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, |
365 | | const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig, |
366 | 648 | ADContext &adContext) { |
367 | 648 | LLVM_DEBUG(getADDebugStream() |
368 | 648 | << "Getting a subset parameters thunk for " |
369 | 648 | << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") |
370 | 648 | << " linear map " << linearMapType |
371 | 648 | << " from " << actualConfig << " to " << desiredConfig << '\n'); |
372 | | |
373 | 648 | assert(!linearMapType->getCombinedSubstitutions()); |
374 | 0 | assert(!targetType->getCombinedSubstitutions()); |
375 | 0 | SubstitutionMap interfaceSubs; |
376 | 648 | GenericEnvironment *genericEnv = nullptr; |
377 | 648 | auto thunkType = buildThunkType(parentThunk, linearMapType, targetType, |
378 | 648 | genericEnv, interfaceSubs, |
379 | 648 | /*withoutActuallyEscaping*/ true, |
380 | 648 | DifferentiationThunkKind::Reabstraction); |
381 | | |
382 | 648 | Mangle::DifferentiationMangler mangler; |
383 | 648 | auto fromInterfaceType = |
384 | 648 | linearMapType->mapTypeOutOfContext()->getCanonicalType(); |
385 | | |
386 | 648 | auto thunkName = mangler.mangleLinearMapSubsetParametersThunk( |
387 | 648 | fromInterfaceType, kind.getLinearMapKind(), |
388 | 648 | actualConfig.parameterIndices, actualConfig.resultIndices, |
389 | 648 | desiredConfig.parameterIndices); |
390 | | |
391 | 648 | auto loc = parentThunk->getLocation(); |
392 | 648 | auto *thunk = fb.getOrCreateSharedFunction( |
393 | 648 | loc, thunkName, thunkType, IsBare, IsTransparent, IsSerialized, |
394 | 648 | ProfileCounter(), IsThunk, IsNotDynamic, IsNotDistributed, |
395 | 648 | IsNotRuntimeAccessible); |
396 | | |
397 | 648 | if (!thunk->empty()) |
398 | 56 | return {thunk, interfaceSubs}; |
399 | | |
400 | 592 | thunk->setGenericEnvironment(genericEnv); |
401 | 592 | auto *entry = thunk->createBasicBlock(); |
402 | 592 | TangentBuilder builder(entry, adContext); |
403 | 592 | createEntryArguments(thunk); |
404 | | |
405 | | // Get arguments. |
406 | 592 | SmallVector<SILValue, 4> arguments; |
407 | 592 | SmallVector<AllocStackInst *, 4> localAllocations; |
408 | 592 | SmallVector<SILValue, 4> valuesToCleanup; |
409 | 592 | auto cleanupValues = [&]() { |
410 | 592 | for (auto value : llvm::reverse(valuesToCleanup)) |
411 | 156 | builder.emitDestroyOperation(loc, value); |
412 | | |
413 | 592 | for (auto *alloc : llvm::reverse(localAllocations)) |
414 | 456 | builder.createDeallocStack(loc, alloc); |
415 | 592 | }; |
416 | | |
417 | | // Build a `.zero` argument for the given `Differentiable`-conforming type. |
418 | 592 | auto buildZeroArgument = [&](SILParameterInfo zeroSILParameter) { |
419 | 308 | auto zeroSILType = zeroSILParameter.getSILStorageInterfaceType(); |
420 | 308 | auto zeroSILObjType = zeroSILType.getObjectType(); |
421 | 308 | auto zeroType = zeroSILType.getASTType(); |
422 | 308 | auto *swiftMod = parentThunk->getModule().getSwiftModule(); |
423 | 308 | auto tangentSpace = |
424 | 308 | zeroType->getAutoDiffTangentSpace(LookUpConformanceInModule(swiftMod)); |
425 | 308 | assert(tangentSpace && "No tangent space for this type"); |
426 | 0 | switch (tangentSpace->getKind()) { |
427 | 308 | case TangentSpace::Kind::TangentVector: { |
428 | 308 | auto *buf = builder.createAllocStack(loc, zeroSILObjType); |
429 | 308 | localAllocations.push_back(buf); |
430 | 308 | builder.emitZeroIntoBuffer(loc, buf, IsInitialization); |
431 | 308 | if (zeroSILType.isAddress()) { |
432 | 148 | arguments.push_back(buf); |
433 | 148 | if (zeroSILParameter.isGuaranteed()) { |
434 | 144 | valuesToCleanup.push_back(buf); |
435 | 144 | } |
436 | 160 | } else { |
437 | 160 | auto arg = builder.emitLoadValueOperation(loc, buf, |
438 | 160 | LoadOwnershipQualifier::Take); |
439 | 160 | arguments.push_back(arg); |
440 | 160 | if (zeroSILParameter.isGuaranteed()) { |
441 | 12 | valuesToCleanup.push_back(arg); |
442 | 12 | } |
443 | 160 | } |
444 | 308 | break; |
445 | 0 | } |
446 | 0 | case TangentSpace::Kind::Tuple: { |
447 | 0 | llvm_unreachable("Unimplemented: Handle zero initialization for tuples"); |
448 | 0 | } |
449 | 308 | } |
450 | 308 | }; |
451 | | |
452 | | // The indices in `actualConfig` and `desiredConfig` are with respect to the |
453 | | // original function. However, the differential parameters and pullback |
454 | | // results may already be w.r.t. a subset. We create a map between the |
455 | | // original function's actual parameter indices and the linear map's actual |
456 | | // indices. |
457 | | // Example: |
458 | | // Original: (T0, T1, T2) -> R |
459 | | // Actual indices: 0, 2 |
460 | | // Original differential: (T0, T2) -> R |
461 | | // Original pullback: R -> (T0, T2) |
462 | | // Desired indices w.r.t. original: 2 |
463 | | // Desired indices w.r.t. linear map: 1 |
464 | 592 | SmallVector<unsigned, 4> actualParamIndicesMap( |
465 | 592 | actualConfig.parameterIndices->getCapacity(), UINT_MAX); |
466 | 592 | { |
467 | 592 | unsigned indexInBitVec = 0; |
468 | 1.28k | for (auto index : actualConfig.parameterIndices->getIndices()) { |
469 | 1.28k | actualParamIndicesMap[index] = indexInBitVec; |
470 | 1.28k | ++indexInBitVec; |
471 | 1.28k | } |
472 | 592 | } |
473 | 980 | auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned { |
474 | 980 | auto mappedIndex = actualParamIndicesMap[index]; |
475 | 980 | assert(mappedIndex < actualConfig.parameterIndices->getCapacity()); |
476 | 0 | return mappedIndex; |
477 | 980 | }; |
478 | | |
479 | 592 | auto toIndirectResultsIter = thunk->getIndirectResults().begin(); |
480 | 592 | auto useNextIndirectResult = [&]() { |
481 | 248 | assert(toIndirectResultsIter != thunk->getIndirectResults().end()); |
482 | 0 | arguments.push_back(*toIndirectResultsIter++); |
483 | 248 | }; |
484 | | |
485 | 592 | switch (kind) { |
486 | | // Differential arguments are: |
487 | | // - All indirect results, followed by: |
488 | | // - An interleaving of: |
489 | | // - Thunk arguments (when parameter index is in both desired and actual |
490 | | // indices). |
491 | | // - Zeros (when parameter is not in desired indices). |
492 | 296 | case AutoDiffDerivativeFunctionKind::JVP: { |
493 | 296 | unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults(); |
494 | | // Forward desired indirect results |
495 | 296 | for (unsigned idx : *actualConfig.resultIndices) { |
496 | 296 | if (idx >= numIndirectResults) |
497 | 184 | break; |
498 | | |
499 | 112 | auto resultInfo = linearMapType->getResults()[idx]; |
500 | 112 | assert(idx < linearMapType->getNumResults()); |
501 | | |
502 | | // Forward result argument in case we do not need to thunk it away |
503 | 112 | if (desiredConfig.resultIndices->contains(idx)) { |
504 | 112 | useNextIndirectResult(); |
505 | 112 | continue; |
506 | 112 | } |
507 | | |
508 | | // Otherwise, allocate and use an uninitialized indirect result |
509 | 0 | auto *indirectResult = builder.createAllocStack( |
510 | 0 | loc, resultInfo.getSILStorageInterfaceType()); |
511 | 0 | localAllocations.push_back(indirectResult); |
512 | 0 | arguments.push_back(indirectResult); |
513 | 0 | } |
514 | 296 | assert(toIndirectResultsIter == thunk->getIndirectResults().end()); |
515 | | |
516 | 0 | auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin(); |
517 | 340 | auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); }; |
518 | | // Iterate over actual indices. |
519 | 644 | for (unsigned i : actualConfig.parameterIndices->getIndices()) { |
520 | | // If index is desired, use next argument. |
521 | 644 | if (desiredConfig.isWrtParameter(i)) { |
522 | 340 | useNextArgument(); |
523 | 340 | } |
524 | | // Otherwise, construct and use a zero argument. |
525 | 304 | else { |
526 | 304 | auto zeroSILParameter = |
527 | 304 | linearMapType->getParameters()[mapOriginalParameterIndex(i)]; |
528 | 304 | buildZeroArgument(zeroSILParameter); |
529 | 304 | } |
530 | 644 | } |
531 | 296 | break; |
532 | 0 | } |
533 | | // Pullback arguments are: |
534 | | // - An interleaving of: |
535 | | // - Thunk indirect results (when parameter index is in both desired and |
536 | | // actual indices). |
537 | | // - Zeros (when parameter is not in desired indices). |
538 | | // - All actual arguments. |
539 | 296 | case AutoDiffDerivativeFunctionKind::VJP: { |
540 | | // Collect pullback arguments. |
541 | 296 | unsigned pullbackResultIndex = 0; |
542 | 644 | for (unsigned i : actualConfig.parameterIndices->getIndices()) { |
543 | 644 | auto origParamInfo = origFnType->getParameters()[i]; |
544 | | // Skip original semantic result parameters. All non-indirect-result pullback |
545 | | // arguments (including semantic result` arguments) are appended to `arguments` |
546 | | // later. |
547 | 644 | if (origParamInfo.isAutoDiffSemanticResult()) |
548 | 32 | continue; |
549 | 612 | auto resultInfo = linearMapType->getResults()[pullbackResultIndex]; |
550 | 612 | assert(pullbackResultIndex < linearMapType->getNumResults()); |
551 | 0 | ++pullbackResultIndex; |
552 | | // Skip pullback direct results. Only indirect results are relevant as |
553 | | // arguments. |
554 | 612 | if (resultInfo.isFormalDirect()) |
555 | 328 | continue; |
556 | | // If index is desired, use next pullback indirect result. |
557 | 284 | if (desiredConfig.isWrtParameter(i)) { |
558 | 136 | useNextIndirectResult(); |
559 | 136 | continue; |
560 | 136 | } |
561 | | // Otherwise, allocate and use an uninitialized pullback indirect result. |
562 | 148 | auto *indirectResult = builder.createAllocStack( |
563 | 148 | loc, resultInfo.getSILStorageInterfaceType()); |
564 | 148 | localAllocations.push_back(indirectResult); |
565 | 148 | arguments.push_back(indirectResult); |
566 | 148 | } |
567 | | // Forward all actual non-indirect-result arguments. |
568 | 296 | auto thunkArgs = thunk->getArgumentsWithoutIndirectResults(); |
569 | | // Slice out the function to be called |
570 | 296 | thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1); |
571 | 296 | unsigned thunkArg = 0; |
572 | 300 | for (unsigned idx : *actualConfig.resultIndices) { |
573 | | // Forward result argument in case we do not need to thunk it away |
574 | 300 | if (desiredConfig.resultIndices->contains(idx)) |
575 | 296 | arguments.push_back(thunkArgs[thunkArg++]); |
576 | 4 | else // otherwise, zero it out |
577 | 4 | buildZeroArgument(linearMapType->getParameters()[arguments.size()]); |
578 | 300 | } |
579 | 296 | break; |
580 | 0 | } |
581 | 592 | } |
582 | | |
583 | | // Get the linear map thunk argument and apply it. |
584 | 592 | auto *linearMap = thunk->getArguments().back(); |
585 | 592 | auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments); |
586 | | |
587 | | // If differential thunk, deallocate local allocations and directly return |
588 | | // `apply` result (if it is desired). |
589 | 592 | if (kind == AutoDiffDerivativeFunctionKind::JVP) { |
590 | 296 | SmallVector<SILValue, 8> differentialDirectResults; |
591 | 296 | extractAllElements(ai, builder, differentialDirectResults); |
592 | 296 | SmallVector<SILValue, 8> allResults; |
593 | 296 | collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults); |
594 | 296 | unsigned numResults = thunk->getConventions().getNumDirectSILResults() + |
595 | 296 | thunk->getConventions().getNumDirectSILResults(); |
596 | 296 | SmallVector<SILValue, 8> results; |
597 | 300 | for (unsigned idx : *actualConfig.resultIndices) { |
598 | 300 | if (idx >= numResults) |
599 | 144 | break; |
600 | | |
601 | 156 | auto result = allResults[idx]; |
602 | 156 | if (desiredConfig.isWrtResult(idx)) |
603 | 152 | results.push_back(result); |
604 | 4 | else { |
605 | 4 | if (result->getType().isAddress()) |
606 | 0 | builder.emitDestroyAddrAndFold(loc, result); |
607 | 4 | else |
608 | 4 | builder.emitDestroyValueOperation(loc, result); |
609 | 4 | } |
610 | 156 | } |
611 | | |
612 | 296 | cleanupValues(); |
613 | 296 | auto result = joinElements(results, builder, loc); |
614 | 296 | builder.createReturn(loc, result); |
615 | 296 | return {thunk, interfaceSubs}; |
616 | 296 | } |
617 | | |
618 | | // If pullback thunk, return only the desired results and clean up the |
619 | | // undesired results. |
620 | 296 | SmallVector<SILValue, 8> pullbackDirectResults; |
621 | 296 | extractAllElements(ai, builder, pullbackDirectResults); |
622 | 296 | SmallVector<SILValue, 8> allResults; |
623 | 296 | collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults); |
624 | | // Collect pullback semantic result arguments in type order. |
625 | 296 | unsigned semanticResultArgIdx = 0; |
626 | 296 | SILFunctionConventions origConv(origFnType, thunk->getModule()); |
627 | 644 | for (auto paramIdx : actualConfig.parameterIndices->getIndices()) { |
628 | 644 | auto paramInfo = origConv.getParameters()[paramIdx]; |
629 | 644 | if (!paramInfo.isAutoDiffSemanticResult()) |
630 | 612 | continue; |
631 | 32 | auto semanticResultArg = |
632 | 32 | *std::next(ai->getAutoDiffSemanticResultArguments().begin(), |
633 | 32 | semanticResultArgIdx++); |
634 | 32 | unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx); |
635 | 32 | allResults.insert(allResults.begin() + mappedParamIdx, semanticResultArg); |
636 | 32 | } |
637 | 296 | assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() && |
638 | 296 | "Number of pullback results should match number of differentiability " |
639 | 296 | "parameters"); |
640 | | |
641 | 0 | SmallVector<SILValue, 8> results; |
642 | 644 | for (unsigned i : actualConfig.parameterIndices->getIndices()) { |
643 | 644 | unsigned mappedIndex = mapOriginalParameterIndex(i); |
644 | | // If result is desired: |
645 | | // - Do nothing if result is indirect. |
646 | | // (It was already forwarded to the `apply` instruction). |
647 | | // - Push it to `results` if result is direct. |
648 | 644 | auto result = allResults[mappedIndex]; |
649 | 644 | if (desiredConfig.isWrtParameter(i)) { |
650 | 340 | if (result->getType().isObject()) |
651 | 172 | results.push_back(result); |
652 | 340 | } |
653 | | // Otherwise, cleanup the unused results. |
654 | 304 | else { |
655 | 304 | if (result->getType().isAddress()) |
656 | 148 | builder.emitDestroyAddrAndFold(loc, result); |
657 | 156 | else |
658 | 156 | builder.emitDestroyValueOperation(loc, result); |
659 | 304 | } |
660 | 644 | } |
661 | | // Deallocate local allocations and return final direct result. |
662 | 296 | cleanupValues(); |
663 | 296 | auto result = joinElements(results, builder, loc); |
664 | 296 | builder.createReturn(loc, result); |
665 | | |
666 | 296 | return {thunk, interfaceSubs}; |
667 | 592 | } |
668 | | |
669 | | std::pair<SILFunction *, SubstitutionMap> |
670 | | getOrCreateSubsetParametersThunkForDerivativeFunction( |
671 | | SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn, |
672 | | AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig, |
673 | 1.81k | const AutoDiffConfig &actualConfig, ADContext &adContext) { |
674 | 1.81k | LLVM_DEBUG(getADDebugStream() |
675 | 1.81k | << "Getting a subset parameters thunk for derivative " |
676 | 1.81k | << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") |
677 | 1.81k | << " function " << derivativeFn |
678 | 1.81k | << " of the original function " << origFnOperand |
679 | 1.81k | << " from " << actualConfig << " to " << desiredConfig << '\n'); |
680 | | |
681 | 1.81k | auto origFnType = origFnOperand->getType().castTo<SILFunctionType>(); |
682 | 1.81k | auto &module = fb.getModule(); |
683 | 1.81k | auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); |
684 | | |
685 | | // Compute target type for thunking. |
686 | 1.81k | auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>(); |
687 | 1.81k | auto targetType = origFnType->getAutoDiffDerivativeFunctionType( |
688 | 1.81k | desiredConfig.parameterIndices, desiredConfig.resultIndices, kind, |
689 | 1.81k | module.Types, lookupConformance); |
690 | 1.81k | auto *caller = derivativeFn->getFunction(); |
691 | 1.81k | if (targetType->hasArchetype()) { |
692 | 96 | auto substTargetType = |
693 | 96 | caller->mapTypeIntoContext(targetType->mapTypeOutOfContext()) |
694 | 96 | ->getCanonicalType(); |
695 | 96 | targetType = SILType::getPrimitiveObjectType(substTargetType) |
696 | 96 | .castTo<SILFunctionType>(); |
697 | 96 | } |
698 | 1.81k | assert(derivativeFnType->getNumParameters() == |
699 | 1.81k | targetType->getNumParameters()); |
700 | 0 | assert(derivativeFnType->getNumResults() == targetType->getNumResults()); |
701 | | |
702 | | // Build thunk type. |
703 | 0 | SubstitutionMap interfaceSubs; |
704 | 1.81k | GenericEnvironment *genericEnv = nullptr; |
705 | 1.81k | auto thunkType = buildThunkType(derivativeFn->getFunction(), derivativeFnType, |
706 | 1.81k | targetType, genericEnv, interfaceSubs, |
707 | 1.81k | /*withoutActuallyEscaping*/ false, |
708 | 1.81k | DifferentiationThunkKind::IndexSubset); |
709 | | |
710 | | // FIXME: The logic for resolving `assocRef` does not reapply function |
711 | | // conversions, which is problematic if `derivativeFn` is a `partial_apply` |
712 | | // instruction. |
713 | 1.81k | StringRef origName; |
714 | 1.81k | if (auto *origFnRef = |
715 | 1.81k | peerThroughFunctionConversions<FunctionRefInst>(origFnOperand)) { |
716 | 1.76k | origName = origFnRef->getReferencedFunction()->getName(); |
717 | 1.76k | } else if (auto *origMethodInst = |
718 | 48 | peerThroughFunctionConversions<MethodInst>(origFnOperand)) { |
719 | 48 | origName = origMethodInst->getMember() |
720 | 48 | .getAnyFunctionRef() |
721 | 48 | ->getAbstractFunctionDecl() |
722 | 48 | ->getNameStr(); |
723 | 48 | } |
724 | 1.81k | assert(!origName.empty() && "Original function name could not be resolved"); |
725 | 0 | Mangle::DifferentiationMangler mangler; |
726 | 1.81k | auto thunkName = mangler.mangleDerivativeFunctionSubsetParametersThunk( |
727 | 1.81k | origName, targetType->mapTypeOutOfContext()->getCanonicalType(), |
728 | 1.81k | kind, actualConfig.parameterIndices, actualConfig.resultIndices, |
729 | 1.81k | desiredConfig.parameterIndices); |
730 | | |
731 | 1.81k | auto loc = origFnOperand.getLoc(); |
732 | 1.81k | auto *thunk = fb.getOrCreateSharedFunction( |
733 | 1.81k | loc, thunkName, thunkType, IsBare, IsTransparent, caller->isSerialized(), |
734 | 1.81k | ProfileCounter(), IsThunk, IsNotDynamic, IsNotDistributed, |
735 | 1.81k | IsNotRuntimeAccessible); |
736 | | |
737 | 1.81k | if (!thunk->empty()) |
738 | 1.16k | return {thunk, interfaceSubs}; |
739 | | |
740 | 648 | thunk->setGenericEnvironment(genericEnv); |
741 | 648 | auto *entry = thunk->createBasicBlock(); |
742 | 648 | SILBuilder builder(entry); |
743 | 648 | createEntryArguments(thunk); |
744 | | |
745 | 648 | SubstitutionMap assocSubstMap; |
746 | 648 | if (auto *partialApply = dyn_cast<PartialApplyInst>(derivativeFn)) |
747 | 280 | assocSubstMap = partialApply->getSubstitutionMap(); |
748 | | |
749 | | // FIXME: The logic for resolving `assocRef` does not reapply function |
750 | | // conversions, which is problematic if `derivativeFn` is a `partial_apply` |
751 | | // instruction. |
752 | 648 | SILValue assocRef; |
753 | 648 | if (auto *derivativeFnRef = |
754 | 648 | peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) { |
755 | 0 | auto *assoc = derivativeFnRef->getReferencedFunction(); |
756 | 0 | assocRef = builder.createFunctionRef(loc, assoc); |
757 | 648 | } else if (auto *assocMethodInst = |
758 | 648 | peerThroughFunctionConversions<WitnessMethodInst>( |
759 | 648 | derivativeFn)) { |
760 | 24 | assocRef = builder.createWitnessMethod( |
761 | 24 | loc, assocMethodInst->getLookupType(), |
762 | 24 | assocMethodInst->getConformance(), assocMethodInst->getMember(), |
763 | 24 | thunk->mapTypeIntoContext(assocMethodInst->getType())); |
764 | 624 | } else if (auto *assocMethodInst = |
765 | 624 | peerThroughFunctionConversions<ClassMethodInst>( |
766 | 624 | derivativeFn)) { |
767 | 8 | auto classOperand = thunk->getArgumentsWithoutIndirectResults().back(); |
768 | 8 | #ifndef NDEBUG |
769 | 8 | auto classOperandType = assocMethodInst->getOperand()->getType(); |
770 | 8 | assert(classOperand->getType() == classOperandType); |
771 | 0 | #endif |
772 | 0 | assocRef = builder.createClassMethod( |
773 | 8 | loc, classOperand, assocMethodInst->getMember(), |
774 | 8 | thunk->mapTypeIntoContext(assocMethodInst->getType())); |
775 | 616 | } else if (auto *diffWitFn = peerThroughFunctionConversions< |
776 | 616 | DifferentiabilityWitnessFunctionInst>(derivativeFn)) { |
777 | 616 | assocRef = builder.createDifferentiabilityWitnessFunction( |
778 | 616 | loc, diffWitFn->getWitnessKind(), diffWitFn->getWitness()); |
779 | 616 | } |
780 | 0 | assert(assocRef && "Expected derivative function to be resolved"); |
781 | | |
782 | 0 | assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap()); |
783 | 648 | derivativeFnType = assocRef->getType().castTo<SILFunctionType>(); |
784 | | |
785 | 648 | SmallVector<SILValue, 4> arguments; |
786 | 648 | arguments.append(thunk->getArguments().begin(), thunk->getArguments().end()); |
787 | 648 | assert(arguments.size() == |
788 | 648 | derivativeFnType->getNumParameters() + |
789 | 648 | derivativeFnType->getNumIndirectFormalResults()); |
790 | 0 | auto *apply = builder.createApply(loc, assocRef, assocSubstMap, arguments); |
791 | | |
792 | | // Extract all direct results. |
793 | 648 | SmallVector<SILValue, 8> directResults; |
794 | 648 | extractAllElements(apply, builder, directResults); |
795 | 648 | auto linearMap = directResults.back(); |
796 | 648 | directResults.pop_back(); |
797 | | |
798 | 648 | auto linearMapType = linearMap->getType().castTo<SILFunctionType>(); |
799 | 648 | auto linearMapTargetType = targetType->getResults() |
800 | 648 | .back() |
801 | 648 | .getSILStorageInterfaceType() |
802 | 648 | .castTo<SILFunctionType>(); |
803 | 648 | auto unsubstLinearMapType = linearMapType->getUnsubstitutedType(module); |
804 | 648 | auto unsubstLinearMapTargetType = |
805 | 648 | linearMapTargetType->getUnsubstitutedType(module); |
806 | | |
807 | 648 | SILFunction *linearMapThunk; |
808 | 648 | SubstitutionMap linearMapSubs; |
809 | 648 | std::tie(linearMapThunk, linearMapSubs) = |
810 | 648 | getOrCreateSubsetParametersThunkForLinearMap( |
811 | 648 | fb, thunk, origFnType, unsubstLinearMapType, |
812 | 648 | unsubstLinearMapTargetType, kind, desiredConfig, actualConfig, |
813 | 648 | adContext); |
814 | | |
815 | 648 | auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk); |
816 | 648 | SILValue thunkedLinearMap = linearMap; |
817 | 648 | if (linearMapType != unsubstLinearMapType) { |
818 | 280 | thunkedLinearMap = builder.createConvertFunction( |
819 | 280 | loc, thunkedLinearMap, |
820 | 280 | SILType::getPrimitiveObjectType(unsubstLinearMapType), |
821 | 280 | /*withoutActuallyEscaping*/ false); |
822 | 280 | } |
823 | 648 | thunkedLinearMap = builder.createPartialApply( |
824 | 648 | loc, linearMapThunkFRI, linearMapSubs, {thunkedLinearMap}, |
825 | 648 | ParameterConvention::Direct_Guaranteed); |
826 | 648 | if (linearMapTargetType != unsubstLinearMapTargetType) { |
827 | 64 | thunkedLinearMap = builder.createConvertFunction( |
828 | 64 | loc, thunkedLinearMap, |
829 | 64 | SILType::getPrimitiveObjectType(linearMapTargetType), |
830 | 64 | /*withoutActuallyEscaping*/ false); |
831 | 64 | } |
832 | 648 | assert(origFnType->getNumAutoDiffSemanticResults() > 0); |
833 | 648 | if (origFnType->getNumResults() > 0 && |
834 | 648 | origFnType->getResults().front().isFormalDirect()) { |
835 | 352 | directResults.push_back(thunkedLinearMap); |
836 | 352 | auto result = joinElements(directResults, builder, loc); |
837 | 352 | builder.createReturn(loc, result); |
838 | 352 | } else { |
839 | 296 | builder.createReturn(loc, thunkedLinearMap); |
840 | 296 | } |
841 | | |
842 | 648 | return {thunk, interfaceSubs}; |
843 | 1.81k | } |
844 | | |
845 | | } // end namespace autodiff |
846 | | } // end namespace swift |