Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/Thunk.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- Thunk.cpp - Automatic differentiation thunks ---------*- C++ -*---===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// Automatic differentiation thunk generation utilities.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#define DEBUG_TYPE "differentiation"
18
19
#include "swift/SILOptimizer/Differentiation/Thunk.h"
20
#include "swift/SILOptimizer/Differentiation/Common.h"
21
22
#include "swift/AST/AnyFunctionRef.h"
23
#include "swift/AST/Requirement.h"
24
#include "swift/AST/SubstitutionMap.h"
25
#include "swift/AST/TypeCheckRequests.h"
26
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
27
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
28
29
namespace swift {
30
namespace autodiff {
31
32
//===----------------------------------------------------------------------===//
33
// Thunk helpers
34
//===----------------------------------------------------------------------===//
35
// These helpers are copied/adapted from SILGen. They should be refactored and
36
// moved to a shared location.
37
//===----------------------------------------------------------------------===//
38
39
CanSILFunctionType buildThunkType(SILFunction *fn,
40
                                  CanSILFunctionType &sourceType,
41
                                  CanSILFunctionType &expectedType,
42
                                  GenericEnvironment *&genericEnv,
43
                                  SubstitutionMap &interfaceSubs,
44
                                  bool withoutActuallyEscaping,
45
6.30k
                                  DifferentiationThunkKind thunkKind) {
46
6.30k
  CanType inputSubstType;
47
6.30k
  CanType outputSubstType;
48
6.30k
  CanType dynamicSelfType;
49
6.30k
  return buildSILFunctionThunkType(
50
6.30k
      fn, sourceType, expectedType, inputSubstType, outputSubstType, genericEnv,
51
6.30k
      interfaceSubs, dynamicSelfType, withoutActuallyEscaping, thunkKind);
52
6.30k
}
53
54
/// Forward function arguments, handling ownership convention mismatches.
55
/// Adapted from `forwardFunctionArguments` in SILGenPoly.cpp.
56
///
57
/// Forwarded arguments are appended to `forwardedArgs`.
58
///
59
/// Local allocations are appended to `localAllocations`. They need to be
60
/// deallocated via `dealloc_stack`.
61
///
62
/// Local values requiring cleanup are appended to `valuesToCleanup`.
63
static void forwardFunctionArgumentsConvertingOwnership(
64
    SILBuilder &builder, SILLocation loc, CanSILFunctionType fromTy,
65
    CanSILFunctionType toTy, ArrayRef<SILArgument *> originalArgs,
66
    SmallVectorImpl<SILValue> &forwardedArgs,
67
    SmallVectorImpl<AllocStackInst *> &localAllocations,
68
940
    SmallVectorImpl<SILValue> &valuesToCleanup) {
69
940
  auto fromParameters = fromTy->getParameters();
70
940
  auto toParameters = toTy->getParameters();
71
940
  assert(fromParameters.size() == toParameters.size());
72
0
  assert(fromParameters.size() == originalArgs.size());
73
1.04k
  for (auto index : indices(originalArgs)) {
74
1.04k
    auto &arg = originalArgs[index];
75
1.04k
    auto fromParam = fromParameters[index];
76
1.04k
    auto toParam = toParameters[index];
77
    // To convert guaranteed argument to be owned, create a copy.
78
1.04k
    if (fromParam.isConsumed() && !toParam.isConsumed()) {
79
      // If the argument has an object type, create a `copy_value`.
80
28
      if (arg->getType().isObject()) {
81
28
        auto argCopy = builder.emitCopyValueOperation(loc, arg);
82
28
        forwardedArgs.push_back(argCopy);
83
28
        continue;
84
28
      }
85
      // If the argument has an address type, create a local allocation and
86
      // `copy_addr` its contents to the local allocation.
87
0
      auto *alloc = builder.createAllocStack(loc, arg->getType());
88
0
      builder.createCopyAddr(loc, arg, alloc, IsNotTake, IsInitialization);
89
0
      localAllocations.push_back(alloc);
90
0
      forwardedArgs.push_back(alloc);
91
0
      continue;
92
28
    }
93
    // To convert owned argument to be guaranteed, borrow the argument.
94
1.01k
    if (fromParam.isGuaranteed() && !toParam.isGuaranteed()) {
95
268
      auto bbi = builder.emitBeginBorrowOperation(loc, arg);
96
268
      forwardedArgs.push_back(bbi);
97
268
      valuesToCleanup.push_back(bbi);
98
268
      valuesToCleanup.push_back(arg);
99
268
      continue;
100
268
    }
101
    // Otherwise, simply forward the argument.
102
744
    forwardedArgs.push_back(arg);
103
744
  }
104
940
}
105
106
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
107
                                           SILModule &module, SILLocation loc,
108
                                           SILFunction *caller,
109
                                           CanSILFunctionType fromType,
110
3.84k
                                           CanSILFunctionType toType) {
111
3.84k
  assert(!fromType->getCombinedSubstitutions());
112
0
  assert(!toType->getCombinedSubstitutions());
113
114
0
  SubstitutionMap interfaceSubs;
115
3.84k
  GenericEnvironment *genericEnv = nullptr;
116
3.84k
  auto thunkType =
117
3.84k
      buildThunkType(caller, fromType, toType, genericEnv, interfaceSubs,
118
3.84k
                     /*withoutActuallyEscaping*/ false,
119
3.84k
                     DifferentiationThunkKind::Reabstraction);
120
3.84k
  auto thunkDeclType =
121
3.84k
      thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false));
122
123
3.84k
  auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType();
124
3.84k
  auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType();
125
126
3.84k
  Mangle::ASTMangler mangler;
127
3.84k
  std::string name = mangler.mangleReabstractionThunkHelper(
128
3.84k
      thunkType, fromInterfaceType, toInterfaceType, Type(), Type(),
129
3.84k
      module.getSwiftModule());
130
131
3.84k
  auto *thunk = fb.getOrCreateSharedFunction(
132
3.84k
      loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
133
3.84k
      ProfileCounter(), IsReabstractionThunk, IsNotDynamic, IsNotDistributed,
134
3.84k
      IsNotRuntimeAccessible);
135
3.84k
  if (!thunk->empty())
136
2.90k
    return thunk;
137
138
940
  thunk->setGenericEnvironment(genericEnv);
139
940
  auto *entry = thunk->createBasicBlock();
140
940
  SILBuilder builder(entry);
141
940
  createEntryArguments(thunk);
142
143
940
  SILFunctionConventions fromConv(fromType, module);
144
940
  SILFunctionConventions toConv(toType, module);
145
940
  assert(toConv.useLoweredAddresses());
146
147
  // Forward thunk arguments, handling ownership convention mismatches.
148
0
  SmallVector<SILValue, 4> forwardedArgs;
149
940
  for (auto indRes : thunk->getIndirectResults())
150
596
    forwardedArgs.push_back(indRes);
151
940
  SmallVector<AllocStackInst *, 4> localAllocations;
152
940
  SmallVector<SILValue, 4> valuesToCleanup;
153
940
  forwardFunctionArgumentsConvertingOwnership(
154
940
      builder, loc, fromType, toType,
155
940
      thunk->getArgumentsWithoutIndirectResults().drop_back(), forwardedArgs,
156
940
      localAllocations, valuesToCleanup);
157
158
940
  SmallVector<SILValue, 4> arguments;
159
940
  auto toArgIter = forwardedArgs.begin();
160
964
  auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
161
162
940
  auto createAllocStack = [&](SILType type) {
163
688
    auto *alloc = builder.createAllocStack(loc, type);
164
688
    localAllocations.push_back(alloc);
165
688
    return alloc;
166
688
  };
167
168
  // Handle indirect results.
169
940
  assert(fromType->getNumResults() == toType->getNumResults());
170
1.28k
  for (unsigned resIdx : range(toType->getNumResults())) {
171
1.28k
    auto fromRes = fromConv.getResults()[resIdx];
172
1.28k
    auto toRes = toConv.getResults()[resIdx];
173
    // No abstraction mismatch.
174
1.28k
    if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
175
      // If result types are indirect, directly pass as next argument.
176
720
      if (toRes.isFormalIndirect())
177
412
        useNextArgument();
178
720
      continue;
179
720
    }
180
    // Convert indirect result to direct result.
181
560
    if (fromRes.isFormalIndirect()) {
182
376
      SILType resultTy =
183
376
          fromConv.getSILType(fromRes, builder.getTypeExpansionContext());
184
376
      assert(resultTy.isAddress());
185
0
      auto *indRes = createAllocStack(resultTy);
186
376
      arguments.push_back(indRes);
187
376
      continue;
188
376
    }
189
    // Convert direct result to indirect result.
190
    // Increment thunk argument iterator; reabstraction handled later.
191
184
    ++toArgIter;
192
184
  }
193
194
  // Reabstract parameters.
195
940
  assert(toType->getNumParameters() == fromType->getNumParameters());
196
1.04k
  for (unsigned paramIdx : range(toType->getNumParameters())) {
197
1.04k
    auto fromParam = fromConv.getParameters()[paramIdx];
198
1.04k
    auto toParam = toConv.getParameters()[paramIdx];
199
    // No abstraction mismatch. Directly use next argument.
200
1.04k
    if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) {
201
552
      useNextArgument();
202
552
      continue;
203
552
    }
204
    // Convert indirect parameter to direct parameter.
205
488
    if (fromParam.isFormalIndirect()) {
206
312
      auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx],
207
312
                                         builder.getTypeExpansionContext());
208
312
      if (!paramTy.hasArchetype())
209
308
        paramTy = thunk->mapTypeIntoContext(paramTy);
210
312
      assert(paramTy.isAddress());
211
0
      auto toArg = *toArgIter++;
212
312
      auto *buf = createAllocStack(toArg->getType());
213
312
      toArg = builder.emitCopyValueOperation(loc, toArg);
214
312
      builder.emitStoreValueOperation(loc, toArg, buf,
215
312
                                      StoreOwnershipQualifier::Init);
216
312
      valuesToCleanup.push_back(buf);
217
312
      arguments.push_back(buf);
218
312
      continue;
219
312
    }
220
    // Convert direct parameter to indirect parameter.
221
176
    assert(toParam.isFormalIndirect());
222
0
    auto toArg = *toArgIter++;
223
176
    auto load = builder.emitLoadBorrowOperation(loc, toArg);
224
176
    if (isa<LoadBorrowInst>(load))
225
20
      valuesToCleanup.push_back(load);
226
176
    arguments.push_back(load);
227
176
  }
228
229
940
  auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
230
940
  auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments);
231
232
  // Get return elements.
233
940
  SmallVector<SILValue, 4> results;
234
  // Extract all direct results.
235
940
  SmallVector<SILValue, 4> directResults;
236
940
  extractAllElements(apply, builder, directResults);
237
238
940
  auto fromDirResultsIter = directResults.begin();
239
940
  auto fromIndResultsIter = apply->getIndirectSILResults().begin();
240
940
  auto toIndResultsIter = thunk->getIndirectResults().begin();
241
  // Reabstract results.
242
1.28k
  for (unsigned resIdx : range(toType->getNumResults())) {
243
1.28k
    auto fromRes = fromConv.getResults()[resIdx];
244
1.28k
    auto toRes = toConv.getResults()[resIdx];
245
    // Check function-typed results.
246
1.28k
    if (isa<SILFunctionType>(fromRes.getInterfaceType()) &&
247
1.28k
        isa<SILFunctionType>(toRes.getInterfaceType())) {
248
40
      auto fromFnType = cast<SILFunctionType>(fromRes.getInterfaceType());
249
40
      auto toFnType = cast<SILFunctionType>(toRes.getInterfaceType());
250
40
      auto fromUnsubstFnType = fromFnType->getUnsubstitutedType(module);
251
40
      auto toUnsubstFnType = toFnType->getUnsubstitutedType(module);
252
      // If unsubstituted function types are not equal, perform reabstraction.
253
40
      if (fromUnsubstFnType != toUnsubstFnType) {
254
40
        auto fromFn = *fromDirResultsIter++;
255
40
        auto newFromFn = reabstractFunction(
256
40
            builder, fb, loc, fromFn, toFnType,
257
40
            [](SubstitutionMap substMap) { return substMap; });
258
40
        results.push_back(newFromFn);
259
40
        continue;
260
40
      }
261
40
    }
262
    // No abstraction mismatch.
263
1.24k
    if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
264
      // If result types are direct, add call result as direct thunk result.
265
680
      if (toRes.isFormalDirect())
266
268
        results.push_back(*fromDirResultsIter++);
267
      // If result types are indirect, increment indirect result iterators.
268
412
      else {
269
412
        ++fromIndResultsIter;
270
412
        ++toIndResultsIter;
271
412
      }
272
680
      continue;
273
680
    }
274
    // Load direct results from indirect results.
275
560
    if (fromRes.isFormalIndirect()) {
276
376
      auto indRes = *fromIndResultsIter++;
277
376
      auto load = builder.emitLoadValueOperation(loc, indRes,
278
376
                                                 LoadOwnershipQualifier::Take);
279
376
      results.push_back(load);
280
376
      continue;
281
376
    }
282
    // Store direct results to indirect results.
283
184
    assert(toRes.isFormalIndirect());
284
0
#ifndef NDEBUG
285
0
    SILType resultTy =
286
184
        toConv.getSILType(toRes, builder.getTypeExpansionContext());
287
184
    assert(resultTy.isAddress());
288
0
#endif
289
0
    auto indRes = *toIndResultsIter++;
290
184
    auto dirRes = *fromDirResultsIter++;
291
184
    builder.emitStoreValueOperation(loc, dirRes, indRes,
292
184
                                    StoreOwnershipQualifier::Init);
293
184
  }
294
940
  auto retVal = joinElements(results, builder, loc);
295
296
  // Clean up local values.
297
  // Guaranteed values need an `end_borrow`.
298
  // Owned values need to be destroyed.
299
940
  for (auto arg : valuesToCleanup) {
300
868
    switch (arg->getOwnershipKind()) {
301
0
    case OwnershipKind::Any:
302
0
      llvm_unreachable("value with any ownership kind?!");
303
24
    case OwnershipKind::Guaranteed:
304
24
      builder.emitEndBorrowOperation(loc, arg);
305
24
      break;
306
4
    case OwnershipKind::Owned:
307
4
    case OwnershipKind::Unowned:
308
844
    case OwnershipKind::None:
309
844
      builder.emitDestroyOperation(loc, arg);
310
844
      break;
311
868
    }
312
868
  }
313
314
  // Deallocate local allocations.
315
940
  for (auto *alloc : llvm::reverse(localAllocations))
316
688
    builder.createDeallocStack(loc, alloc);
317
318
  // Create return.
319
940
  builder.createReturn(loc, retVal);
320
321
940
  LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n";
322
940
             s << "  From type: " << fromType << '\n';
323
940
             s << "  To type: " << toType << '\n'; s << '\n'
324
940
                                                     << *thunk);
325
326
940
  return thunk;
327
940
}
328
329
SILValue reabstractFunction(
330
    SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
331
    SILValue fn, CanSILFunctionType toType,
332
3.84k
    std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
333
3.84k
  auto &module = *fn->getModule();
334
3.84k
  auto fromType = fn->getType().getAs<SILFunctionType>();
335
3.84k
  auto unsubstFromType = fromType->getUnsubstitutedType(module);
336
3.84k
  auto unsubstToType = toType->getUnsubstitutedType(module);
337
338
3.84k
  auto *thunk = getOrCreateReabstractionThunk(fb, module, loc,
339
3.84k
                                              /*caller*/ fn->getFunction(),
340
3.84k
                                              unsubstFromType, unsubstToType);
341
3.84k
  auto *thunkRef = builder.createFunctionRef(loc, thunk);
342
343
3.84k
  if (fromType != unsubstFromType)
344
656
    fn = builder.createConvertFunction(
345
656
        loc, fn, SILType::getPrimitiveObjectType(unsubstFromType),
346
656
        /*withoutActuallyEscaping*/ false);
347
348
3.84k
  fn = builder.createPartialApply(
349
3.84k
      loc, thunkRef, remapSubstitutions(thunk->getForwardingSubstitutionMap()),
350
3.84k
      {fn}, fromType->getCalleeConvention());
351
352
3.84k
  if (toType != unsubstToType)
353
652
    fn = builder.createConvertFunction(loc, fn,
354
652
                                       SILType::getPrimitiveObjectType(toType),
355
652
                                       /*withoutActuallyEscaping*/ false);
356
357
3.84k
  return fn;
358
3.84k
}
359
360
std::pair<SILFunction *, SubstitutionMap>
361
getOrCreateSubsetParametersThunkForLinearMap(
362
    SILOptFunctionBuilder &fb, SILFunction *parentThunk,
363
    CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
364
    CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
365
    const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig,
366
648
    ADContext &adContext) {
367
648
  LLVM_DEBUG(getADDebugStream()
368
648
             << "Getting a subset parameters thunk for "
369
648
             << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp")
370
648
             << " linear map " << linearMapType
371
648
             << " from " << actualConfig << " to " << desiredConfig << '\n');
372
373
648
  assert(!linearMapType->getCombinedSubstitutions());
374
0
  assert(!targetType->getCombinedSubstitutions());
375
0
  SubstitutionMap interfaceSubs;
376
648
  GenericEnvironment *genericEnv = nullptr;
377
648
  auto thunkType = buildThunkType(parentThunk, linearMapType, targetType,
378
648
                                  genericEnv, interfaceSubs,
379
648
                                  /*withoutActuallyEscaping*/ true,
380
648
                                  DifferentiationThunkKind::Reabstraction);
381
382
648
  Mangle::DifferentiationMangler mangler;
383
648
  auto fromInterfaceType =
384
648
      linearMapType->mapTypeOutOfContext()->getCanonicalType();
385
386
648
  auto thunkName = mangler.mangleLinearMapSubsetParametersThunk(
387
648
      fromInterfaceType, kind.getLinearMapKind(),
388
648
      actualConfig.parameterIndices, actualConfig.resultIndices,
389
648
      desiredConfig.parameterIndices);
390
391
648
  auto loc = parentThunk->getLocation();
392
648
  auto *thunk = fb.getOrCreateSharedFunction(
393
648
      loc, thunkName, thunkType, IsBare, IsTransparent, IsSerialized,
394
648
      ProfileCounter(), IsThunk, IsNotDynamic, IsNotDistributed,
395
648
      IsNotRuntimeAccessible);
396
397
648
  if (!thunk->empty())
398
56
    return {thunk, interfaceSubs};
399
400
592
  thunk->setGenericEnvironment(genericEnv);
401
592
  auto *entry = thunk->createBasicBlock();
402
592
  TangentBuilder builder(entry, adContext);
403
592
  createEntryArguments(thunk);
404
405
  // Get arguments.
406
592
  SmallVector<SILValue, 4> arguments;
407
592
  SmallVector<AllocStackInst *, 4> localAllocations;
408
592
  SmallVector<SILValue, 4> valuesToCleanup;
409
592
  auto cleanupValues = [&]() {
410
592
    for (auto value : llvm::reverse(valuesToCleanup))
411
156
      builder.emitDestroyOperation(loc, value);
412
413
592
    for (auto *alloc : llvm::reverse(localAllocations))
414
456
      builder.createDeallocStack(loc, alloc);
415
592
  };
416
417
  // Build a `.zero` argument for the given `Differentiable`-conforming type.
418
592
  auto buildZeroArgument = [&](SILParameterInfo zeroSILParameter) {
419
308
    auto zeroSILType = zeroSILParameter.getSILStorageInterfaceType();
420
308
    auto zeroSILObjType = zeroSILType.getObjectType();
421
308
    auto zeroType = zeroSILType.getASTType();
422
308
    auto *swiftMod = parentThunk->getModule().getSwiftModule();
423
308
    auto tangentSpace =
424
308
        zeroType->getAutoDiffTangentSpace(LookUpConformanceInModule(swiftMod));
425
308
    assert(tangentSpace && "No tangent space for this type");
426
0
    switch (tangentSpace->getKind()) {
427
308
    case TangentSpace::Kind::TangentVector: {
428
308
      auto *buf = builder.createAllocStack(loc, zeroSILObjType);
429
308
      localAllocations.push_back(buf);
430
308
      builder.emitZeroIntoBuffer(loc, buf, IsInitialization);
431
308
      if (zeroSILType.isAddress()) {
432
148
        arguments.push_back(buf);
433
148
        if (zeroSILParameter.isGuaranteed()) {
434
144
          valuesToCleanup.push_back(buf);
435
144
        }
436
160
      } else {
437
160
        auto arg = builder.emitLoadValueOperation(loc, buf,
438
160
                                                  LoadOwnershipQualifier::Take);
439
160
        arguments.push_back(arg);
440
160
        if (zeroSILParameter.isGuaranteed()) {
441
12
          valuesToCleanup.push_back(arg);
442
12
        }
443
160
      }
444
308
      break;
445
0
    }
446
0
    case TangentSpace::Kind::Tuple: {
447
0
      llvm_unreachable("Unimplemented: Handle zero initialization for tuples");
448
0
    }
449
308
    }
450
308
  };
451
452
  // The indices in `actualConfig` and `desiredConfig` are with respect to the
453
  // original function. However, the differential parameters and pullback
454
  // results may already be w.r.t. a subset. We create a map between the
455
  // original function's actual parameter indices and the linear map's actual
456
  // indices.
457
  // Example:
458
  //   Original: (T0, T1, T2) -> R
459
  //   Actual indices: 0, 2
460
  //   Original differential: (T0, T2) -> R
461
  //   Original pullback: R -> (T0, T2)
462
  //   Desired indices w.r.t. original: 2
463
  //   Desired indices w.r.t. linear map: 1
464
592
  SmallVector<unsigned, 4> actualParamIndicesMap(
465
592
      actualConfig.parameterIndices->getCapacity(), UINT_MAX);
466
592
  {
467
592
    unsigned indexInBitVec = 0;
468
1.28k
    for (auto index : actualConfig.parameterIndices->getIndices()) {
469
1.28k
      actualParamIndicesMap[index] = indexInBitVec;
470
1.28k
      ++indexInBitVec;
471
1.28k
    }
472
592
  }
473
980
  auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned {
474
980
    auto mappedIndex = actualParamIndicesMap[index];
475
980
    assert(mappedIndex < actualConfig.parameterIndices->getCapacity());
476
0
    return mappedIndex;
477
980
  };
478
479
592
  auto toIndirectResultsIter = thunk->getIndirectResults().begin();
480
592
  auto useNextIndirectResult = [&]() {
481
248
      assert(toIndirectResultsIter != thunk->getIndirectResults().end());
482
0
      arguments.push_back(*toIndirectResultsIter++);
483
248
  };
484
485
592
  switch (kind) {
486
  // Differential arguments are:
487
  // - All indirect results, followed by:
488
  // - An interleaving of:
489
  //   - Thunk arguments (when parameter index is in both desired and actual
490
  //     indices).
491
  //   - Zeros (when parameter is not in desired indices).
492
296
  case AutoDiffDerivativeFunctionKind::JVP: {
493
296
    unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults();
494
    // Forward desired indirect results
495
296
    for (unsigned idx : *actualConfig.resultIndices) {
496
296
      if (idx >= numIndirectResults)
497
184
        break;
498
499
112
      auto resultInfo = linearMapType->getResults()[idx];
500
112
      assert(idx < linearMapType->getNumResults());
501
502
      // Forward result argument in case we do not need to thunk it away
503
112
      if (desiredConfig.resultIndices->contains(idx)) {
504
112
        useNextIndirectResult();
505
112
        continue;
506
112
      }
507
508
      // Otherwise, allocate and use an uninitialized indirect result
509
0
      auto *indirectResult = builder.createAllocStack(
510
0
        loc, resultInfo.getSILStorageInterfaceType());
511
0
      localAllocations.push_back(indirectResult);
512
0
      arguments.push_back(indirectResult);
513
0
    }
514
296
    assert(toIndirectResultsIter == thunk->getIndirectResults().end());
515
516
0
    auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin();
517
340
    auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
518
    // Iterate over actual indices.
519
644
    for (unsigned i : actualConfig.parameterIndices->getIndices()) {
520
      // If index is desired, use next argument.
521
644
      if (desiredConfig.isWrtParameter(i)) {
522
340
        useNextArgument();
523
340
      }
524
      // Otherwise, construct and use a zero argument.
525
304
      else {
526
304
        auto zeroSILParameter =
527
304
            linearMapType->getParameters()[mapOriginalParameterIndex(i)];
528
304
        buildZeroArgument(zeroSILParameter);
529
304
      }
530
644
    }
531
296
    break;
532
0
  }
533
  // Pullback arguments are:
534
  // - An interleaving of:
535
  //   - Thunk indirect results (when parameter index is in both desired and
536
  //     actual indices).
537
  //   - Zeros (when parameter is not in desired indices).
538
  // - All actual arguments.
539
296
  case AutoDiffDerivativeFunctionKind::VJP: {
540
    // Collect pullback arguments.
541
296
    unsigned pullbackResultIndex = 0;
542
644
    for (unsigned i : actualConfig.parameterIndices->getIndices()) {
543
644
      auto origParamInfo = origFnType->getParameters()[i];
544
      // Skip original semantic result parameters. All non-indirect-result pullback
545
      // arguments (including semantic result` arguments) are appended to `arguments`
546
      // later.
547
644
      if (origParamInfo.isAutoDiffSemanticResult())
548
32
        continue;
549
612
      auto resultInfo = linearMapType->getResults()[pullbackResultIndex];
550
612
      assert(pullbackResultIndex < linearMapType->getNumResults());
551
0
      ++pullbackResultIndex;
552
      // Skip pullback direct results. Only indirect results are relevant as
553
      // arguments.
554
612
      if (resultInfo.isFormalDirect())
555
328
        continue;
556
      // If index is desired, use next pullback indirect result.
557
284
      if (desiredConfig.isWrtParameter(i)) {
558
136
        useNextIndirectResult();
559
136
        continue;
560
136
      }
561
      // Otherwise, allocate and use an uninitialized pullback indirect result.
562
148
      auto *indirectResult = builder.createAllocStack(
563
148
          loc, resultInfo.getSILStorageInterfaceType());
564
148
      localAllocations.push_back(indirectResult);
565
148
      arguments.push_back(indirectResult);
566
148
    }
567
    // Forward all actual non-indirect-result arguments.
568
296
    auto thunkArgs = thunk->getArgumentsWithoutIndirectResults();
569
    // Slice out the function to be called
570
296
    thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1);
571
296
    unsigned thunkArg = 0;
572
300
    for (unsigned idx : *actualConfig.resultIndices) {
573
      // Forward result argument in case we do not need to thunk it away
574
300
      if (desiredConfig.resultIndices->contains(idx))
575
296
        arguments.push_back(thunkArgs[thunkArg++]);
576
4
      else // otherwise, zero it out
577
4
        buildZeroArgument(linearMapType->getParameters()[arguments.size()]);
578
300
    }
579
296
    break;
580
0
  }
581
592
  }
582
583
  // Get the linear map thunk argument and apply it.
584
592
  auto *linearMap = thunk->getArguments().back();
585
592
  auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments);
586
587
  // If differential thunk, deallocate local allocations and directly return
588
  // `apply` result (if it is desired).
589
592
  if (kind == AutoDiffDerivativeFunctionKind::JVP) {
590
296
    SmallVector<SILValue, 8> differentialDirectResults;
591
296
    extractAllElements(ai, builder, differentialDirectResults);
592
296
    SmallVector<SILValue, 8> allResults;
593
296
    collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults);
594
296
    unsigned numResults = thunk->getConventions().getNumDirectSILResults() +
595
296
     thunk->getConventions().getNumDirectSILResults();
596
296
    SmallVector<SILValue, 8> results;
597
300
    for (unsigned idx : *actualConfig.resultIndices) {
598
300
      if (idx >= numResults)
599
144
        break;
600
601
156
      auto result = allResults[idx];
602
156
      if (desiredConfig.isWrtResult(idx))
603
152
        results.push_back(result);
604
4
      else {
605
4
        if (result->getType().isAddress())
606
0
          builder.emitDestroyAddrAndFold(loc, result);
607
4
        else
608
4
          builder.emitDestroyValueOperation(loc, result);
609
4
      }
610
156
    }
611
612
296
    cleanupValues();
613
296
    auto result = joinElements(results, builder, loc);
614
296
    builder.createReturn(loc, result);
615
296
    return {thunk, interfaceSubs};
616
296
  }
617
618
  // If pullback thunk, return only the desired results and clean up the
619
  // undesired results.
620
296
  SmallVector<SILValue, 8> pullbackDirectResults;
621
296
  extractAllElements(ai, builder, pullbackDirectResults);
622
296
  SmallVector<SILValue, 8> allResults;
623
296
  collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults);
624
  // Collect pullback semantic result arguments in type order.
625
296
  unsigned semanticResultArgIdx = 0;
626
296
  SILFunctionConventions origConv(origFnType, thunk->getModule());
627
644
  for (auto paramIdx : actualConfig.parameterIndices->getIndices()) {
628
644
    auto paramInfo = origConv.getParameters()[paramIdx];
629
644
    if (!paramInfo.isAutoDiffSemanticResult())
630
612
      continue;
631
32
    auto semanticResultArg =
632
32
      *std::next(ai->getAutoDiffSemanticResultArguments().begin(),
633
32
                 semanticResultArgIdx++);
634
32
    unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx);
635
32
    allResults.insert(allResults.begin() + mappedParamIdx, semanticResultArg);
636
32
  }
637
296
  assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() &&
638
296
         "Number of pullback results should match number of differentiability "
639
296
         "parameters");
640
641
0
  SmallVector<SILValue, 8> results;
642
644
  for (unsigned i : actualConfig.parameterIndices->getIndices()) {
643
644
    unsigned mappedIndex = mapOriginalParameterIndex(i);
644
    // If result is desired:
645
    // - Do nothing if result is indirect.
646
    //   (It was already forwarded to the `apply` instruction).
647
    // - Push it to `results` if result is direct.
648
644
    auto result = allResults[mappedIndex];
649
644
    if (desiredConfig.isWrtParameter(i)) {
650
340
      if (result->getType().isObject())
651
172
        results.push_back(result);
652
340
    }
653
    // Otherwise, cleanup the unused results.
654
304
    else {
655
304
      if (result->getType().isAddress())
656
148
        builder.emitDestroyAddrAndFold(loc, result);
657
156
      else
658
156
        builder.emitDestroyValueOperation(loc, result);
659
304
    }
660
644
  }
661
  // Deallocate local allocations and return final direct result.
662
296
  cleanupValues();
663
296
  auto result = joinElements(results, builder, loc);
664
296
  builder.createReturn(loc, result);
665
666
296
  return {thunk, interfaceSubs};
667
592
}
668
669
std::pair<SILFunction *, SubstitutionMap>
670
getOrCreateSubsetParametersThunkForDerivativeFunction(
671
    SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
672
    AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig,
673
1.81k
    const AutoDiffConfig &actualConfig, ADContext &adContext) {
674
1.81k
  LLVM_DEBUG(getADDebugStream()
675
1.81k
             << "Getting a subset parameters thunk for derivative "
676
1.81k
             << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp")
677
1.81k
             << " function " << derivativeFn
678
1.81k
             << " of the original function " << origFnOperand
679
1.81k
             << " from " << actualConfig << " to " << desiredConfig << '\n');
680
681
1.81k
  auto origFnType = origFnOperand->getType().castTo<SILFunctionType>();
682
1.81k
  auto &module = fb.getModule();
683
1.81k
  auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
684
685
  // Compute target type for thunking.
686
1.81k
  auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
687
1.81k
  auto targetType = origFnType->getAutoDiffDerivativeFunctionType(
688
1.81k
      desiredConfig.parameterIndices, desiredConfig.resultIndices, kind,
689
1.81k
      module.Types, lookupConformance);
690
1.81k
  auto *caller = derivativeFn->getFunction();
691
1.81k
  if (targetType->hasArchetype()) {
692
96
    auto substTargetType =
693
96
        caller->mapTypeIntoContext(targetType->mapTypeOutOfContext())
694
96
            ->getCanonicalType();
695
96
    targetType = SILType::getPrimitiveObjectType(substTargetType)
696
96
                     .castTo<SILFunctionType>();
697
96
  }
698
1.81k
  assert(derivativeFnType->getNumParameters() ==
699
1.81k
         targetType->getNumParameters());
700
0
  assert(derivativeFnType->getNumResults() == targetType->getNumResults());
701
702
  // Build thunk type.
703
0
  SubstitutionMap interfaceSubs;
704
1.81k
  GenericEnvironment *genericEnv = nullptr;
705
1.81k
  auto thunkType = buildThunkType(derivativeFn->getFunction(), derivativeFnType,
706
1.81k
                                  targetType, genericEnv, interfaceSubs,
707
1.81k
                                  /*withoutActuallyEscaping*/ false,
708
1.81k
                                  DifferentiationThunkKind::IndexSubset);
709
710
  // FIXME: The logic for resolving `assocRef` does not reapply function
711
  // conversions, which is problematic if `derivativeFn` is a `partial_apply`
712
  // instruction.
713
1.81k
  StringRef origName;
714
1.81k
  if (auto *origFnRef =
715
1.81k
          peerThroughFunctionConversions<FunctionRefInst>(origFnOperand)) {
716
1.76k
    origName = origFnRef->getReferencedFunction()->getName();
717
1.76k
  } else if (auto *origMethodInst =
718
48
                 peerThroughFunctionConversions<MethodInst>(origFnOperand)) {
719
48
    origName = origMethodInst->getMember()
720
48
                   .getAnyFunctionRef()
721
48
                   ->getAbstractFunctionDecl()
722
48
                   ->getNameStr();
723
48
  }
724
1.81k
  assert(!origName.empty() && "Original function name could not be resolved");
725
0
  Mangle::DifferentiationMangler mangler;
726
1.81k
  auto thunkName = mangler.mangleDerivativeFunctionSubsetParametersThunk(
727
1.81k
      origName, targetType->mapTypeOutOfContext()->getCanonicalType(),
728
1.81k
      kind, actualConfig.parameterIndices, actualConfig.resultIndices,
729
1.81k
      desiredConfig.parameterIndices);
730
731
1.81k
  auto loc = origFnOperand.getLoc();
732
1.81k
  auto *thunk = fb.getOrCreateSharedFunction(
733
1.81k
      loc, thunkName, thunkType, IsBare, IsTransparent, caller->isSerialized(),
734
1.81k
      ProfileCounter(), IsThunk, IsNotDynamic, IsNotDistributed,
735
1.81k
      IsNotRuntimeAccessible);
736
737
1.81k
  if (!thunk->empty())
738
1.16k
    return {thunk, interfaceSubs};
739
740
648
  thunk->setGenericEnvironment(genericEnv);
741
648
  auto *entry = thunk->createBasicBlock();
742
648
  SILBuilder builder(entry);
743
648
  createEntryArguments(thunk);
744
745
648
  SubstitutionMap assocSubstMap;
746
648
  if (auto *partialApply = dyn_cast<PartialApplyInst>(derivativeFn))
747
280
    assocSubstMap = partialApply->getSubstitutionMap();
748
749
  // FIXME: The logic for resolving `assocRef` does not reapply function
750
  // conversions, which is problematic if `derivativeFn` is a `partial_apply`
751
  // instruction.
752
648
  SILValue assocRef;
753
648
  if (auto *derivativeFnRef =
754
648
          peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) {
755
0
    auto *assoc = derivativeFnRef->getReferencedFunction();
756
0
    assocRef = builder.createFunctionRef(loc, assoc);
757
648
  } else if (auto *assocMethodInst =
758
648
                 peerThroughFunctionConversions<WitnessMethodInst>(
759
648
                     derivativeFn)) {
760
24
    assocRef = builder.createWitnessMethod(
761
24
        loc, assocMethodInst->getLookupType(),
762
24
        assocMethodInst->getConformance(), assocMethodInst->getMember(),
763
24
        thunk->mapTypeIntoContext(assocMethodInst->getType()));
764
624
  } else if (auto *assocMethodInst =
765
624
                 peerThroughFunctionConversions<ClassMethodInst>(
766
624
                     derivativeFn)) {
767
8
    auto classOperand = thunk->getArgumentsWithoutIndirectResults().back();
768
8
#ifndef NDEBUG
769
8
    auto classOperandType = assocMethodInst->getOperand()->getType();
770
8
    assert(classOperand->getType() == classOperandType);
771
0
#endif
772
0
    assocRef = builder.createClassMethod(
773
8
        loc, classOperand, assocMethodInst->getMember(),
774
8
        thunk->mapTypeIntoContext(assocMethodInst->getType()));
775
616
  } else if (auto *diffWitFn = peerThroughFunctionConversions<
776
616
                 DifferentiabilityWitnessFunctionInst>(derivativeFn)) {
777
616
    assocRef = builder.createDifferentiabilityWitnessFunction(
778
616
        loc, diffWitFn->getWitnessKind(), diffWitFn->getWitness());
779
616
  }
780
0
  assert(assocRef && "Expected derivative function to be resolved");
781
782
0
  assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap());
783
648
  derivativeFnType = assocRef->getType().castTo<SILFunctionType>();
784
785
648
  SmallVector<SILValue, 4> arguments;
786
648
  arguments.append(thunk->getArguments().begin(), thunk->getArguments().end());
787
648
  assert(arguments.size() ==
788
648
         derivativeFnType->getNumParameters() +
789
648
             derivativeFnType->getNumIndirectFormalResults());
790
0
  auto *apply = builder.createApply(loc, assocRef, assocSubstMap, arguments);
791
792
  // Extract all direct results.
793
648
  SmallVector<SILValue, 8> directResults;
794
648
  extractAllElements(apply, builder, directResults);
795
648
  auto linearMap = directResults.back();
796
648
  directResults.pop_back();
797
798
648
  auto linearMapType = linearMap->getType().castTo<SILFunctionType>();
799
648
  auto linearMapTargetType = targetType->getResults()
800
648
                                 .back()
801
648
                                 .getSILStorageInterfaceType()
802
648
                                 .castTo<SILFunctionType>();
803
648
  auto unsubstLinearMapType = linearMapType->getUnsubstitutedType(module);
804
648
  auto unsubstLinearMapTargetType =
805
648
      linearMapTargetType->getUnsubstitutedType(module);
806
807
648
  SILFunction *linearMapThunk;
808
648
  SubstitutionMap linearMapSubs;
809
648
  std::tie(linearMapThunk, linearMapSubs) =
810
648
      getOrCreateSubsetParametersThunkForLinearMap(
811
648
          fb, thunk, origFnType, unsubstLinearMapType,
812
648
          unsubstLinearMapTargetType, kind, desiredConfig, actualConfig,
813
648
          adContext);
814
815
648
  auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
816
648
  SILValue thunkedLinearMap = linearMap;
817
648
  if (linearMapType != unsubstLinearMapType) {
818
280
    thunkedLinearMap = builder.createConvertFunction(
819
280
        loc, thunkedLinearMap,
820
280
        SILType::getPrimitiveObjectType(unsubstLinearMapType),
821
280
        /*withoutActuallyEscaping*/ false);
822
280
  }
823
648
  thunkedLinearMap = builder.createPartialApply(
824
648
      loc, linearMapThunkFRI, linearMapSubs, {thunkedLinearMap},
825
648
      ParameterConvention::Direct_Guaranteed);
826
648
  if (linearMapTargetType != unsubstLinearMapTargetType) {
827
64
    thunkedLinearMap = builder.createConvertFunction(
828
64
        loc, thunkedLinearMap,
829
64
        SILType::getPrimitiveObjectType(linearMapTargetType),
830
64
        /*withoutActuallyEscaping*/ false);
831
64
  }
832
648
  assert(origFnType->getNumAutoDiffSemanticResults() > 0);
833
648
  if (origFnType->getNumResults() > 0 &&
834
648
      origFnType->getResults().front().isFormalDirect()) {
835
352
    directResults.push_back(thunkedLinearMap);
836
352
    auto result = joinElements(directResults, builder, loc);
837
352
    builder.createReturn(loc, result);
838
352
  } else {
839
296
    builder.createReturn(loc, thunkedLinearMap);
840
296
  }
841
842
648
  return {thunk, interfaceSubs};
843
1.81k
}
844
845
} // end namespace autodiff
846
} // end namespace swift