Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/VJPCloner.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- VJPCloner.cpp - VJP function generation --------------*- C++ -*---===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// This file defines a helper class for generating VJP functions for automatic
14
// differentiation.
15
//
16
//===----------------------------------------------------------------------===//
17
18
#define DEBUG_TYPE "differentiation"
19
20
#include "swift/AST/Types.h"
21
22
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
23
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
24
#include "swift/SILOptimizer/Differentiation/ADContext.h"
25
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
26
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
27
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
28
#include "swift/SILOptimizer/Differentiation/Thunk.h"
29
30
#include "swift/SIL/TerminatorUtils.h"
31
#include "swift/SIL/TypeSubstCloner.h"
32
#include "swift/SILOptimizer/Analysis/LoopAnalysis.h"
33
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
34
#include "swift/SILOptimizer/Utils/CFGOptUtils.h"
35
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
36
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
37
#include "llvm/ADT/DenseMap.h"
38
39
namespace swift {
40
namespace autodiff {
41
42
class VJPCloner::Implementation final
43
    : public TypeSubstCloner<VJPCloner::Implementation, SILOptFunctionBuilder> {
44
  friend class VJPCloner;
45
  friend class PullbackCloner;
46
47
  /// The parent VJP cloner.
48
  VJPCloner &cloner;
49
50
  /// The global context.
51
  ADContext &context;
52
53
  /// The original function.
54
  SILFunction *const original;
55
56
  /// The differentiability witness.
57
  SILDifferentiabilityWitness *const witness;
58
59
  /// The VJP function.
60
  SILFunction *const vjp;
61
62
  /// The pullback function.
63
  SILFunction *pullback;
64
65
  /// The differentiation invoker.
66
  DifferentiationInvoker invoker;
67
68
  /// Info from activity analysis on the original function.
69
  const DifferentiableActivityInfo &activityInfo;
70
71
  /// The loop info.
72
  SILLoopInfo *loopInfo;
73
74
  /// The linear map info.
75
  LinearMapInfo pullbackInfo;
76
77
  /// Caches basic blocks whose phi arguments have been remapped (adding a
78
  /// predecessor enum argument).
79
  SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks;
80
81
  /// The `AutoDiffLinearMapContext` object. If null, no explicit context is
82
  /// needed (no loops).
83
  SILValue pullbackContextValue;
84
  /// The unique, borrowed context object. This is valid until the exit block.
85
  SILValue borrowedPullbackContextValue;
86
87
  /// The generic signature of the `Builtin.autoDiffAllocateSubcontext(_:_:)`
88
  /// declaration. It is used for creating a builtin call.
89
  GenericSignature builtinAutoDiffAllocateSubcontextGenericSignature;
90
91
  bool errorOccurred = false;
92
93
  /// Mapping from original blocks to pullback values. Used to build pullback
94
  /// struct instances.
95
  llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> pullbackValues;
96
97
12.5k
  ASTContext &getASTContext() const { return vjp->getASTContext(); }
98
28.4k
  SILModule &getModule() const { return vjp->getModule(); }
99
179k
  const AutoDiffConfig &getConfig() const {
100
179k
    return witness->getConfig();
101
179k
  }
102
103
  Implementation(VJPCloner &parent, ADContext &context,
104
                 SILDifferentiabilityWitness *witness, SILFunction *vjp,
105
                 DifferentiationInvoker invoker);
106
107
  /// Creates an empty pullback function, to be filled in by `PullbackCloner`.
108
  SILFunction *createEmptyPullback();
109
110
  /// Run VJP generation. Returns true on error.
111
  bool run();
112
113
  /// Initializes a context object if needed.
114
5.25k
  void emitLinearMapContextInitializationIfNeeded() {
115
5.25k
    if (!pullbackInfo.hasHeapAllocatedContext())
116
5.14k
      return;
117
 
118
    // Get linear map struct size.
119
108
    auto *returnBB = &*original->findReturnBB();
120
108
    auto pullbackTupleType =
121
108
      remapASTType(pullbackInfo.getLinearMapTupleType(returnBB)->getCanonicalType());
122
108
    Builder.setInsertionPoint(vjp->getEntryBlock());
123
124
108
    auto pbTupleMetatypeType =
125
108
        CanMetatypeType::get(pullbackTupleType, MetatypeRepresentation::Thick);
126
108
    auto pbTupleMetatypeSILType =
127
108
        SILType::getPrimitiveObjectType(pbTupleMetatypeType);
128
108
    auto pbTupleMetatype =
129
108
        Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType);
130
131
    // Create an context.
132
108
    pullbackContextValue = Builder.createBuiltin(
133
108
        original->getLocation(),
134
108
        getASTContext().getIdentifier(getBuiltinName(
135
108
            BuiltinValueKind::AutoDiffCreateLinearMapContextWithType)),
136
108
        SILType::getNativeObjectType(getASTContext()), SubstitutionMap(),
137
108
        {pbTupleMetatype});
138
108
    borrowedPullbackContextValue = Builder.createBeginBorrow(
139
108
        original->getLocation(), pullbackContextValue);
140
108
    LLVM_DEBUG(getADDebugStream()
141
108
               << "Context object initialized because there are loops\n"
142
108
               << *vjp->getEntryBlock() << '\n'
143
108
               << "pullback tuple type: " << pullbackTupleType << '\n');
144
108
  }
145
146
  /// Get the lowered SIL type of the given AST type.
147
5.71k
  SILType getLoweredType(Type type) {
148
5.71k
    auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
149
5.71k
    Lowering::AbstractionPattern pattern(vjpGenSig,
150
5.71k
                                         type->getReducedType(vjpGenSig));
151
5.71k
    return vjp->getLoweredType(pattern, type);
152
5.71k
  }
153
154
0
  GenericSignature getBuiltinAutoDiffAllocateSubcontextDecl() {
155
0
    if (builtinAutoDiffAllocateSubcontextGenericSignature)
156
0
      return builtinAutoDiffAllocateSubcontextGenericSignature;
157
0
    auto &ctx = getASTContext();
158
0
    auto *decl = cast<FuncDecl>(getBuiltinValueDecl(
159
0
        ctx, ctx.getIdentifier(getBuiltinName(
160
0
                 BuiltinValueKind::AutoDiffAllocateSubcontextWithType))));
161
0
    builtinAutoDiffAllocateSubcontextGenericSignature =
162
0
        decl->getGenericSignature();
163
0
    assert(builtinAutoDiffAllocateSubcontextGenericSignature);
164
0
    return builtinAutoDiffAllocateSubcontextGenericSignature;
165
0
  }
166
167
  // Creates a trampoline block for given original terminator instruction, the
168
  // pullback struct value for its parent block, and a successor basic block.
169
  //
170
  // The trampoline block has the same arguments as and branches to the remapped
171
  // successor block, but drops the last predecessor enum argument.
172
  //
173
  // Used for cloning branching terminator instructions with specific
174
  // requirements on successor block arguments, where an additional predecessor
175
  // enum argument is not acceptable.
176
  SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst,
177
                                            TupleInst *pbTupleVal,
178
                                            SILBasicBlock *succBB);
179
180
  /// Build a pullback tuple value for the given original terminator
181
  /// instruction.
182
  TupleInst *buildPullbackValueTupleValue(TermInst *termInst);
183
  llvm::SmallVector<SILValue, 8> getPullbackValues(SILBasicBlock *origBB);
184
185
  /// Build a predecessor enum instance using the given builder for the given
186
  /// original predecessor/successor blocks and pullback struct value.
187
  EnumInst *buildPredecessorEnumValue(SILBuilder &builder,
188
                                      SILBasicBlock *predBB,
189
                                      SILBasicBlock *succBB,
190
                                      SILValue pbTupleVal);
191
192
public:
193
  /// Remap original basic blocks, adding predecessor enum arguments.
194
2.63k
  SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
195
2.63k
    auto *vjpBB = BBMap[bb];
196
    // If error has occurred, or if block has already been remapped, return
197
    // remapped, return remapped block.
198
2.63k
    if (errorOccurred || remappedBasicBlocks.count(bb))
199
648
      return vjpBB;
200
    // Add predecessor enum argument to the remapped block.
201
1.98k
    auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb);
202
1.98k
    auto enumTy =
203
1.98k
        getOpASTType(predEnum->getDeclaredInterfaceType()->getCanonicalType());
204
1.98k
    auto enumLoweredTy = context.getTypeConverter().getLoweredType(
205
1.98k
        enumTy, TypeExpansionContext::minimal());
206
1.98k
    vjpBB->createPhiArgument(enumLoweredTy, OwnershipKind::Owned);
207
1.98k
    remappedBasicBlocks.insert(bb);
208
1.98k
    return vjpBB;
209
2.63k
  }
210
211
  /// General visitor for all instructions. If any error is emitted by previous
212
  /// visits, bail out.
213
86.1k
  void visit(SILInstruction *inst) {
214
86.1k
    if (errorOccurred)
215
192
      return;
216
85.9k
    TypeSubstCloner::visit(inst);
217
85.9k
  }
218
219
0
  void visitSILInstruction(SILInstruction *inst) {
220
0
    context.emitNondifferentiabilityError(
221
0
        inst, invoker, diag::autodiff_expression_not_differentiable_note);
222
0
    errorOccurred = true;
223
0
  }
224
225
73.0k
  void postProcess(SILInstruction *orig, SILInstruction *cloned) {
226
73.0k
    if (errorOccurred)
227
0
      return;
228
73.0k
    SILClonerWithScopes::postProcess(orig, cloned);
229
73.0k
  }
230
231
5.23k
  void visitReturnInst(ReturnInst *ri) {
232
5.23k
    Builder.setCurrentDebugScope(getOpScope(ri->getDebugScope()));
233
5.23k
    auto loc = ri->getOperand().getLoc();
234
    // Build pullback tuple value for original block.
235
5.23k
    auto *origExit = ri->getParent();
236
237
    // Get the value in the VJP corresponding to the original result.
238
5.23k
    auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
239
5.23k
    auto origResult = getOpValue(origRetInst->getOperand());
240
5.23k
    SmallVector<SILValue, 8> origResults;
241
5.23k
    extractAllElements(origResult, Builder, origResults);
242
243
    // Get and partially apply the pullback.
244
5.23k
    auto vjpSubstMap = vjp->getForwardingSubstitutionMap();
245
5.23k
    auto *pullbackRef = Builder.createFunctionRef(loc, pullback);
246
247
    // Prepare partial application arguments.
248
5.23k
    SILValue partialApplyArg;
249
5.23k
    PartialApplyInst *pullbackPartialApply;
250
5.23k
    if (borrowedPullbackContextValue) {
251
104
      auto *pbTupleVal = buildPullbackValueTupleValue(ri);
252
      // Initialize the top-level subcontext buffer with the top-level pullback
253
      // tuple.
254
104
      auto addr = emitProjectTopLevelSubcontext(
255
104
          Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType());
256
104
      Builder.createStore(
257
104
          loc, pbTupleVal, addr,
258
104
          pbTupleVal->getType().isTrivial(*pullback) ?
259
92
              StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init);
260
261
104
      Builder.createEndBorrow(loc, borrowedPullbackContextValue);
262
104
      pullbackPartialApply = Builder.createPartialApply(
263
104
        loc, pullbackRef, vjpSubstMap, {pullbackContextValue},
264
104
        ParameterConvention::Direct_Guaranteed);
265
5.13k
    } else {
266
5.13k
      pullbackPartialApply = Builder.createPartialApply(
267
5.13k
        loc, pullbackRef, vjpSubstMap, getPullbackValues(origExit),
268
5.13k
        ParameterConvention::Direct_Guaranteed);
269
5.13k
    }
270
271
5.23k
    auto pullbackType = vjp->mapTypeIntoContext(
272
5.23k
        vjp->getConventions().getSILType(
273
5.23k
            vjp->getLoweredFunctionType()->getResults().back(),
274
5.23k
            vjp->getTypeExpansionContext()));
275
5.23k
    auto pullbackFnType = pullbackType.castTo<SILFunctionType>();
276
5.23k
    auto pullbackSubstType =
277
5.23k
        pullbackPartialApply->getType().castTo<SILFunctionType>();
278
279
    // If necessary, convert the pullback value to the returned pullback
280
    // function type.
281
5.23k
    SILValue pullbackValue;
282
5.23k
    if (pullbackSubstType == pullbackFnType) {
283
4.36k
      pullbackValue = pullbackPartialApply;
284
4.36k
    } else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp)
285
872
                   .isCompatible()) {
286
872
      pullbackValue =
287
872
          Builder.createConvertFunction(loc, pullbackPartialApply, pullbackType,
288
872
                                        /*withoutActuallyEscaping*/ false);
289
872
    } else {
290
0
      llvm::report_fatal_error("Pullback value type is not ABI-compatible "
291
0
                               "with the returned pullback type");
292
0
    }
293
294
    // Return a tuple of the original result and pullback.
295
5.23k
    SmallVector<SILValue, 8> directResults;
296
5.23k
    directResults.append(origResults.begin(), origResults.end());
297
5.23k
    directResults.push_back(pullbackValue);
298
5.23k
    Builder.createReturn(ri->getLoc(),
299
5.23k
                         joinElements(directResults, Builder, loc));
300
5.23k
  }
301
302
1.19k
  void visitBranchInst(BranchInst *bi) {
303
1.19k
    Builder.setCurrentDebugScope(getOpScope(bi->getDebugScope()));
304
    // Build pullback struct value for original block.
305
    // Build predecessor enum value for destination block.
306
1.19k
    auto *origBB = bi->getParent();
307
1.19k
    auto *pbTupleVal = buildPullbackValueTupleValue(bi);
308
1.19k
    auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB,
309
1.19k
                                              bi->getDestBB(), pbTupleVal);
310
311
    // Remap arguments, appending the new enum values.
312
1.19k
    SmallVector<SILValue, 8> args;
313
1.19k
    for (auto origArg : bi->getArgs())
314
528
      args.push_back(getOpValue(origArg));
315
1.19k
    args.push_back(enumVal);
316
317
    // Create a new `br` instruction.
318
1.19k
    getBuilder().createBranch(bi->getLoc(), getOpBasicBlock(bi->getDestBB()),
319
1.19k
                              args);
320
1.19k
  }
321
322
224
  void visitCondBranchInst(CondBranchInst *cbi) {
323
224
    Builder.setCurrentDebugScope(getOpScope(cbi->getDebugScope()));
324
    // Build pullback struct value for original block.
325
224
    auto *pbTupleVal = buildPullbackValueTupleValue(cbi);
326
    // Create a new `cond_br` instruction.
327
224
    getBuilder().createCondBranch(
328
224
        cbi->getLoc(), getOpValue(cbi->getCondition()),
329
224
        createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getTrueBB()),
330
224
        createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getFalseBB()));
331
224
  }
332
333
452
  void visitSwitchEnumTermInst(SwitchEnumTermInst inst) {
334
452
    Builder.setCurrentDebugScope(getOpScope(inst->getDebugScope()));
335
    // Build pullback tuple value for original block.
336
452
    auto *pbTupleVal = buildPullbackValueTupleValue(*inst);
337
338
    // Create trampoline successor basic blocks.
339
452
    SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
340
876
    for (unsigned i : range(inst.getNumCases())) {
341
876
      auto caseBB = inst.getCase(i);
342
876
      auto *trampolineBB =
343
876
          createTrampolineBasicBlock(inst, pbTupleVal, caseBB.second);
344
876
      caseBBs.push_back({caseBB.first, trampolineBB});
345
876
    }
346
    // Create trampoline default basic block.
347
452
    SILBasicBlock *newDefaultBB = nullptr;
348
452
    if (auto *defaultBB = inst.getDefaultBBOrNull().getPtrOrNull())
349
20
      newDefaultBB = createTrampolineBasicBlock(inst, pbTupleVal, defaultBB);
350
351
    // Create a new `switch_enum` instruction.
352
452
    switch (inst->getKind()) {
353
324
    case SILInstructionKind::SwitchEnumInst:
354
324
      getBuilder().createSwitchEnum(
355
324
          inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs);
356
324
      break;
357
128
    case SILInstructionKind::SwitchEnumAddrInst:
358
128
      getBuilder().createSwitchEnumAddr(
359
128
          inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs);
360
128
      break;
361
0
    default:
362
0
      llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`");
363
452
    }
364
452
  }
365
366
324
  void visitSwitchEnumInst(SwitchEnumInst *sei) {
367
324
    visitSwitchEnumTermInst(sei);
368
324
  }
369
370
128
  void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
371
128
    visitSwitchEnumTermInst(seai);
372
128
  }
373
374
4
  void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
375
4
    Builder.setCurrentDebugScope(getOpScope(ccbi->getDebugScope()));
376
    // Build pullback struct value for original block.
377
4
    auto *pbTupleVal = buildPullbackValueTupleValue(ccbi);
378
    // Create a new `checked_cast_branch` instruction.
379
4
    getBuilder().createCheckedCastBranch(
380
4
        ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()),
381
4
        getOpASTType(ccbi->getSourceFormalType()),
382
4
        getOpType(ccbi->getTargetLoweredType()),
383
4
        getOpASTType(ccbi->getTargetFormalType()),
384
4
        createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getSuccessBB()),
385
4
        createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getFailureBB()),
386
4
        ccbi->getTrueBBCount(), ccbi->getFalseBBCount());
387
4
  }
388
389
8
  void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi) {
390
8
    Builder.setCurrentDebugScope(getOpScope(ccabi->getDebugScope()));
391
    // Build pullback struct value for original block.
392
8
    auto *pbTupleVal = buildPullbackValueTupleValue(ccabi);
393
    // Create a new `checked_cast_addr_branch` instruction.
394
8
    getBuilder().createCheckedCastAddrBranch(
395
8
        ccabi->getLoc(), ccabi->getConsumptionKind(),
396
8
        getOpValue(ccabi->getSrc()), getOpASTType(ccabi->getSourceFormalType()),
397
8
        getOpValue(ccabi->getDest()),
398
8
        getOpASTType(ccabi->getTargetFormalType()),
399
8
        createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getSuccessBB()),
400
8
        createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getFailureBB()),
401
8
        ccabi->getTrueBBCount(), ccabi->getFalseBBCount());
402
8
  }
403
404
  // If an `apply` has active results or active inout arguments, replace it
405
  // with an `apply` of its VJP.
406
9.62k
  void visitApplyInst(ApplyInst *ai) {
407
    // If callee should not be differentiated, do standard cloning.
408
9.62k
    if (!pullbackInfo.shouldDifferentiateApplySite(ai)) {
409
3.32k
      LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
410
3.32k
      TypeSubstCloner::visitApplyInst(ai);
411
3.32k
      return;
412
3.32k
    }
413
    // If callee is `array.uninitialized_intrinsic`, do standard cloning.
414
    // `array.uninitialized_intrinsic` differentiation is handled separately.
415
6.30k
    if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) {
416
212
      LLVM_DEBUG(getADDebugStream()
417
212
                 << "Cloning `array.uninitialized_intrinsic` `apply`:\n"
418
212
                 << *ai << '\n');
419
212
      TypeSubstCloner::visitApplyInst(ai);
420
212
      return;
421
212
    }
422
    // If callee is `array.finalize_intrinsic`, do standard cloning.
423
    // `array.finalize_intrinsic` has special-case pullback generation.
424
6.09k
    if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) {
425
212
      LLVM_DEBUG(getADDebugStream()
426
212
                 << "Cloning `array.finalize_intrinsic` `apply`:\n"
427
212
                 << *ai << '\n');
428
212
      TypeSubstCloner::visitApplyInst(ai);
429
212
      return;
430
212
    }
431
    // If the original function is a semantic member accessor, do standard
432
    // cloning. Semantic member accessors have special pullback generation
433
    // logic, so all `apply` instructions can be directly cloned to the VJP.
434
5.88k
    if (isSemanticMemberAccessor(original)) {
435
152
      LLVM_DEBUG(getADDebugStream()
436
152
                 << "Cloning `apply` in semantic member accessor:\n"
437
152
                 << *ai << '\n');
438
152
      TypeSubstCloner::visitApplyInst(ai);
439
152
      return;
440
152
    }
441
442
5.73k
    Builder.setCurrentDebugScope(getOpScope(ai->getDebugScope()));
443
5.73k
    auto loc = ai->getLoc();
444
5.73k
    auto &builder = getBuilder();
445
5.73k
    auto origCallee = getOpValue(ai->getCallee());
446
5.73k
    auto originalFnTy = origCallee->getType().castTo<SILFunctionType>();
447
448
5.73k
    LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n');
449
450
    // Get the minimal parameter and result indices required for differentiating
451
    // this `apply`.
452
5.73k
    SmallVector<SILValue, 4> allResults;
453
5.73k
    SmallVector<unsigned, 8> activeParamIndices;
454
5.73k
    SmallVector<unsigned, 8> activeResultIndices;
455
5.73k
    collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo,
456
5.73k
                                         allResults, activeParamIndices,
457
5.73k
                                         activeResultIndices);
458
5.73k
    assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
459
0
    assert(!activeResultIndices.empty() && "Result indices cannot be empty");
460
5.73k
    LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=(";
461
5.73k
               llvm::interleave(
462
5.73k
                   activeParamIndices.begin(), activeParamIndices.end(),
463
5.73k
                   [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
464
5.73k
               s << "), results=("; llvm::interleave(
465
5.73k
                   activeResultIndices.begin(), activeResultIndices.end(),
466
5.73k
                   [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
467
5.73k
               s << ")\n";);
468
469
    // Form expected indices.
470
5.73k
    AutoDiffConfig config(
471
5.73k
        IndexSubset::get(getASTContext(),
472
5.73k
                         ai->getArgumentsWithoutIndirectResults().size(),
473
5.73k
                         activeParamIndices),
474
5.73k
        IndexSubset::get(getASTContext(),
475
5.73k
                         ai->getSubstCalleeType()->getNumAutoDiffSemanticResults(),
476
5.73k
                         activeResultIndices));
477
478
    // Emit the VJP.
479
5.73k
    SILValue vjpValue;
480
    // If functionSource is a `@differentiable` function, just extract it.
481
5.73k
    if (originalFnTy->isDifferentiable()) {
482
156
      auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices();
483
204
      for (auto i : config.parameterIndices->getIndices()) {
484
204
        if (!paramIndices->contains(i)) {
485
0
          context.emitNondifferentiabilityError(
486
0
              origCallee, invoker,
487
0
              diag::
488
0
                  autodiff_function_noderivative_parameter_not_differentiable);
489
0
          errorOccurred = true;
490
0
          return;
491
0
        }
492
204
      }
493
156
      builder.emitScopedBorrowOperation(
494
156
          loc, origCallee, [&](SILValue borrowedDiffFunc) {
495
156
            auto origFnType = origCallee->getType().castTo<SILFunctionType>();
496
156
            auto origFnUnsubstType =
497
156
                origFnType->getUnsubstitutedType(getModule());
498
156
            if (origFnType != origFnUnsubstType) {
499
20
              borrowedDiffFunc = builder.createConvertFunction(
500
20
                  loc, borrowedDiffFunc,
501
20
                  SILType::getPrimitiveObjectType(origFnUnsubstType),
502
20
                  /*withoutActuallyEscaping*/ false);
503
20
            }
504
156
            vjpValue = builder.createDifferentiableFunctionExtract(
505
156
                loc, NormalDifferentiableFunctionTypeComponent::VJP,
506
156
                borrowedDiffFunc);
507
156
            vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
508
156
          });
509
156
      auto vjpFnType = vjpValue->getType().castTo<SILFunctionType>();
510
156
      auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule());
511
156
      if (vjpFnType != vjpFnUnsubstType) {
512
0
        vjpValue = builder.createConvertFunction(
513
0
            loc, vjpValue, SILType::getPrimitiveObjectType(vjpFnUnsubstType),
514
0
            /*withoutActuallyEscaping*/ false);
515
0
      }
516
156
    }
517
518
    // Check and diagnose non-differentiable original function type.
519
5.73k
    auto diagnoseNondifferentiableOriginalFunctionType =
520
7.82k
        [&](CanSILFunctionType origFnTy) {
521
          // Check and diagnose non-differentiable arguments.
522
11.9k
          for (auto paramIndex : config.parameterIndices->getIndices()) {
523
11.9k
            if (!originalFnTy->getParameters()[paramIndex]
524
11.9k
                     .getSILStorageInterfaceType()
525
11.9k
                     .isDifferentiable(getModule())) {
526
8
              auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex];
527
              // FIXME: This shouldn't be necessary and might indicate a bug in
528
              // the transformation.
529
8
              RegularLocation nonAutoGenLoc(arg.getLoc());
530
8
              nonAutoGenLoc.markNonAutoGenerated();
531
8
              auto startLoc = nonAutoGenLoc.getStartSourceLoc();
532
8
              auto endLoc = nonAutoGenLoc.getEndSourceLoc();
533
8
              context
534
8
                  .emitNondifferentiabilityError(
535
8
                      arg, invoker, diag::autodiff_nondifferentiable_argument)
536
8
                  .fixItInsert(startLoc, "withoutDerivative(at: ")
537
8
                  .fixItInsertAfter(endLoc, ")");
538
8
              errorOccurred = true;
539
8
              return true;
540
8
            }
541
11.9k
          }
542
          // Check and diagnose non-differentiable results.
543
7.92k
          for (auto resultIndex : config.resultIndices->getIndices()) {
544
7.92k
            SILType remappedResultType;
545
7.92k
            if (resultIndex >= originalFnTy->getNumResults()) {
546
648
              auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults();
547
648
              auto semanticResultArg =
548
648
                  *std::next(ai->getAutoDiffSemanticResultArguments().begin(),
549
648
                             semanticResultArgIdx);
550
648
              remappedResultType = semanticResultArg->getType();
551
7.27k
            } else {
552
7.27k
              remappedResultType = originalFnTy->getResults()[resultIndex]
553
7.27k
                                       .getSILStorageInterfaceType();
554
7.27k
            }
555
7.92k
            if (!remappedResultType.isDifferentiable(getModule())) {
556
12
              auto startLoc = ai->getLoc().getStartSourceLoc();
557
12
              auto endLoc = ai->getLoc().getEndSourceLoc();
558
12
              context
559
12
                  .emitNondifferentiabilityError(
560
12
                      origCallee, invoker,
561
12
                      diag::autodiff_nondifferentiable_result)
562
12
                  .fixItInsert(startLoc, "withoutDerivative(at: ")
563
12
                  .fixItInsertAfter(endLoc, ")");
564
12
              errorOccurred = true;
565
12
              return true;
566
12
            }
567
7.92k
          }
568
7.80k
          return false;
569
7.81k
        };
570
5.73k
    if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
571
12
      return;
572
573
    // If VJP has not yet been found, emit an `differentiable_function`
574
    // instruction on the remapped original function operand and
575
    // an `differentiable_function_extract` instruction to get the VJP.
576
    // The `differentiable_function` instruction will be canonicalized during
577
    // the transform main loop.
578
5.72k
    if (!vjpValue) {
579
      // FIXME: Handle indirect differentiation invokers. This may require some
580
      // redesign: currently, each original function + witness pair is mapped
581
      // only to one invoker.
582
      /*
583
      DifferentiationInvoker indirect(ai, attr);
584
      auto insertion =
585
          context.getInvokers().try_emplace({original, attr}, indirect);
586
      auto &invoker = insertion.first->getSecond();
587
      invoker = indirect;
588
      */
589
590
      // If the original `apply` instruction has a substitution map, then the
591
      // applied function is specialized.
592
      // In the VJP, specialization is also necessary for parity. The original
593
      // function operand is specialized with a remapped version of same
594
      // substitution map using an argument-less `partial_apply`.
595
5.56k
      if (ai->getSubstitutionMap().empty()) {
596
3.47k
        origCallee = builder.emitCopyValueOperation(loc, origCallee);
597
3.47k
      } else {
598
2.08k
        auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap());
599
2.08k
        auto vjpPartialApply = getBuilder().createPartialApply(
600
2.08k
            ai->getLoc(), origCallee, substMap, {},
601
2.08k
            ParameterConvention::Direct_Guaranteed);
602
2.08k
        origCallee = vjpPartialApply;
603
2.08k
        originalFnTy = origCallee->getType().castTo<SILFunctionType>();
604
        // Diagnose if new original function type is non-differentiable.
605
2.08k
        if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
606
8
          return;
607
2.08k
      }
608
609
5.55k
      auto *diffFuncInst = context.createDifferentiableFunction(
610
5.55k
          getBuilder(), loc, config.parameterIndices, config.resultIndices,
611
5.55k
          origCallee);
612
613
      // Record the `differentiable_function` instruction.
614
5.55k
      context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
615
616
5.55k
      builder.emitScopedBorrowOperation(
617
5.55k
          loc, diffFuncInst, [&](SILValue borrowedADFunc) {
618
5.55k
            auto extractedVJP =
619
5.55k
                getBuilder().createDifferentiableFunctionExtract(
620
5.55k
                    loc, NormalDifferentiableFunctionTypeComponent::VJP,
621
5.55k
                    borrowedADFunc);
622
5.55k
            vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
623
5.55k
          });
624
5.55k
      builder.emitDestroyValueOperation(loc, diffFuncInst);
625
5.55k
    }
626
627
    // Record desired/actual VJP indices.
628
    // Temporarily set original pullback type to `None`.
629
5.71k
    NestedApplyInfo info{config, /*originalPullbackType*/ llvm::None};
630
5.71k
    auto insertion = context.getNestedApplyInfo().try_emplace(ai, info);
631
5.71k
    auto &nestedApplyInfo = insertion.first->getSecond();
632
5.71k
    nestedApplyInfo = info;
633
634
    // Call the VJP using the original parameters.
635
5.71k
    SmallVector<SILValue, 8> vjpArgs;
636
5.71k
    auto vjpFnTy = getOpType(vjpValue->getType()).castTo<SILFunctionType>();
637
5.71k
    auto numVJPArgs =
638
5.71k
        vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults();
639
5.71k
    vjpArgs.reserve(numVJPArgs);
640
    // Collect substituted arguments.
641
5.71k
    for (auto origArg : ai->getArguments())
642
15.7k
      vjpArgs.push_back(getOpValue(origArg));
643
5.71k
    assert(vjpArgs.size() == numVJPArgs);
644
    // Apply the VJP.
645
    // The VJP should be specialized, so no substitution map is necessary.
646
0
    auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(),
647
5.71k
                                             vjpArgs, ai->getApplyOptions());
648
5.71k
    LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);
649
5.71k
    builder.emitDestroyValueOperation(loc, vjpValue);
650
651
    // Get the VJP results (original results and pullback).
652
5.71k
    SmallVector<SILValue, 8> vjpDirectResults;
653
5.71k
    extractAllElements(vjpCall, getBuilder(), vjpDirectResults);
654
5.71k
    ArrayRef<SILValue> originalDirectResults =
655
5.71k
        ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
656
5.71k
    SILValue originalDirectResult =
657
5.71k
        joinElements(originalDirectResults, getBuilder(), vjpCall->getLoc());
658
5.71k
    SILValue pullback = vjpDirectResults.back();
659
5.71k
    {
660
5.71k
      auto pullbackFnType = pullback->getType().castTo<SILFunctionType>();
661
5.71k
      auto pullbackUnsubstFnType =
662
5.71k
          pullbackFnType->getUnsubstitutedType(getModule());
663
5.71k
      if (pullbackFnType != pullbackUnsubstFnType) {
664
508
        pullback = builder.createConvertFunction(
665
508
            loc, pullback,
666
508
            SILType::getPrimitiveObjectType(pullbackUnsubstFnType),
667
508
            /*withoutActuallyEscaping*/ false);
668
508
      }
669
5.71k
    }
670
671
    // Store the original result to the value map.
672
5.71k
    mapValue(ai, originalDirectResult);
673
674
    // Checkpoint the pullback.
675
5.71k
    auto pullbackType = pullbackInfo.lookUpLinearMapType(ai);
676
677
    // If actual pullback type does not match lowered pullback type, reabstract
678
    // the pullback using a thunk.
679
5.71k
    auto actualPullbackType =
680
5.71k
        getOpType(pullback->getType()).getAs<SILFunctionType>();
681
5.71k
    auto loweredPullbackType =
682
5.71k
        getOpType(getLoweredType(pullbackType)).castTo<SILFunctionType>();
683
5.71k
    if (!loweredPullbackType->isEqual(actualPullbackType)) {
684
      // Set non-reabstracted original pullback type in nested apply info.
685
1.46k
      nestedApplyInfo.originalPullbackType = actualPullbackType;
686
1.46k
      SILOptFunctionBuilder fb(context.getTransform());
687
1.46k
      pullback = reabstractFunction(
688
1.46k
          getBuilder(), fb, ai->getLoc(), pullback, loweredPullbackType,
689
1.46k
          [this](SubstitutionMap subs) -> SubstitutionMap {
690
1.46k
            return this->getOpSubstitutionMap(subs);
691
1.46k
          });
692
1.46k
    }
693
5.71k
    pullbackValues[ai->getParent()].push_back(pullback);
694
695
    // Some instructions that produce the callee may have been cloned.
696
    // If the original callee did not have any users beyond this `apply`,
697
    // recursively kill the cloned callee.
698
5.71k
    if (auto *origCallee = cast_or_null<SingleValueInstruction>(
699
5.71k
            ai->getCallee()->getDefiningInstruction()))
700
5.62k
      if (origCallee->hasOneUse())
701
5.55k
        recursivelyDeleteTriviallyDeadInstructions(
702
5.55k
            getOpValue(origCallee)->getDefiningInstruction());
703
5.71k
  }
704
705
36
  void visitTryApplyInst(TryApplyInst *tai) {
706
36
    Builder.setCurrentDebugScope(getOpScope(tai->getDebugScope()));
707
    // Build pullback struct value for original block.
708
36
    auto *pbTupleVal = buildPullbackValueTupleValue(tai);
709
    // Create a new `try_apply` instruction.
710
36
    auto args = getOpValueArray<8>(tai->getArguments());
711
36
    getBuilder().createTryApply(
712
36
        tai->getLoc(), getOpValue(tai->getCallee()),
713
36
        getOpSubstitutionMap(tai->getSubstitutionMap()), args,
714
36
        createTrampolineBasicBlock(tai, pbTupleVal, tai->getNormalBB()),
715
36
        createTrampolineBasicBlock(tai, pbTupleVal, tai->getErrorBB()),
716
36
        tai->getApplyOptions());
717
36
  }
718
719
96
  void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
720
    // Clone `differentiable_function` from original to VJP, then add the cloned
721
    // instruction to the `differentiable_function` worklist.
722
96
    TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
723
96
    auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
724
96
    context.getDifferentiableFunctionInstWorklist().push_back(newDFI);
725
96
  }
726
727
0
  void visitLinearFunctionInst(LinearFunctionInst *lfi) {
728
    // Clone `linear_function` from original to VJP, then add the cloned
729
    // instruction to the `linear_function` worklist.
730
0
    TypeSubstCloner::visitLinearFunctionInst(lfi);
731
0
    auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi));
732
0
    context.getLinearFunctionInstWorklist().push_back(newLFI);
733
0
  }
734
};
735
736
/// Initialization helper function.
737
///
738
/// Returns the substitution map used for type remapping.
739
static SubstitutionMap getSubstitutionMap(SILFunction *original,
740
5.25k
                                          SILFunction *vjp) {
741
5.25k
  auto substMap = original->getForwardingSubstitutionMap();
742
5.25k
  if (auto *vjpGenEnv = vjp->getGenericEnvironment()) {
743
940
    auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap();
744
940
    substMap = SubstitutionMap::get(
745
940
        vjpGenEnv->getGenericSignature(), QuerySubstitutionMap{vjpSubstMap},
746
940
        LookUpConformanceInSubstitutionMap(vjpSubstMap));
747
940
  }
748
5.25k
  return substMap;
749
5.25k
}
750
751
/// Initialization helper function.
752
///
753
/// Returns the activity info for the given original function, autodiff indices,
754
/// and VJP generic signature.
755
static const DifferentiableActivityInfo &
756
getActivityInfoHelper(ADContext &context, SILFunction *original,
757
5.25k
                      const AutoDiffConfig &config, SILFunction *vjp) {
758
  // Get activity info of the original function.
759
5.25k
  auto &passManager = context.getPassManager();
760
5.25k
  auto *activityAnalysis =
761
5.25k
      passManager.getAnalysis<DifferentiableActivityAnalysis>();
762
5.25k
  auto &activityCollection = *activityAnalysis->get(original);
763
5.25k
  auto &activityInfo = activityCollection.getActivityInfo(
764
5.25k
      vjp->getLoweredFunctionType()->getSubstGenericSignature(),
765
5.25k
      AutoDiffDerivativeFunctionKind::VJP);
766
5.25k
  LLVM_DEBUG(activityInfo.dump(config, getADDebugStream()));
767
5.25k
  return activityInfo;
768
5.25k
}
769
770
VJPCloner::Implementation::Implementation(VJPCloner &cloner, ADContext &context,
771
                                          SILDifferentiabilityWitness *witness,
772
                                          SILFunction *vjp,
773
                                          DifferentiationInvoker invoker)
774
    : TypeSubstCloner(*vjp, *witness->getOriginalFunction(),
775
                      getSubstitutionMap(witness->getOriginalFunction(), vjp)),
776
      cloner(cloner), context(context),
777
      original(witness->getOriginalFunction()), witness(witness),
778
      vjp(vjp), invoker(invoker),
779
      activityInfo(getActivityInfoHelper(
780
          context, original, witness->getConfig(), vjp)),
781
      loopInfo(context.getPassManager().getAnalysis<SILLoopAnalysis>()
782
                   ->get(original)),
783
      pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp,
784
5.25k
                   witness->getConfig(), activityInfo, loopInfo) {
785
  // Create empty pullback function.
786
5.25k
  pullback = createEmptyPullback();
787
5.25k
  context.recordGeneratedFunction(pullback);
788
5.25k
}
789
790
VJPCloner::VJPCloner(ADContext &context,
791
                     SILDifferentiabilityWitness *witness, SILFunction *vjp,
792
                     DifferentiationInvoker invoker)
793
5.25k
    : impl(*new Implementation(*this, context, witness, vjp, invoker)) {}
794
795
5.25k
VJPCloner::~VJPCloner() { delete &impl; }
796
797
185k
ADContext &VJPCloner::getContext() const { return impl.context; }
798
0
SILModule &VJPCloner::getModule() const { return impl.getModule(); }
799
180k
SILFunction &VJPCloner::getOriginal() const { return *impl.original; }
800
5.10k
SILFunction &VJPCloner::getVJP() const { return *impl.vjp; }
801
615k
SILFunction &VJPCloner::getPullback() const { return *impl.pullback; }
802
144k
SILDifferentiabilityWitness *VJPCloner::getWitness() const {
803
144k
  return impl.witness;
804
144k
}
805
173k
const AutoDiffConfig &VJPCloner::getConfig() const {
806
173k
  return impl.getConfig();
807
173k
}
808
3.81k
DifferentiationInvoker VJPCloner::getInvoker() const { return impl.invoker; }
809
134k
LinearMapInfo &VJPCloner::getPullbackInfo() const { return impl.pullbackInfo; }
810
2.35k
SILLoopInfo *VJPCloner::getLoopInfo() const { return impl.loopInfo; }
811
152k
const DifferentiableActivityInfo &VJPCloner::getActivityInfo() const {
812
152k
  return impl.activityInfo;
813
152k
}
814
815
5.25k
SILFunction *VJPCloner::Implementation::createEmptyPullback() {
816
5.25k
  auto &module = context.getModule();
817
5.25k
  auto origTy = original->getLoweredFunctionType();
818
  // Get witness generic signature for remapping types.
819
  // Witness generic signature may have more requirements than VJP generic
820
  // signature: when witness generic signature has same-type requirements
821
  // binding all generic parameters to concrete types, VJP function type uses
822
  // all the concrete types and VJP generic signature is null.
823
5.25k
  auto witnessCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature();
824
5.25k
  auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
825
826
  // Given a type, returns its formal SIL parameter info.
827
5.25k
  auto getTangentParameterInfoForOriginalResult =
828
5.25k
      [&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
829
4.98k
    tanType = tanType->getReducedType(witnessCanGenSig);
830
4.98k
    Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
831
4.98k
    auto &tl = context.getTypeConverter().getTypeLowering(
832
4.98k
        pattern, tanType, TypeExpansionContext::minimal());
833
4.98k
    ParameterConvention conv;
834
4.98k
    switch (origResConv) {
835
3.28k
    case ResultConvention::Unowned:
836
3.28k
    case ResultConvention::UnownedInnerPointer:
837
3.76k
    case ResultConvention::Owned:
838
3.76k
    case ResultConvention::Autoreleased:
839
3.76k
      if (tl.isAddressOnly()) {
840
92
        conv = ParameterConvention::Indirect_In_Guaranteed;
841
3.66k
      } else {
842
3.66k
        conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
843
3.66k
                              : ParameterConvention::Direct_Guaranteed;
844
3.66k
      }
845
3.76k
      break;
846
1.22k
    case ResultConvention::Indirect:
847
1.22k
      conv = ParameterConvention::Indirect_In_Guaranteed;
848
1.22k
      break;
849
0
    case ResultConvention::Pack:
850
0
      conv = ParameterConvention::Pack_Guaranteed;
851
0
      break;
852
4.98k
    }
853
4.98k
    return {tanType, conv};
854
4.98k
  };
855
856
  // Given a type, returns its formal SIL result info.
857
5.25k
  auto getTangentResultInfoForOriginalParameter =
858
6.93k
      [&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
859
6.93k
    tanType = tanType->getReducedType(witnessCanGenSig);
860
6.93k
    Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
861
6.93k
    auto &tl = context.getTypeConverter().getTypeLowering(
862
6.93k
        pattern, tanType, TypeExpansionContext::minimal());
863
6.93k
    ResultConvention conv;
864
6.93k
    switch (origParamConv) {
865
48
    case ParameterConvention::Direct_Owned:
866
620
    case ParameterConvention::Direct_Guaranteed:
867
5.20k
    case ParameterConvention::Direct_Unowned:
868
5.20k
      if (tl.isAddressOnly()) {
869
112
        conv = ResultConvention::Indirect;
870
5.09k
      } else {
871
5.09k
        conv = tl.isTrivial() ? ResultConvention::Unowned
872
5.09k
                              : ResultConvention::Owned;
873
5.09k
      }
874
5.20k
      break;
875
204
    case ParameterConvention::Indirect_In:
876
204
    case ParameterConvention::Indirect_Inout:
877
1.73k
    case ParameterConvention::Indirect_In_Guaranteed:
878
1.73k
    case ParameterConvention::Indirect_InoutAliasable:
879
1.73k
      conv = ResultConvention::Indirect;
880
1.73k
      break;
881
0
    case ParameterConvention::Pack_Guaranteed:
882
0
    case ParameterConvention::Pack_Owned:
883
0
    case ParameterConvention::Pack_Inout:
884
0
      conv = ResultConvention::Pack;
885
0
      break;
886
6.93k
    }
887
6.93k
    return {tanType, conv};
888
6.93k
  };
889
890
  // Parameters of the pullback are:
891
  // - the tangent vectors of the original results, and
892
  // - a pullback struct.
893
  // Results of the pullback are in the tangent space of the original
894
  // parameters.
895
5.25k
  SmallVector<SILParameterInfo, 8> pbParams;
896
5.25k
  SmallVector<SILResultInfo, 8> adjResults;
897
5.25k
  auto origParams = origTy->getParameters();
898
5.25k
  auto config = witness->getConfig();
899
900
  // Add pullback parameters based on original result indices.
901
5.25k
  SmallVector<unsigned, 4> semanticResultParamIndices;
902
8.56k
  for (auto i : range(origTy->getNumParameters())) {
903
8.56k
    auto origParam = origParams[i];
904
8.56k
    if (!origParam.isAutoDiffSemanticResult())
905
8.18k
      continue;
906
388
    semanticResultParamIndices.push_back(i);
907
388
  }
908
909
5.36k
  for (auto resultIndex : config.resultIndices->getIndices()) {
910
    // Handle formal result.
911
5.36k
    if (resultIndex < origTy->getNumResults()) {
912
4.98k
      auto origResult = origTy->getResults()[resultIndex];
913
4.98k
      origResult = origResult.getWithInterfaceType(
914
4.98k
          origResult.getInterfaceType()->getReducedType(witnessCanGenSig));
915
4.98k
      auto paramInfo = getTangentParameterInfoForOriginalResult(
916
4.98k
          origResult.getInterfaceType()
917
4.98k
              ->getAutoDiffTangentSpace(lookupConformance)
918
4.98k
              ->getType()
919
4.98k
              ->getReducedType(witnessCanGenSig),
920
4.98k
          origResult.getConvention());
921
4.98k
      pbParams.push_back(paramInfo);
922
4.98k
      continue;
923
4.98k
    }
924
925
    // Handle semantic result parameter.
926
384
    unsigned paramIndex = 0;
927
384
    unsigned resultParamIndex = 0;
928
592
    for (auto i : range(origTy->getNumParameters())) {
929
592
      auto origParam = origTy->getParameters()[i];
930
592
      if (!origParam.isAutoDiffSemanticResult()) {
931
168
        ++paramIndex;
932
168
        continue;
933
168
      }
934
424
      if (resultParamIndex == resultIndex - origTy->getNumResults())
935
384
        break;
936
40
      ++paramIndex;
937
40
      ++resultParamIndex;
938
40
    }
939
384
    auto resultParam = origParams[paramIndex];
940
384
    auto origResult = resultParam.getWithInterfaceType(
941
384
      resultParam.getInterfaceType()->getReducedType(witnessCanGenSig));
942
943
384
    auto resultParamTanConvention = resultParam.getConvention();
944
384
    if (!config.isWrtParameter(paramIndex))
945
0
      resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed;
946
947
384
    pbParams.emplace_back(origResult.getInterfaceType()
948
384
                          ->getAutoDiffTangentSpace(lookupConformance)
949
384
                          ->getType()
950
384
                          ->getReducedType(witnessCanGenSig),
951
384
                          resultParamTanConvention);
952
384
  }
953
954
5.25k
  if (pullbackInfo.hasHeapAllocatedContext()) {
955
    // Accept a `AutoDiffLinarMapContext` heap object if there are loops.
956
108
    pbParams.push_back({
957
108
      getASTContext().TheNativeObjectType,
958
108
      ParameterConvention::Direct_Guaranteed
959
108
    });
960
5.14k
  } else {
961
    // Accept a pullback struct in the pullback parameter list. This is the
962
    // returned pullback's closure context.
963
5.14k
    auto *origExit = &*original->findReturnBB();
964
5.14k
    auto pbTupleType =
965
5.14k
      pullbackInfo.getLinearMapTupleLoweredType(origExit).getAs<TupleType>();
966
5.14k
    for (Type eltTy : pbTupleType->getElementTypes())
967
5.48k
      pbParams.emplace_back(CanType(eltTy), ParameterConvention::Direct_Owned);
968
5.14k
  }
969
970
  // Add pullback results for the requested wrt parameters.
971
7.32k
  for (auto i : config.parameterIndices->getIndices()) {
972
7.32k
    auto origParam = origParams[i];
973
7.32k
    if (origParam.isAutoDiffSemanticResult())
974
384
      continue;
975
6.93k
    origParam = origParam.getWithInterfaceType(
976
6.93k
        origParam.getInterfaceType()->getReducedType(witnessCanGenSig));
977
6.93k
    adjResults.push_back(getTangentResultInfoForOriginalParameter(
978
6.93k
        origParam.getInterfaceType()
979
6.93k
            ->getAutoDiffTangentSpace(lookupConformance)
980
6.93k
            ->getType()
981
6.93k
            ->getReducedType(witnessCanGenSig),
982
6.93k
        origParam.getConvention()));
983
6.93k
  }
984
985
5.25k
  Mangle::DifferentiationMangler mangler;
986
5.25k
  auto pbName = mangler.mangleLinearMap(
987
5.25k
      original->getName(), AutoDiffLinearMapKind::Pullback, config);
988
  // Set pullback generic signature equal to VJP generic signature.
989
  // Do not use witness generic signature, which may have same-type requirements
990
  // binding all generic parameters to concrete types.
991
5.25k
  auto pbGenericSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
992
5.25k
  auto *pbGenericEnv = pbGenericSig.getGenericEnvironment();
993
5.25k
  auto pbType = SILFunctionType::get(
994
5.25k
      pbGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(),
995
5.25k
      origTy->getCalleeConvention(), pbParams, {}, adjResults, llvm::None,
996
5.25k
      origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(),
997
5.25k
      original->getASTContext());
998
999
5.25k
  SILOptFunctionBuilder fb(context.getTransform());
1000
5.25k
  auto linkage = vjp->isSerialized() ? SILLinkage::Public : SILLinkage::Private;
1001
5.25k
  auto *pullback = fb.createFunction(
1002
5.25k
      linkage, context.getASTContext().getIdentifier(pbName).str(), pbType,
1003
5.25k
      pbGenericEnv, original->getLocation(), original->isBare(),
1004
5.25k
      IsNotTransparent, vjp->isSerialized(),
1005
5.25k
      original->isDynamicallyReplaceable(), original->isDistributed(),
1006
5.25k
      original->isRuntimeAccessible());
1007
5.25k
  pullback->setDebugScope(new (module)
1008
5.25k
                              SILDebugScope(original->getLocation(), pullback));
1009
1010
5.25k
  return pullback;
1011
5.25k
}
1012
1013
SILBasicBlock *VJPCloner::Implementation::createTrampolineBasicBlock(
1014
1.44k
    TermInst *termInst, TupleInst *pbTupleVal, SILBasicBlock *succBB) {
1015
1.44k
  assert(llvm::find(termInst->getSuccessorBlocks(), succBB) !=
1016
1.44k
             termInst->getSuccessorBlocks().end() &&
1017
1.44k
         "Basic block is not a successor of terminator instruction");
1018
  // Create the trampoline block.
1019
0
  auto *vjpSuccBB = getOpBasicBlock(succBB);
1020
1.44k
  auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
1021
1.44k
  for (auto *arg : vjpSuccBB->getArguments().drop_back())
1022
452
    trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
1023
  // In the trampoline block, build predecessor enum value for VJP successor
1024
  // block and branch to it.
1025
1.44k
  SILBuilder trampolineBuilder(trampolineBB);
1026
1.44k
  trampolineBuilder.setCurrentDebugScope(getOpScope(termInst->getDebugScope()));
1027
1.44k
  auto *origBB = termInst->getParent();
1028
1.44k
  auto *succEnumVal =
1029
1.44k
      buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbTupleVal);
1030
1.44k
  SmallVector<SILValue, 4> forwardedArguments(
1031
1.44k
      trampolineBB->getArguments().begin(), trampolineBB->getArguments().end());
1032
1.44k
  forwardedArguments.push_back(succEnumVal);
1033
1.44k
  trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB,
1034
1.44k
                                 forwardedArguments);
1035
1.44k
  return trampolineBB;
1036
1.44k
}
1037
1038
llvm::SmallVector<SILValue, 8>
1039
7.15k
VJPCloner::Implementation::getPullbackValues(SILBasicBlock *origBB) {
1040
7.15k
  auto *vjpBB = BBMap[origBB];
1041
7.15k
  auto bbPullbackValues = pullbackValues[origBB];
1042
7.15k
  if (!origBB->isEntry()) {
1043
1.91k
    auto *predEnumArg = vjpBB->getArguments().back();
1044
1.91k
    bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg);
1045
1.91k
  }
1046
1047
7.15k
  return bbPullbackValues;
1048
7.15k
}
1049
1050
TupleInst *
1051
2.02k
VJPCloner::Implementation::buildPullbackValueTupleValue(TermInst *termInst) {
1052
2.02k
  assert(termInst->getFunction() == original);
1053
0
  auto loc = RegularLocation::getAutoGeneratedLocation();
1054
2.02k
  auto origBB = termInst->getParent();
1055
2.02k
  auto tupleLoweredTy =
1056
2.02k
      remapType(pullbackInfo.getLinearMapTupleLoweredType(origBB));
1057
2.02k
  auto bbPullbackValues = getPullbackValues(origBB);
1058
2.02k
  return getBuilder().createTuple(loc, tupleLoweredTy, bbPullbackValues);
1059
2.02k
}
1060
1061
EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue(
1062
    SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB,
1063
2.63k
    SILValue pbTupleVal) {
1064
2.63k
  auto loc = RegularLocation::getAutoGeneratedLocation();
1065
2.63k
  auto enumLoweredTy =
1066
2.63k
      remapType(pullbackInfo.getBranchingTraceEnumLoweredType(succBB));
1067
2.63k
  auto *enumEltDecl =
1068
2.63k
      pullbackInfo.lookUpBranchingTraceEnumElement(predBB, succBB);
1069
2.63k
  auto enumEltType = getOpType(enumLoweredTy.getEnumElementType(
1070
2.63k
      enumEltDecl, getModule(), TypeExpansionContext::minimal()));
1071
  // If the predecessor block is in a loop, its predecessor enum payload is a
1072
  // `Builtin.RawPointer`.
1073
2.63k
  if (loopInfo->getLoopFor(predBB)) {
1074
396
    auto rawPtrType = SILType::getRawPointerType(getASTContext());
1075
396
    assert(enumEltType == rawPtrType);
1076
0
    auto pbTupleType =
1077
396
      remapASTType(pullbackInfo.getLinearMapTupleType(predBB)->getCanonicalType());
1078
1079
396
    auto pbTupleMetatypeType =
1080
396
        CanMetatypeType::get(pbTupleType, MetatypeRepresentation::Thick);
1081
396
    auto pbTupleMetatypeSILType =
1082
396
        SILType::getPrimitiveObjectType(pbTupleMetatypeType);
1083
396
    auto pbTupleMetatype =
1084
396
        Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType);
1085
1086
396
    auto rawBufferValue = builder.createBuiltin(
1087
396
        loc,
1088
396
        getASTContext().getIdentifier(getBuiltinName(
1089
396
            BuiltinValueKind::AutoDiffAllocateSubcontextWithType)),
1090
396
        rawPtrType, SubstitutionMap(),
1091
396
        {borrowedPullbackContextValue, pbTupleMetatype});
1092
1093
396
    auto typedBufferValue =
1094
396
      builder.createPointerToAddress(
1095
396
        loc, rawBufferValue, pbTupleVal->getType().getAddressType(),
1096
396
        /*isStrict*/ true);
1097
396
    builder.createStore(
1098
396
        loc, pbTupleVal, typedBufferValue,
1099
396
        pbTupleVal->getType().isTrivial(*pullback) ?
1100
296
            StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init);
1101
396
    return builder.createEnum(loc, rawBufferValue, enumEltDecl, enumLoweredTy);
1102
396
  }
1103
2.23k
  return builder.createEnum(loc, pbTupleVal, enumEltDecl, enumLoweredTy);
1104
2.63k
}
1105
1106
5.25k
bool VJPCloner::Implementation::run() {
1107
5.25k
  PrettyStackTraceSILFunction trace("generating VJP for", original);
1108
5.25k
  LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName()
1109
5.25k
                                << " to vjp @" << vjp->getName() << '\n');
1110
1111
  // Create entry BB and arguments.
1112
5.25k
  auto *entry = vjp->createBasicBlock();
1113
5.25k
  createEntryArguments(vjp);
1114
1115
5.25k
  emitLinearMapContextInitializationIfNeeded();
1116
1117
  // Clone.
1118
5.25k
  SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(),
1119
5.25k
                                     entry->getArguments().end());
1120
5.25k
  cloneFunctionBody(original, entry, entryArgs);
1121
  // If errors occurred, back out.
1122
5.25k
  if (errorOccurred)
1123
20
    return true;
1124
1125
  // Merge VJP basic blocks. This is significant for control flow
1126
  // differentiation: trampoline destination bbs are merged into trampoline bbs.
1127
  // NOTE(TF-990): Merging basic blocks ensures that `@guaranteed` trampoline
1128
  // bb arguments have a lifetime-ending `end_borrow` use, and is robust when
1129
  // `-enable-strip-ownership-after-serialization` is true.
1130
5.23k
  mergeBasicBlocks(vjp);
1131
1132
5.23k
  LLVM_DEBUG(getADDebugStream()
1133
5.23k
             << "Generated VJP for " << original->getName() << ":\n"
1134
5.23k
             << *vjp);
1135
1136
  // Generate pullback code.
1137
5.23k
  PullbackCloner PullbackCloner(cloner);
1138
5.23k
  if (PullbackCloner.run()) {
1139
132
    errorOccurred = true;
1140
132
    return true;
1141
132
  }
1142
5.10k
  return errorOccurred;
1143
5.23k
}
1144
1145
5.25k
bool VJPCloner::run() {
1146
5.25k
  bool foundError = impl.run();
1147
5.25k
#ifndef NDEBUG
1148
5.25k
  if (!foundError)
1149
5.10k
    getVJP().verify();
1150
5.25k
#endif
1151
5.25k
  return foundError;
1152
5.25k
}
1153
1154
} // end namespace autodiff
1155
} // end namespace swift