Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/JVPCloner.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- JVPCloner.cpp - JVP function generation --------------*- C++ -*---===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// This file defines a helper class for generating JVP functions for automatic
14
// differentiation.
15
//
16
//===----------------------------------------------------------------------===//
17
18
#define DEBUG_TYPE "differentiation"
19
20
#include "swift/SILOptimizer/Differentiation/JVPCloner.h"
21
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
22
#include "swift/SILOptimizer/Differentiation/ADContext.h"
23
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
24
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
25
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
26
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
27
#include "swift/SILOptimizer/Differentiation/Thunk.h"
28
29
#include "swift/SIL/LoopInfo.h"
30
#include "swift/SIL/TypeSubstCloner.h"
31
#include "swift/SILOptimizer/Analysis/LoopAnalysis.h"
32
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
33
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
34
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
35
#include "llvm/ADT/DenseMap.h"
36
37
using namespace swift;
38
using namespace autodiff;
39
40
namespace swift {
41
namespace autodiff {
42
43
class JVPCloner::Implementation final
44
    : public TypeSubstCloner<JVPCloner::Implementation, SILOptFunctionBuilder> {
45
private:
46
  /// The global context.
47
  ADContext &context;
48
49
  /// The original function.
50
  SILFunction *const original;
51
52
  /// The witness.
53
  SILDifferentiabilityWitness *const witness;
54
55
  /// The JVP function.
56
  SILFunction *const jvp;
57
58
  llvm::BumpPtrAllocator allocator;
59
60
  /// The differentiation invoker.
61
  DifferentiationInvoker invoker;
62
63
  /// Info from activity analysis on the original function.
64
  const DifferentiableActivityInfo &activityInfo;
65
66
  /// The loop info.
67
  SILLoopInfo *loopInfo;
68
69
  /// The differential info.
70
  LinearMapInfo differentialInfo;
71
72
  bool errorOccurred = false;
73
74
  //--------------------------------------------------------------------------//
75
  // Differential generation related fields
76
  //--------------------------------------------------------------------------//
77
78
  /// The builder for the differential function.
79
  TangentBuilder differentialBuilder;
80
81
  /// Mapping from original basic blocks to corresponding differential basic
82
  /// blocks.
83
  llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> diffBBMap;
84
85
  /// Mapping from original basic blocks and original values to corresponding
86
  /// tangent values.
87
  llvm::DenseMap<SILValue, AdjointValue> tangentValueMap;
88
89
  /// Mapping from original basic blocks and original buffers to corresponding
90
  /// tangent buffers.
91
  llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap;
92
93
  /// Mapping from differential basic blocks to differential struct arguments.
94
  llvm::DenseMap<SILBasicBlock *, SILArgument *> differentialStructArguments;
95
96
  /// Mapping from differential struct field declarations to differential struct
97
  /// elements destructured from the linear map basic block argument. In the
98
  /// beginning of each differential basic block, the block's differential
99
  /// struct is destructured into the individual elements stored here.
100
  llvm::DenseMap<SILBasicBlock *, SILInstructionResultArray> differentialTupleElements;
101
102
  /// An auxiliary differential local allocation builder.
103
  TangentBuilder diffLocalAllocBuilder;
104
105
  /// Stack buffers allocated for storing local tangent values.
106
  SmallVector<SILValue, 8> differentialLocalAllocations;
107
108
  /// Mapping from original blocks to differential values. Used to build
109
  /// differential struct instances.
110
  llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> differentialValues;
111
112
  //--------------------------------------------------------------------------//
113
  // Getters
114
  //--------------------------------------------------------------------------//
115
116
3.20k
  ASTContext &getASTContext() const { return jvp->getASTContext(); }
117
13.5k
  SILModule &getModule() const { return jvp->getModule(); }
118
9.85k
  const AutoDiffConfig getConfig() const { return witness->getConfig(); }
119
19.6k
  TangentBuilder &getDifferentialBuilder() { return differentialBuilder; }
120
24.4k
  SILFunction &getDifferential() { return differentialBuilder.getFunction(); }
121
0
  SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) {
122
0
    return differentialStructArguments[origBB];
123
0
  }
124
125
  //--------------------------------------------------------------------------//
126
  // Differential tuple mapping
127
  //--------------------------------------------------------------------------//
128
129
  void initializeDifferentialTupleElements(SILBasicBlock *origBB,
130
                                           SILInstructionResultArray values);
131
132
  SILValue getDifferentialTupleElement(ApplyInst *ai);
133
134
  //--------------------------------------------------------------------------//
135
  // General utilities
136
  //--------------------------------------------------------------------------//
137
138
  /// Get the lowered SIL type of the given AST type.
139
1.60k
  SILType getLoweredType(Type type) {
140
1.60k
    auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature();
141
1.60k
    Lowering::AbstractionPattern pattern(jvpGenSig,
142
1.60k
                                         type->getReducedType(jvpGenSig));
143
1.60k
    return jvp->getLoweredType(pattern, type);
144
1.60k
  }
145
146
  /// Get the lowered SIL type of the given nominal type declaration.
147
0
  SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) {
148
0
    auto nominalType =
149
0
        getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType());
150
0
    return getLoweredType(nominalType);
151
0
  }
152
153
  /// Build a differential struct value for the original block corresponding to
154
  /// the given terminator.
155
1.33k
  TupleInst *buildDifferentialValueStructValue(TermInst *termInst) {
156
1.33k
    assert(termInst->getFunction() == original);
157
0
    auto loc = termInst->getFunction()->getLocation();
158
1.33k
    auto *origBB = termInst->getParent();
159
1.33k
    auto *jvpBB = BBMap[origBB];
160
1.33k
    assert(jvpBB && "Basic block mapping should exist");
161
0
    auto tupleLoweredTy =
162
1.33k
      remapType(differentialInfo.getLinearMapTupleLoweredType(origBB));
163
1.33k
    auto bbDifferentialValues = differentialValues[origBB];
164
1.33k
    if (!origBB->isEntry()) {
165
0
      auto *enumArg = jvpBB->getArguments().back();
166
0
      bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg);
167
0
    }
168
1.33k
    return getBuilder().createTuple(loc, tupleLoweredTy,
169
1.33k
                                    bbDifferentialValues);
170
1.33k
  }
171
172
  //--------------------------------------------------------------------------//
173
  // Tangent value factory methods
174
  //--------------------------------------------------------------------------//
175
176
3.95k
  AdjointValue makeZeroTangentValue(SILType type) {
177
3.95k
    return AdjointValue::createZero(allocator,
178
3.95k
                                    remapSILTypeInDifferential(type));
179
3.95k
  }
180
181
3.64k
  AdjointValue makeConcreteTangentValue(SILValue value) {
182
3.64k
    return AdjointValue::createConcrete(allocator, value);
183
3.64k
  }
184
185
  //--------------------------------------------------------------------------//
186
  // Tangent materialization
187
  //--------------------------------------------------------------------------//
188
189
92
  void emitZeroIndirect(CanType type, SILValue buffer, SILLocation loc) {
190
92
    auto builder = getDifferentialBuilder();
191
92
    auto tangentSpace = getTangentSpace(type);
192
92
    assert(tangentSpace && "No tangent space for this type");
193
0
    switch (tangentSpace->getKind()) {
194
92
    case TangentSpace::Kind::TangentVector:
195
92
      builder.emitZeroIntoBuffer(loc, buffer, IsInitialization);
196
92
      return;
197
0
    case TangentSpace::Kind::Tuple: {
198
0
      auto tupleType = tangentSpace->getTuple();
199
0
      SmallVector<SILValue, 8> zeroElements;
200
0
      for (unsigned i : range(tupleType->getNumElements())) {
201
0
        auto eltAddr = builder.createTupleElementAddr(loc, buffer, i);
202
0
        emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(),
203
0
                         eltAddr, loc);
204
0
      }
205
0
      return;
206
0
    }
207
92
    }
208
92
  }
209
210
64
  SILValue emitZeroDirect(CanType type, SILLocation loc) {
211
64
    auto diffBuilder = getDifferentialBuilder();
212
64
    auto silType = getModule().Types.getLoweredLoadableType(
213
64
        type, TypeExpansionContext::minimal(), getModule());
214
64
    auto *buffer = diffBuilder.createAllocStack(loc, silType);
215
64
    emitZeroIndirect(type, buffer, loc);
216
64
    auto loaded = diffBuilder.emitLoadValueOperation(
217
64
        loc, buffer, LoadOwnershipQualifier::Take);
218
64
    diffBuilder.createDeallocStack(loc, buffer);
219
64
    return loaded;
220
64
  }
221
222
60
  SILValue materializeTangentDirect(AdjointValue val, SILLocation loc) {
223
60
    assert(val.getType().isObject());
224
60
    LLVM_DEBUG(getADDebugStream()
225
60
               << "Materializing tangents for " << val << '\n');
226
60
    switch (val.getKind()) {
227
60
    case AdjointValueKind::Zero: {
228
60
      auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
229
60
      return zeroVal;
230
0
    }
231
0
    case AdjointValueKind::Concrete:
232
0
      return val.getConcreteValue();
233
0
    case AdjointValueKind::Aggregate:
234
0
    case AdjointValueKind::AddElement:
235
0
      llvm_unreachable(
236
60
          "Tuples and structs are not supported in forward mode yet.");
237
60
    }
238
0
    llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
239
0
  }
240
241
3.92k
  SILValue materializeTangent(AdjointValue val, SILLocation loc) {
242
3.92k
    if (val.isConcrete()) {
243
3.86k
      LLVM_DEBUG(getADDebugStream()
244
3.86k
                 << "Materializing tangent: Value is concrete.\n");
245
3.86k
      return val.getConcreteValue();
246
3.86k
    }
247
60
    LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is "
248
60
                                     "non-concrete. Materializing directly.\n");
249
60
    return materializeTangentDirect(val, loc);
250
3.92k
  }
251
252
  //--------------------------------------------------------------------------//
253
  // Tangent value mapping
254
  //--------------------------------------------------------------------------//
255
256
  /// Get the tangent for an original value. The given value must be in the
257
  /// original function.
258
  ///
259
  /// This method first tries to find an entry in `tangentValueMap`. If an entry
260
  /// doesn't exist, create a zero tangent.
261
3.95k
  AdjointValue getTangentValue(SILValue originalValue) {
262
3.95k
    assert(originalValue->getType().isObject());
263
0
    assert(originalValue->getFunction() == original);
264
0
    auto insertion = tangentValueMap.try_emplace(
265
3.95k
        originalValue,
266
3.95k
        makeZeroTangentValue(getRemappedTangentType(originalValue->getType())));
267
3.95k
    return insertion.first->getSecond();
268
3.95k
  }
269
270
  /// Map the tangent value to the given original value.
271
  void setTangentValue(SILBasicBlock *origBB, SILValue originalValue,
272
3.64k
                       AdjointValue newTangentValue) {
273
3.64k
#ifndef NDEBUG
274
3.64k
    if (auto *defInst = originalValue->getDefiningInstruction()) {
275
1.88k
      bool isTupleTypedApplyResult =
276
1.88k
          isa<ApplyInst>(defInst) && originalValue->getType().is<TupleType>();
277
1.88k
      assert(!isTupleTypedApplyResult &&
278
1.88k
             "Should not set tangent value for tuple-typed result from `apply` "
279
1.88k
             "instruction; use `destructure_tuple` on `apply` result and set "
280
1.88k
             "tangent value for `destructure_tuple` results instead.");
281
1.88k
    }
282
0
#endif
283
0
    assert(originalValue->getType().isObject());
284
0
    assert(newTangentValue.getType().isObject());
285
0
    assert(originalValue->getFunction() == original);
286
3.64k
    LLVM_DEBUG(getADDebugStream()
287
3.64k
               << "Setting tangent value for " << originalValue);
288
    // The tangent value must be in the tangent space.
289
3.64k
    assert(newTangentValue.getType() ==
290
3.64k
           getRemappedTangentType(originalValue->getType()));
291
0
    auto insertion =
292
3.64k
        tangentValueMap.try_emplace(originalValue, newTangentValue);
293
3.64k
    (void)insertion;
294
3.64k
    assert(insertion.second && "The tangent value should not already exist.");
295
3.64k
  }
296
297
  //--------------------------------------------------------------------------//
298
  // Tangent buffer mapping
299
  //--------------------------------------------------------------------------//
300
301
  /// Sets the tangent buffer for the original buffer. Asserts that the
302
  /// original buffer does not already have a tangent buffer.
303
  void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
304
2.59k
                        SILValue tangentBuffer) {
305
2.59k
    assert(originalBuffer->getType().isAddress());
306
0
    auto insertion =
307
2.59k
        bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer);
308
2.59k
    assert(insertion.second && "Tangent buffer already exists");
309
0
    (void)insertion;
310
2.59k
  }
311
312
  /// Returns the tangent buffer for the original buffer. Asserts that the
313
  /// original buffer has a tangent buffer.
314
6.06k
  SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) {
315
6.06k
    assert(originalBuffer->getType().isAddress());
316
0
    assert(originalBuffer->getFunction() == original);
317
0
    auto it = bufferMap.find({origBB, originalBuffer});
318
6.06k
    assert(it != bufferMap.end() && "Tangent buffer should already exist");
319
0
    return it->getSecond();
320
6.06k
  }
321
322
  //--------------------------------------------------------------------------//
323
  // Differential type calculations
324
  //--------------------------------------------------------------------------//
325
326
  /// Substitutes all replacement types of the given substitution map using the
327
  /// tangent function's substitution map.
328
0
  SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) {
329
0
    return substMap.subst(getDifferential().getForwardingSubstitutionMap());
330
0
  }
331
332
  /// Remap any archetypes into the differential function's context.
333
0
  Type remapTypeInDifferential(Type ty) {
334
0
    if (ty->hasArchetype())
335
0
      return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext());
336
0
    return getDifferential().mapTypeIntoContext(ty);
337
0
  }
338
339
  /// Remap any archetypes into the differential function's context.
340
16.3k
  SILType remapSILTypeInDifferential(SILType ty) {
341
16.3k
    if (ty.hasArchetype())
342
592
      return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext());
343
15.7k
    return getDifferential().mapTypeIntoContext(ty);
344
16.3k
  }
345
346
  /// Find the tangent space of a given canonical type.
347
9.33k
  llvm::Optional<TangentSpace> getTangentSpace(CanType type) {
348
    // Use witness generic signature to remap types.
349
9.33k
    type = witness->getDerivativeGenericSignature().getReducedType(
350
9.33k
        type);
351
9.33k
    return type->getAutoDiffTangentSpace(
352
9.33k
        LookUpConformanceInModule(getModule().getSwiftModule()));
353
9.33k
  }
354
355
  /// Assuming the given type conforms to `Differentiable` after remapping,
356
  /// returns the associated tangent space SIL type.
357
9.07k
  SILType getRemappedTangentType(SILType type) {
358
9.07k
    return SILType::getPrimitiveType(
359
9.07k
        getTangentSpace(remapSILTypeInDifferential(type).getASTType())
360
9.07k
            ->getCanonicalType(),
361
9.07k
        type.getCategory());
362
9.07k
  }
363
364
  /// Set up the differential function. This includes:
365
  /// - Creating all differential blocks.
366
  /// - Creating differential entry block arguments based on the function type.
367
  /// - Creating tangent value mapping for original/differential parameters.
368
  /// - Checking for unvaried result and emitting related warnings.
369
  void prepareForDifferentialGeneration();
370
371
public:
372
  explicit Implementation(ADContext &context,
373
                          SILDifferentiabilityWitness *witness,
374
                          SILFunction *jvp, DifferentiationInvoker invoker);
375
376
  static SILFunction *
377
  createEmptyDifferential(ADContext &context,
378
                          SILDifferentiabilityWitness *witness,
379
                          LinearMapInfo *linearMapInfo);
380
381
  /// Run JVP generation. Returns true on error.
382
  bool run();
383
384
1.33k
  SILFunction &getJVP() const { return *jvp; }
385
386
14.2k
  void postProcess(SILInstruction *orig, SILInstruction *cloned) {
387
14.2k
    if (errorOccurred)
388
0
      return;
389
14.2k
    SILClonerWithScopes::postProcess(orig, cloned);
390
14.2k
  }
391
392
  /// Remap original basic blocks.
393
0
  SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
394
0
    auto *jvpBB = BBMap[bb];
395
0
    return jvpBB;
396
0
  }
397
398
  /// General visitor for all instructions. If any error is emitted by previous
399
  /// visits, bail out.
400
17.2k
  void visit(SILInstruction *inst) {
401
17.2k
    if (errorOccurred)
402
60
      return;
403
17.2k
    if (differentialInfo.shouldDifferentiateInstruction(inst)) {
404
7.58k
      LLVM_DEBUG(getADDebugStream() << "JVPCloner visited:\n[ORIG]" << *inst);
405
7.58k
#ifndef NDEBUG
406
7.58k
      auto diffBuilder = getDifferentialBuilder();
407
7.58k
      auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint());
408
7.58k
#endif
409
7.58k
      TypeSubstCloner::visit(inst);
410
7.58k
      LLVM_DEBUG({
411
7.58k
        auto &s = llvm::dbgs() << "[TAN] Emitted in differential:\n";
412
7.58k
        auto afterInsertion = diffBuilder.getInsertionPoint();
413
7.58k
        for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
414
7.58k
          s << *it;
415
7.58k
      });
416
9.65k
    } else {
417
9.65k
      TypeSubstCloner::visit(inst);
418
9.65k
    }
419
17.2k
  }
420
421
0
  void visitSILInstruction(SILInstruction *inst) {
422
0
    context.emitNondifferentiabilityError(
423
0
        inst, invoker, diag::autodiff_expression_not_differentiable_note);
424
0
    errorOccurred = true;
425
0
  }
426
427
1.35k
  void visitInstructionsInBlock(SILBasicBlock *bb) {
428
    // Destructure the differential struct to get the elements.
429
1.35k
    auto &diffBuilder = getDifferentialBuilder();
430
1.35k
    auto diffLoc = getDifferential().getLocation();
431
1.35k
    auto *diffBB = diffBBMap.lookup(bb);
432
1.35k
    auto *mainDifferentialStruct = diffBB->getArguments().back();
433
1.35k
    diffBuilder.setInsertionPoint(diffBB);
434
1.35k
    auto *dsi =
435
1.35k
        diffBuilder.createDestructureTuple(diffLoc, mainDifferentialStruct);
436
1.35k
    initializeDifferentialTupleElements(bb, dsi->getResults());
437
1.35k
    TypeSubstCloner::visitInstructionsInBlock(bb);
438
1.35k
  }
439
440
  // If an `apply` has active results or active inout parameters, replace it
441
  // with an `apply` of its JVP.
442
2.06k
  void visitApplyInst(ApplyInst *ai) {
443
2.06k
    bool shouldDifferentiate =
444
2.06k
        differentialInfo.shouldDifferentiateApplySite(ai);
445
    // If the function has no active arguments or results, zero-initialize the
446
    // tangent buffers of the active indirect results.
447
2.06k
    if (!shouldDifferentiate) {
448
460
      for (auto indResult : ai->getIndirectSILResults())
449
44
        if (activityInfo.isActive(indResult, getConfig())) {
450
20
          auto &tanBuf = getTangentBuffer(ai->getParent(), indResult);
451
20
          emitZeroIndirect(tanBuf->getType().getASTType(), tanBuf,
452
20
                           tanBuf.getLoc());
453
20
        }
454
460
    }
455
    // If the function should not be differentiated or its the array literal
456
    // initialization intrinsic, just do standard cloning.
457
2.06k
    if (!shouldDifferentiate ||
458
2.06k
        ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) {
459
460
      LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
460
460
      TypeSubstCloner::visitApplyInst(ai);
461
460
      return;
462
460
    }
463
464
1.60k
    auto loc = ai->getLoc();
465
1.60k
    auto &builder = getBuilder();
466
1.60k
    auto origCallee = getOpValue(ai->getCallee());
467
1.60k
    auto originalFnTy = origCallee->getType().castTo<SILFunctionType>();
468
469
1.60k
    LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n');
470
471
    // Get the minimal parameter and result indices required for differentiating
472
    // this `apply`.
473
1.60k
    SmallVector<SILValue, 4> allResults;
474
1.60k
    SmallVector<unsigned, 8> activeParamIndices;
475
1.60k
    SmallVector<unsigned, 8> activeResultIndices;
476
1.60k
    collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo,
477
1.60k
                                         allResults, activeParamIndices,
478
1.60k
                                         activeResultIndices);
479
1.60k
    assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
480
0
    assert(!activeResultIndices.empty() && "Result indices cannot be empty");
481
1.60k
    LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={";
482
1.60k
               llvm::interleave(
483
1.60k
                   activeParamIndices.begin(), activeParamIndices.end(),
484
1.60k
                   [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
485
1.60k
               s << "}, results={"; llvm::interleave(
486
1.60k
                   activeResultIndices.begin(), activeResultIndices.end(),
487
1.60k
                   [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
488
1.60k
               s << "}\n";);
489
490
    // Form expected indices.
491
1.60k
    auto numResults =
492
1.60k
        ai->getSubstCalleeType()->getNumResults() +
493
1.60k
        ai->getSubstCalleeType()->getNumIndirectMutatingParameters();
494
1.60k
    AutoDiffConfig config(
495
1.60k
        IndexSubset::get(getASTContext(),
496
1.60k
                         ai->getArgumentsWithoutIndirectResults().size(),
497
1.60k
                         activeParamIndices),
498
1.60k
        IndexSubset::get(getASTContext(), numResults, activeResultIndices));
499
500
    // Emit the JVP.
501
1.60k
    SILValue jvpValue;
502
    // If functionSource is a `@differentiable` function, just extract it.
503
1.60k
    if (originalFnTy->isDifferentiable()) {
504
24
      auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices();
505
36
      for (auto i : config.parameterIndices->getIndices()) {
506
36
        if (!paramIndices->contains(i)) {
507
0
          context.emitNondifferentiabilityError(
508
0
              origCallee, invoker,
509
0
              diag::
510
0
                  autodiff_function_noderivative_parameter_not_differentiable);
511
0
          errorOccurred = true;
512
0
          return;
513
0
        }
514
36
      }
515
24
      builder.emitScopedBorrowOperation(
516
24
          loc, origCallee, [&](SILValue borrowedDiffFunc) {
517
24
            jvpValue = builder.createDifferentiableFunctionExtract(
518
24
                loc, NormalDifferentiableFunctionTypeComponent::JVP,
519
24
                borrowedDiffFunc);
520
24
            jvpValue = builder.emitCopyValueOperation(loc, jvpValue);
521
24
          });
522
24
    }
523
524
    // If JVP has not yet been found, emit an `differentiable_function`
525
    // instruction on the remapped  function operand and
526
    // an `differentiable_function_extract` instruction to get the JVP.
527
    // The `differentiable_function` instruction will be canonicalized during
528
    // the transform main loop.
529
1.60k
    if (!jvpValue) {
530
      // FIXME: Handle indirect differentiation invokers. This may require some
531
      // redesign: currently, each original function + witness pair is mapped
532
      // only to one invoker.
533
      /*
534
       DifferentiationInvoker indirect(ai, attr);
535
       auto insertion =
536
           context.getInvokers().try_emplace({original, attr}, indirect);
537
       auto &invoker = insertion.first->getSecond();
538
       invoker = indirect;
539
       */
540
541
      // If the original `apply` instruction has a substitution map, then the
542
      // applied function is specialized.
543
      // In the JVP, specialization is also necessary for parity. The original
544
      // function operand is specialized with a remapped version of same
545
      // substitution map using an argument-less `partial_apply`.
546
1.58k
      if (ai->getSubstitutionMap().empty()) {
547
984
        origCallee = builder.emitCopyValueOperation(loc, origCallee);
548
984
      } else {
549
596
        auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap());
550
596
        auto jvpPartialApply = getBuilder().createPartialApply(
551
596
            ai->getLoc(), origCallee, substMap, {},
552
596
            ParameterConvention::Direct_Guaranteed);
553
596
        origCallee = jvpPartialApply;
554
596
      }
555
556
      // Check and diagnose non-differentiable original function type.
557
1.58k
      auto diagnoseNondifferentiableOriginalFunctionType =
558
1.58k
          [&](CanSILFunctionType origFnTy) {
559
            // Check and diagnose non-differentiable arguments.
560
2.52k
            for (auto paramIndex : config.parameterIndices->getIndices()) {
561
2.52k
              if (!originalFnTy->getParameters()[paramIndex]
562
2.52k
                       .getSILStorageInterfaceType()
563
2.52k
                       .isDifferentiable(getModule())) {
564
0
                auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex];
565
0
                auto startLoc = arg.getLoc().getStartSourceLoc();
566
0
                auto endLoc = arg.getLoc().getEndSourceLoc();
567
0
                context
568
0
                    .emitNondifferentiabilityError(
569
0
                        arg, invoker, diag::autodiff_nondifferentiable_argument)
570
0
                    .fixItInsert(startLoc, "withoutDerivative(at: ")
571
0
                    .fixItInsertAfter(endLoc, ")");
572
0
                errorOccurred = true;
573
0
                return true;
574
0
              }
575
2.52k
            }
576
            // Check and diagnose non-differentiable results.
577
1.59k
            for (auto resultIndex : config.resultIndices->getIndices()) {
578
1.59k
              SILType remappedResultType;
579
1.59k
              if (resultIndex >= originalFnTy->getNumResults()) {
580
92
                auto inoutArgIdx = resultIndex - originalFnTy->getNumResults();
581
92
                auto inoutArg =
582
92
                    *std::next(ai->getInoutArguments().begin(), inoutArgIdx);
583
92
                remappedResultType = inoutArg->getType();
584
1.50k
              } else {
585
1.50k
                remappedResultType = originalFnTy->getResults()[resultIndex]
586
1.50k
                                         .getSILStorageInterfaceType();
587
1.50k
              }
588
1.59k
              if (!remappedResultType.isDifferentiable(getModule())) {
589
0
                auto startLoc = ai->getLoc().getStartSourceLoc();
590
0
                auto endLoc = ai->getLoc().getEndSourceLoc();
591
0
                context
592
0
                    .emitNondifferentiabilityError(
593
0
                        origCallee, invoker,
594
0
                        diag::autodiff_nondifferentiable_result)
595
0
                    .fixItInsert(startLoc, "withoutDerivative(at: ")
596
0
                    .fixItInsertAfter(endLoc, ")");
597
0
                errorOccurred = true;
598
0
                return true;
599
0
              }
600
1.59k
            }
601
1.58k
            return false;
602
1.58k
          };
603
1.58k
      if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
604
0
        return;
605
606
1.58k
      auto *diffFuncInst = context.createDifferentiableFunction(
607
1.58k
          builder, loc, config.parameterIndices, config.resultIndices,
608
1.58k
          origCallee);
609
610
      // Record the `differentiable_function` instruction.
611
1.58k
      context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
612
613
1.58k
      builder.emitScopedBorrowOperation(
614
1.58k
          loc, diffFuncInst, [&](SILValue borrowedADFunc) {
615
1.58k
            auto extractedJVP = builder.createDifferentiableFunctionExtract(
616
1.58k
                loc, NormalDifferentiableFunctionTypeComponent::JVP,
617
1.58k
                borrowedADFunc);
618
1.58k
            jvpValue = builder.emitCopyValueOperation(loc, extractedJVP);
619
1.58k
          });
620
1.58k
      builder.emitDestroyValueOperation(loc, diffFuncInst);
621
1.58k
    }
622
623
    // Call the JVP using the original parameters.
624
1.60k
    SmallVector<SILValue, 8> jvpArgs;
625
1.60k
    auto jvpFnTy = getOpType(jvpValue->getType()).castTo<SILFunctionType>();
626
1.60k
    auto numJVPArgs =
627
1.60k
        jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults();
628
1.60k
    jvpArgs.reserve(numJVPArgs);
629
    // Collect substituted arguments.
630
1.60k
    for (auto origArg : ai->getArguments())
631
4.43k
      jvpArgs.push_back(getOpValue(origArg));
632
1.60k
    assert(jvpArgs.size() == numJVPArgs);
633
    // Apply the JVP.
634
    // The JVP should be specialized, so no substitution map is necessary.
635
0
    auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(),
636
1.60k
                                             jvpArgs, ai->getApplyOptions());
637
1.60k
    LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall);
638
639
    // Release the differentiable function.
640
1.60k
    builder.emitDestroyValueOperation(loc, jvpValue);
641
642
    // Get the JVP results (original results and differential).
643
1.60k
    SmallVector<SILValue, 8> jvpDirectResults;
644
1.60k
    extractAllElements(jvpCall, builder, jvpDirectResults);
645
1.60k
    auto originalDirectResults =
646
1.60k
        ArrayRef<SILValue>(jvpDirectResults).drop_back(1);
647
1.60k
    auto originalDirectResult =
648
1.60k
        joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc());
649
650
1.60k
    mapValue(ai, originalDirectResult);
651
652
    // Some instructions that produce the callee may have been cloned.
653
    // If the original callee did not have any users beyond this `apply`,
654
    // recursively kill the cloned callee.
655
1.60k
    if (auto *origCallee = cast_or_null<SingleValueInstruction>(
656
1.60k
            ai->getCallee()->getDefiningInstruction()))
657
1.58k
      if (origCallee->hasOneUse())
658
1.58k
        recursivelyDeleteTriviallyDeadInstructions(
659
1.58k
            getOpValue(origCallee)->getDefiningInstruction());
660
661
    // Add the differential function for when we create the struct we partially
662
    // apply to the differential we are generating.
663
1.60k
    auto differential = jvpDirectResults.back();
664
1.60k
    auto differentialType = differentialInfo.lookUpLinearMapType(ai);
665
1.60k
    auto originalDifferentialType =
666
1.60k
        getOpType(differential->getType()).getAs<SILFunctionType>();
667
1.60k
    auto loweredDifferentialType =
668
1.60k
        getOpType(getLoweredType(differentialType)).castTo<SILFunctionType>();
669
    // If actual differential type does not match lowered differential type,
670
    // reabstract the differential using a thunk.
671
1.60k
    if (!loweredDifferentialType->isEqual(originalDifferentialType)) {
672
388
      SILOptFunctionBuilder fb(context.getTransform());
673
388
      differential = reabstractFunction(
674
388
          builder, fb, loc, differential, loweredDifferentialType,
675
388
          [this](SubstitutionMap subs) -> SubstitutionMap {
676
388
            return this->getOpSubstitutionMap(subs);
677
388
          });
678
388
    }
679
1.60k
    differentialValues[ai->getParent()].push_back(differential);
680
681
    // Differential emission.
682
1.60k
    emitTangentForApplyInst(ai, config, originalDifferentialType);
683
1.60k
  }
684
685
1.33k
  void visitReturnInst(ReturnInst *ri) {
686
1.33k
    auto loc = ri->getOperand().getLoc();
687
1.33k
    auto *origExit = ri->getParent();
688
1.33k
    auto &builder = getBuilder();
689
1.33k
    auto *diffStructVal = buildDifferentialValueStructValue(ri);
690
691
    // Get the JVP value corresponding to the original functions's return value.
692
1.33k
    auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
693
1.33k
    auto origResult = getOpValue(origRetInst->getOperand());
694
1.33k
    SmallVector<SILValue, 8> origResults;
695
1.33k
    extractAllElements(origResult, builder, origResults);
696
697
    // Get and partially apply the differential.
698
1.33k
    auto jvpSubstMap = jvp->getForwardingSubstitutionMap();
699
1.33k
    auto *differentialRef = builder.createFunctionRef(loc, &getDifferential());
700
1.33k
    auto *differentialPartialApply = builder.createPartialApply(
701
1.33k
        loc, differentialRef, jvpSubstMap, {diffStructVal},
702
1.33k
        ParameterConvention::Direct_Guaranteed);
703
704
1.33k
    auto differentialType = jvp->mapTypeIntoContext(
705
1.33k
        jvp->getConventions().getSILType(
706
1.33k
            jvp->getLoweredFunctionType()->getResults().back(),
707
1.33k
            jvp->getTypeExpansionContext()));
708
1.33k
    auto differentialFnType = differentialType.castTo<SILFunctionType>();
709
1.33k
    auto differentialSubstType =
710
1.33k
        differentialPartialApply->getType().castTo<SILFunctionType>();
711
712
    // If necessary, convert the differential value to the returned differential
713
    // function type.
714
1.33k
    SILValue differentialValue;
715
1.33k
    if (differentialSubstType == differentialFnType) {
716
1.20k
      differentialValue = differentialPartialApply;
717
1.20k
    } else if (differentialSubstType
718
132
                   ->isABICompatibleWith(differentialFnType, *jvp)
719
132
                   .isCompatible()) {
720
132
      differentialValue = builder.createConvertFunction(
721
132
          loc, differentialPartialApply, differentialType,
722
132
          /*withoutActuallyEscaping*/ false);
723
132
    } else {
724
0
      llvm::report_fatal_error("Differential value type is not ABI-compatible "
725
0
                               "with the returned differential type");
726
0
    }
727
728
    // Return a tuple of the original result and differential.
729
1.33k
    SmallVector<SILValue, 8> directResults;
730
1.33k
    directResults.append(origResults.begin(), origResults.end());
731
1.33k
    directResults.push_back(differentialValue);
732
1.33k
    builder.createReturn(ri->getLoc(),
733
1.33k
                         joinElements(directResults, builder, loc));
734
1.33k
  }
735
736
0
  void visitBranchInst(BranchInst *bi) {
737
0
    llvm_unreachable("Unsupported SIL instruction.");
738
0
  }
739
740
0
  void visitCondBranchInst(CondBranchInst *cbi) {
741
0
    llvm_unreachable("Unsupported SIL instruction.");
742
0
  }
743
744
0
  void visitSwitchEnumInst(SwitchEnumInst *sei) {
745
0
    llvm_unreachable("Unsupported SIL instruction.");
746
0
  }
747
748
56
  void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
749
    // Clone `differentiable_function` from original to JVP, then add the cloned
750
    // instruction to the `differentiable_function` worklist.
751
56
    TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
752
56
    auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
753
56
    context.getDifferentiableFunctionInstWorklist().push_back(newDFI);
754
56
  }
755
756
0
  void visitLinearFunctionInst(LinearFunctionInst *lfi) {
757
    // Clone `linear_function` from original to JVP, then add the cloned
758
    // instruction to the `linear_function` worklist.
759
0
    TypeSubstCloner::visitLinearFunctionInst(lfi);
760
0
    auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi));
761
0
    context.getLinearFunctionInstWorklist().push_back(newLFI);
762
0
  }
763
764
  //--------------------------------------------------------------------------//
765
  // Tangent emission helpers
766
  //--------------------------------------------------------------------------//
767
768
#define CLONE_AND_EMIT_TANGENT(INST, ID)                                       \
769
5.21k
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
5.21k
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
5.21k
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
5.21k
      emitTangentFor##INST##Inst(inst);                                        \
773
5.21k
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation19visitAllocStackInstEPNS_14AllocStackInstE
Line
Count
Source
769
1.26k
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
1.26k
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
1.26k
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
1.26k
      emitTangentFor##INST##Inst(inst);                                        \
773
1.26k
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation18visitCopyValueInstEPNS_13CopyValueInstE
Line
Count
Source
769
112
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
112
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
112
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
112
      emitTangentFor##INST##Inst(inst);                                        \
773
112
  }                                                                            \
Unexecuted instantiation: _ZN5swift8autodiff9JVPCloner14Implementation19visitLoadBorrowInstEPNS_14LoadBorrowInstE
_ZN5swift8autodiff9JVPCloner14Implementation20visitBeginBorrowInstEPNS_15BeginBorrowInstE
Line
Count
Source
769
28
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
28
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
28
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
28
      emitTangentFor##INST##Inst(inst);                                        \
773
28
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation20visitBeginAccessInstEPNS_15BeginAccessInstE
Line
Count
Source
769
604
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
604
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
604
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
604
      emitTangentFor##INST##Inst(inst);                                        \
773
604
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation14visitTupleInstEPNS_9TupleInstE
Line
Count
Source
769
196
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
196
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
196
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
196
      emitTangentFor##INST##Inst(inst);                                        \
773
196
  }                                                                            \
Unexecuted instantiation: _ZN5swift8autodiff9JVPCloner14Implementation21visitTupleExtractInstEPNS_16TupleExtractInstE
_ZN5swift8autodiff9JVPCloner14Implementation25visitTupleElementAddrInstEPNS_20TupleElementAddrInstE
Line
Count
Source
769
328
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
328
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
328
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
328
      emitTangentFor##INST##Inst(inst);                                        \
773
328
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation15visitStructInstEPNS_10StructInstE
Line
Count
Source
769
24
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
24
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
24
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
24
      emitTangentFor##INST##Inst(inst);                                        \
773
24
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation22visitStructExtractInstEPNS_17StructExtractInstE
Line
Count
Source
769
220
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
220
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
220
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
220
      emitTangentFor##INST##Inst(inst);                                        \
773
220
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation26visitStructElementAddrInstEPNS_21StructElementAddrInstE
Line
Count
Source
769
184
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
184
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
184
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
184
      emitTangentFor##INST##Inst(inst);                                        \
773
184
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation21visitDeallocStackInstEPNS_16DeallocStackInstE
Line
Count
Source
769
1.24k
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
1.24k
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
1.24k
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
1.24k
      emitTangentFor##INST##Inst(inst);                                        \
773
1.24k
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation21visitDestroyValueInstEPNS_16DestroyValueInstE
Line
Count
Source
769
188
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
188
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
188
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
188
      emitTangentFor##INST##Inst(inst);                                        \
773
188
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation18visitEndBorrowInstEPNS_13EndBorrowInstE
Line
Count
Source
769
28
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
28
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
28
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
28
      emitTangentFor##INST##Inst(inst);                                        \
773
28
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation18visitEndAccessInstEPNS_13EndAccessInstE
Line
Count
Source
769
592
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
592
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
592
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
592
      emitTangentFor##INST##Inst(inst);                                        \
773
592
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation20visitDestroyAddrInstEPNS_15DestroyAddrInstE
Line
Count
Source
769
176
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
176
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
176
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
176
      emitTangentFor##INST##Inst(inst);                                        \
773
176
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation37visitUnconditionalCheckedCastAddrInstEPNS_32UnconditionalCheckedCastAddrInstE
Line
Count
Source
769
4
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
4
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
4
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
4
      emitTangentFor##INST##Inst(inst);                                        \
773
4
  }                                                                            \
_ZN5swift8autodiff9JVPCloner14Implementation25visitDestructureTupleInstEPNS_20DestructureTupleInstE
Line
Count
Source
769
24
  void visit##INST##Inst(INST##Inst *inst) {                                   \
770
24
    TypeSubstCloner::visit##INST##Inst(inst);                                  \
771
24
    if (differentialInfo.shouldDifferentiateInstruction(inst))                 \
772
24
      emitTangentFor##INST##Inst(inst);                                        \
773
24
  }                                                                            \
774
  void emitTangentFor##INST##Inst(INST##Inst *(ID))
775
776
0
  CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) {
777
0
    auto &diffBuilder = getDifferentialBuilder();
778
0
    auto loc = bbi->getLoc();
779
0
    auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc);
780
0
    auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal);
781
0
    setTangentValue(bbi->getParent(), bbi,
782
0
                    makeConcreteTangentValue(tanValBorrow));
783
0
  }
784
785
0
  CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) {
786
0
    auto &diffBuilder = getDifferentialBuilder();
787
0
    auto loc = ebi->getLoc();
788
0
    auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc);
789
0
    diffBuilder.emitEndBorrowOperation(loc, tanVal);
790
0
  }
791
792
16
  CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) {
793
16
    auto &diffBuilder = getDifferentialBuilder();
794
16
    auto loc = dvi->getLoc();
795
16
    auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc);
796
16
    diffBuilder.emitDestroyValueOperation(loc, tanVal);
797
16
  }
798
799
12
  CLONE_AND_EMIT_TANGENT(CopyValue, cvi) {
800
12
    auto &diffBuilder = getDifferentialBuilder();
801
12
    auto tan = getTangentValue(cvi->getOperand());
802
12
    auto tanVal = materializeTangent(tan, cvi->getLoc());
803
12
    auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal);
804
12
    setTangentValue(cvi->getParent(), cvi,
805
12
                    makeConcreteTangentValue(tanValCopy));
806
12
  }
807
808
  /// Handle `load` instruction.
809
  ///   Original: y = load x
810
  ///    Tangent: tan[y] = load tan[x]
811
564
  void visitLoadInst(LoadInst *li) {
812
564
    TypeSubstCloner::visitLoadInst(li);
813
    // If an active buffer is loaded with take to a non-active value, destroy
814
    // the active buffer's tangent buffer.
815
564
    if (!differentialInfo.shouldDifferentiateInstruction(li)) {
816
12
      auto isTake =
817
12
          (li->getOwnershipQualifier() == LoadOwnershipQualifier::Take);
818
12
      if (isTake && activityInfo.isActive(li->getOperand(), getConfig())) {
819
0
        auto &tanBuf = getTangentBuffer(li->getParent(), li->getOperand());
820
0
        getDifferentialBuilder().emitDestroyOperation(tanBuf.getLoc(), tanBuf);
821
0
      }
822
12
      return;
823
12
    }
824
    // Otherwise, do standard differential cloning.
825
552
    auto &diffBuilder = getDifferentialBuilder();
826
552
    auto *bb = li->getParent();
827
552
    auto loc = li->getLoc();
828
552
    auto tanBuf = getTangentBuffer(bb, li->getOperand());
829
552
    auto tanVal = diffBuilder.emitLoadValueOperation(
830
552
        loc, tanBuf, li->getOwnershipQualifier());
831
552
    setTangentValue(bb, li, makeConcreteTangentValue(tanVal));
832
552
  }
833
834
  /// Handle `load_borrow` instruction.
835
  ///   Original: y = load_borrow x
836
  ///    Tangent: tan[y] = load_borrow tan[x]
837
0
  CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) {
838
0
    auto &diffBuilder = getDifferentialBuilder();
839
0
    auto *bb = lbi->getParent();
840
0
    auto loc = lbi->getLoc();
841
0
    auto tanBuf = getTangentBuffer(bb, lbi->getOperand());
842
0
    auto tanVal = diffBuilder.emitLoadBorrowOperation(loc, tanBuf);
843
0
    setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal));
844
0
  }
845
846
  /// Handle `store` instruction in the differential.
847
  ///   Original: store x to y
848
  ///     Tangent: store tan[x] to tan[y]
849
884
  void visitStoreInst(StoreInst *si) {
850
884
    TypeSubstCloner::visitStoreInst(si);
851
    // If a non-active value is stored into an active buffer, zero-initialize
852
    // the active buffer's tangent buffer.
853
884
    if (!differentialInfo.shouldDifferentiateInstruction(si)) {
854
116
      if (activityInfo.isActive(si->getDest(), getConfig())) {
855
0
        auto &tanBufDest = getTangentBuffer(si->getParent(), si->getDest());
856
0
        emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
857
0
                         tanBufDest.getLoc());
858
0
      }
859
116
      return;
860
116
    }
861
    // Otherwise, do standard differential cloning.
862
768
    auto &diffBuilder = getDifferentialBuilder();
863
768
    auto loc = si->getLoc();
864
768
    auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc);
865
768
    auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest());
866
768
    diffBuilder.emitStoreValueOperation(loc, tanValSrc, tanValDest,
867
768
                                        si->getOwnershipQualifier());
868
768
  }
869
870
  /// Handle `store_borrow` instruction in the differential.
871
  ///   Original: store_borrow x to y
872
  ///    Tangent: store_borrow tan[x] to tan[y]
873
0
  void visitStoreBorrowInst(StoreBorrowInst *sbi) {
874
0
    TypeSubstCloner::visitStoreBorrowInst(sbi);
875
    // If a non-active value is stored into an active buffer, zero-initialize
876
    // the active buffer's tangent buffer.
877
0
    if (!differentialInfo.shouldDifferentiateInstruction(sbi)) {
878
0
      if (activityInfo.isActive(sbi->getDest(), getConfig())) {
879
0
        auto &tanBufDest = getTangentBuffer(sbi->getParent(), sbi->getDest());
880
0
        emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
881
0
                         tanBufDest.getLoc());
882
0
      }
883
0
      return;
884
0
    }
885
    // Otherwise, do standard differential cloning.
886
0
    auto &diffBuilder = getDifferentialBuilder();
887
0
    auto loc = sbi->getLoc();
888
0
    auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc);
889
0
    auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest());
890
0
    diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest);
891
0
  }
892
893
  /// Handle `copy_addr` instruction.
894
  ///   Original: copy_addr x to y
895
  ///    Tangent: copy_addr tan[x] to tan[y]
896
248
  void visitCopyAddrInst(CopyAddrInst *cai) {
897
248
    TypeSubstCloner::visitCopyAddrInst(cai);
898
    // If a non-active buffer is copied into an active buffer, zero-initialize
899
    // the destination buffer's tangent buffer.
900
    // If an active buffer is copied with take into a non-active buffer, destroy
901
    // the source buffer's tangent buffer.
902
248
    if (!differentialInfo.shouldDifferentiateInstruction(cai)) {
903
4
      if (activityInfo.isActive(cai->getDest(), getConfig())) {
904
0
        auto &tanBufDest = getTangentBuffer(cai->getParent(), cai->getDest());
905
0
        emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
906
0
                         tanBufDest.getLoc());
907
0
      }
908
4
      if (cai->isTakeOfSrc() &&
909
4
          activityInfo.isActive(cai->getSrc(), getConfig())) {
910
0
        auto &tanBufSrc = getTangentBuffer(cai->getParent(), cai->getSrc());
911
0
        getDifferentialBuilder().emitDestroyOperation(tanBufSrc.getLoc(),
912
0
                                                      tanBufSrc);
913
0
      }
914
4
      return;
915
4
    }
916
    // Otherwise, do standard differential cloning.
917
244
    auto diffBuilder = getDifferentialBuilder();
918
244
    auto loc = cai->getLoc();
919
244
    auto *bb = cai->getParent();
920
244
    auto &tanSrc = getTangentBuffer(bb, cai->getSrc());
921
244
    auto tanDest = getTangentBuffer(bb, cai->getDest());
922
244
    diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(),
923
244
                               cai->isInitializationOfDest());
924
244
  }
925
926
  /// Handle `unconditional_checked_cast_addr` instruction.
927
  ///   Original: unconditional_checked_cast_addr $X in x to $Y in y
928
  ///    Tangent: unconditional_checked_cast_addr $X.Tan in tan[x]
929
  ///                                          to $Y.Tan in tan[y]
930
4
  CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai) {
931
4
    auto diffBuilder = getDifferentialBuilder();
932
4
    auto loc = uccai->getLoc();
933
4
    auto *bb = uccai->getParent();
934
4
    auto &tanSrc = getTangentBuffer(bb, uccai->getSrc());
935
4
    auto tanDest = getTangentBuffer(bb, uccai->getDest());
936
937
4
    diffBuilder.createUnconditionalCheckedCastAddr(
938
4
        loc, tanSrc, tanSrc->getType().getASTType(), tanDest,
939
4
        tanDest->getType().getASTType());
940
4
  }
941
942
  /// Handle `begin_access` instruction (and do differentiability checks).
943
  ///   Original: y = begin_access x
944
  ///    Tangent: tan[y] = begin_access tan[x]
945
588
  CLONE_AND_EMIT_TANGENT(BeginAccess, bai) {
946
    // Check for non-differentiable writes.
947
588
    if (bai->getAccessKind() == SILAccessKind::Modify) {
948
260
      if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) {
949
0
        context.emitNondifferentiabilityError(
950
0
            bai, invoker,
951
0
            diag::autodiff_cannot_differentiate_writes_to_global_variables);
952
0
        errorOccurred = true;
953
0
        return;
954
0
      }
955
260
      if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) {
956
0
        context.emitNondifferentiabilityError(
957
0
            bai, invoker,
958
0
            diag::autodiff_cannot_differentiate_writes_to_mutable_captures);
959
0
        errorOccurred = true;
960
0
        return;
961
0
      }
962
260
    }
963
964
588
    auto &diffBuilder = getDifferentialBuilder();
965
588
    auto *bb = bai->getParent();
966
967
588
    auto tanSrc = getTangentBuffer(bb, bai->getSource());
968
588
    auto *tanDest = diffBuilder.createBeginAccess(
969
588
        bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(),
970
588
        bai->hasNoNestedConflict(), bai->isFromBuiltin());
971
588
    setTangentBuffer(bb, bai, tanDest);
972
588
  }
973
974
  /// Handle `end_access` instruction.
975
  ///   Original: begin_access x
976
  ///    Tangent: end_access tan[x]
977
576
  CLONE_AND_EMIT_TANGENT(EndAccess, eai) {
978
576
    auto &diffBuilder = getDifferentialBuilder();
979
576
    auto *bb = eai->getParent();
980
576
    auto loc = eai->getLoc();
981
576
    auto tanOperand = getTangentBuffer(bb, eai->getOperand());
982
576
    diffBuilder.createEndAccess(loc, tanOperand, eai->isAborting());
983
576
  }
984
985
  /// Handle `alloc_stack` instruction.
986
  ///   Original: y = alloc_stack $T
987
  ///    Tangent: tan[y] = alloc_stack $T.Tangent
988
1.18k
  CLONE_AND_EMIT_TANGENT(AllocStack, asi) {
989
1.18k
    auto &diffBuilder = getDifferentialBuilder();
990
1.18k
    auto *mappedAllocStackInst = diffBuilder.createAllocStack(
991
1.18k
        asi->getLoc(), getRemappedTangentType(asi->getElementType()),
992
1.18k
        asi->getVarInfo());
993
1.18k
    setTangentBuffer(asi->getParent(), asi, mappedAllocStackInst);
994
1.18k
  }
995
996
  /// Handle `dealloc_stack` instruction.
997
  ///   Original: dealloc_stack x
998
  ///    Tangent: dealloc_stack tan[x]
999
1.16k
  CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) {
1000
1.16k
    auto &diffBuilder = getDifferentialBuilder();
1001
1.16k
    auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand());
1002
1.16k
    diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf);
1003
1.16k
  }
1004
1005
  /// Handle `destroy_addr` instruction.
1006
  ///   Original: destroy_addr x
1007
  ///    Tangent: destroy_addr tan[x]
1008
164
  CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) {
1009
164
    auto &diffBuilder = getDifferentialBuilder();
1010
164
    auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand());
1011
164
    diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf);
1012
164
  }
1013
1014
  /// Handle `struct` instruction.
1015
  ///   Original: y = struct $T (x0, x1, x2, ...)
1016
  ///    Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...)
1017
24
  CLONE_AND_EMIT_TANGENT(Struct, si) {
1018
24
    auto &diffBuilder = getDifferentialBuilder();
1019
24
    SmallVector<SILValue, 4> tangentElements;
1020
24
    for (auto elem : si->getElements())
1021
32
      tangentElements.push_back(getTangentValue(elem).getConcreteValue());
1022
24
    auto tanExtract = diffBuilder.createStruct(
1023
24
        si->getLoc(), getRemappedTangentType(si->getType()), tangentElements);
1024
24
    setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract));
1025
24
  }
1026
1027
  /// Handle `struct_extract` instruction.
1028
  ///   Original: y = struct_extract x, #field
1029
  ///    Tangent: tan[y] = struct_extract tan[x], #field'
1030
  ///                                             ^~~~~~~
1031
  ///                          field in tangent space corresponding to #field
1032
204
  CLONE_AND_EMIT_TANGENT(StructExtract, sei) {
1033
204
    assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
1034
204
           "`struct_extract` with `@noDerivative` field should not be "
1035
204
           "differentiated; activity analysis should not marked as varied.");
1036
0
    auto diffBuilder = getDifferentialBuilder();
1037
204
    auto loc = getValidLocation(sei);
1038
    // Find the corresponding field in the tangent space.
1039
204
    auto structType =
1040
204
        remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType();
1041
204
    auto *tanField =
1042
204
      getTangentStoredProperty(context, sei, structType, invoker);
1043
204
    if (!tanField) {
1044
8
      errorOccurred = true;
1045
8
      return;
1046
8
    }
1047
    // Emit tangent `struct_extract`.
1048
196
    auto tanStruct =
1049
196
        materializeTangent(getTangentValue(sei->getOperand()), loc);
1050
196
    auto tangentInst =
1051
196
        diffBuilder.createStructExtract(loc, tanStruct, tanField);
1052
    // Update tangent value mapping for `struct_extract` result.
1053
196
    auto tangentResult = makeConcreteTangentValue(tangentInst);
1054
196
    setTangentValue(sei->getParent(), sei, tangentResult);
1055
196
  }
1056
1057
  /// Handle `struct_element_addr` instruction.
1058
  ///   Original: y = struct_element_addr x, #field
1059
  ///    Tangent: tan[y] = struct_element_addr tan[x], #field'
1060
  ///                                                  ^~~~~~~
1061
  ///                          field in tangent space corresponding to #field
1062
180
  CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) {
1063
180
    assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
1064
180
           "`struct_element_addr` with `@noDerivative` field should not be "
1065
180
           "differentiated; activity analysis should not marked as varied.");
1066
0
    auto diffBuilder = getDifferentialBuilder();
1067
180
    auto *bb = seai->getParent();
1068
180
    auto loc = getValidLocation(seai);
1069
    // Find the corresponding field in the tangent space.
1070
180
    auto structType =
1071
180
        remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType();
1072
180
    auto *tanField =
1073
180
      getTangentStoredProperty(context, seai, structType, invoker);
1074
180
    if (!tanField) {
1075
12
      errorOccurred = true;
1076
12
      return;
1077
12
    }
1078
    // Emit tangent `struct_element_addr`.
1079
168
    auto tanOperand = getTangentBuffer(bb, seai->getOperand());
1080
168
    auto tangentInst =
1081
168
        diffBuilder.createStructElementAddr(loc, tanOperand, tanField);
1082
    // Update tangent buffer map for `struct_element_addr`.
1083
168
    setTangentBuffer(bb, seai, tangentInst);
1084
168
  }
1085
1086
  /// Handle `tuple` instruction.
1087
  ///   Original: y = tuple (x0, x1, x2, ...)
1088
  ///    Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...)
1089
  ///                                                        ^~~
1090
  ///                                      excluding non-differentiable elements
1091
8
  CLONE_AND_EMIT_TANGENT(Tuple, ti) {
1092
8
    auto diffBuilder = getDifferentialBuilder();
1093
    // Get the tangents of all the tuple elements.
1094
8
    SmallVector<SILValue, 8> tangentTupleElements;
1095
16
    for (auto elem : ti->getElements()) {
1096
16
      if (!getTangentSpace(elem->getType().getASTType()))
1097
0
        continue;
1098
16
      tangentTupleElements.push_back(
1099
16
          materializeTangent(getTangentValue(elem), ti->getLoc()));
1100
16
    }
1101
    // Emit the instruction and add the tangent mapping.
1102
8
    auto tanTuple =
1103
8
        joinElements(tangentTupleElements, diffBuilder, ti->getLoc());
1104
8
    setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple));
1105
8
  }
1106
1107
  /// Handle `tuple_extract` instruction.
1108
  ///   Original: y = tuple_extract x, <n>
1109
  ///    Tangent: tan[y] = tuple_extract tan[x], <n'>
1110
  ///                                            ^~~~
1111
  ///                         tuple tangent space index corresponding to n
1112
0
  CLONE_AND_EMIT_TANGENT(TupleExtract, tei) {
1113
0
    auto &diffBuilder = getDifferentialBuilder();
1114
0
    auto loc = tei->getLoc();
1115
0
    auto origTupleTy = tei->getOperand()->getType().castTo<TupleType>();
1116
0
    unsigned tanIndex = 0;
1117
0
    for (unsigned i : range(tei->getFieldIndex())) {
1118
0
      if (getTangentSpace(
1119
0
              origTupleTy->getElement(i).getType()->getCanonicalType()))
1120
0
        ++tanIndex;
1121
0
    }
1122
0
    auto tanType = getRemappedTangentType(tei->getType());
1123
0
    auto tanSource =
1124
0
        materializeTangent(getTangentValue(tei->getOperand()), loc);
1125
    // If the tangent value of the source does not have a tuple type, then
1126
    // it must represent a "single element tuple type". Use it directly.
1127
0
    if (!tanSource->getType().is<TupleType>()) {
1128
0
      setTangentValue(tei->getParent(), tei,
1129
0
                      makeConcreteTangentValue(tanSource));
1130
0
    } else {
1131
0
      auto tanElt =
1132
0
          diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType);
1133
0
      setTangentValue(tei->getParent(), tei, makeConcreteTangentValue(tanElt));
1134
0
    }
1135
0
  }
1136
1137
  /// Handle `tuple_element_addr` instruction.
1138
  ///   Original: y = tuple_element_addr x, <n>
1139
  ///    Tangent: tan[y] = tuple_element_addr tan[x], <n'>
1140
  ///                                                ^~~~
1141
  ///                            tuple tangent space index corresponding to n
1142
272
  CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) {
1143
272
    auto &diffBuilder = getDifferentialBuilder();
1144
272
    auto origTupleTy = teai->getOperand()->getType().castTo<TupleType>();
1145
272
    unsigned tanIndex = 0;
1146
272
    for (unsigned i : range(teai->getFieldIndex())) {
1147
128
      if (getTangentSpace(
1148
128
              origTupleTy->getElement(i).getType()->getCanonicalType()))
1149
64
        ++tanIndex;
1150
128
    }
1151
272
    auto tanType = getRemappedTangentType(teai->getType());
1152
272
    auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand());
1153
272
    SILValue tanBuf;
1154
    // If the tangent buffer of the source does not have a tuple type, then
1155
    // it must represent a "single element tuple type". Use it directly.
1156
272
    if (!tanSource->getType().is<TupleType>()) {
1157
52
      tanBuf = tanSource;
1158
220
    } else {
1159
220
      tanBuf = diffBuilder.createTupleElementAddr(teai->getLoc(), tanSource,
1160
220
                                                  tanIndex, tanType);
1161
220
    }
1162
272
    setTangentBuffer(teai->getParent(), teai, tanBuf);
1163
272
  }
1164
1165
  /// Handle `destructure_tuple` instruction.
1166
  ///   Original: (y0, y1, ...)  = destructure_tuple x, <n>
1167
  ///    Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], <n'>
1168
  ///                                                                 ^~~~
1169
  ///                              tuple tangent space index corresponding to n
1170
12
  CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) {
1171
12
    assert(llvm::any_of(dti->getResults(),
1172
12
                        [&](SILValue elt) {
1173
12
                          return activityInfo.isActive(elt, getConfig());
1174
12
                        }) &&
1175
12
           "`destructure_tuple` should have at least one active result");
1176
1177
0
    auto &diffBuilder = getDifferentialBuilder();
1178
12
    auto *bb = dti->getParent();
1179
12
    auto loc = dti->getLoc();
1180
1181
12
    auto tanTuple = materializeTangent(getTangentValue(dti->getOperand()), loc);
1182
12
    SmallVector<SILValue, 4> tanElts;
1183
12
    if (tanTuple->getType().is<TupleType>()) {
1184
12
      auto *tanDti = diffBuilder.createDestructureTuple(loc, tanTuple);
1185
12
      tanElts.append(tanDti->getResults().begin(), tanDti->getResults().end());
1186
12
    } else {
1187
0
      tanElts.push_back(tanTuple);
1188
0
    }
1189
12
    unsigned tanIdx = 0;
1190
24
    for (auto i : range(dti->getNumResults())) {
1191
24
      auto origElt = dti->getResult(i);
1192
24
      if (!getTangentSpace(origElt->getType().getASTType()))
1193
0
        continue;
1194
24
      setTangentValue(bb, origElt, makeConcreteTangentValue(tanElts[tanIdx++]));
1195
24
    }
1196
12
  }
1197
1198
#undef CLONE_AND_EMIT_TANGENT
1199
1200
  /// Handle `apply` instruction, given:
1201
  /// - The minimal indices for differentiating the `apply`.
1202
  /// - The original non-reabstracted differential type.
1203
  ///
1204
  ///   Original: y = apply f(x0, x1, ...)
1205
  ///    Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...)
1206
  void emitTangentForApplyInst(ApplyInst *ai, const AutoDiffConfig &applyConfig,
1207
1.60k
                               CanSILFunctionType originalDifferentialType) {
1208
1.60k
    assert(differentialInfo.shouldDifferentiateApplySite(ai));
1209
0
    auto *bb = ai->getParent();
1210
1.60k
    auto loc = ai->getLoc();
1211
1.60k
    auto &diffBuilder = getDifferentialBuilder();
1212
1213
    // Get the differential value.
1214
1.60k
    SILValue differential = getDifferentialTupleElement(ai);
1215
1.60k
    auto differentialType = remapSILTypeInDifferential(differential->getType())
1216
1.60k
                                .castTo<SILFunctionType>();
1217
1218
    // Get the differential arguments.
1219
1.60k
    SmallVector<SILValue, 8> diffArgs;
1220
1221
1.60k
    for (auto indRes : ai->getIndirectSILResults())
1222
452
      diffArgs.push_back(getTangentBuffer(bb, indRes));
1223
1224
1.60k
    auto origArgs = ai->getArgumentsWithoutIndirectResults();
1225
    // Get the tangent value of the original arguments.
1226
3.98k
    for (auto i : indices(origArgs)) {
1227
3.98k
      auto origArg = origArgs[i];
1228
      // If the argument is not active:
1229
      // - Skip the element, if it is not differentiable.
1230
      // - Otherwise, add a zero value to that location.
1231
3.98k
      if (!activityInfo.isActive(origArg, getConfig())) {
1232
1.42k
        auto origCalleeType = ai->getSubstCalleeType();
1233
1.42k
        if (!origCalleeType->isDifferentiable())
1234
1.41k
          continue;
1235
8
        auto actualOrigCalleeIndices =
1236
8
            origCalleeType->getDifferentiabilityParameterIndices();
1237
8
        if (actualOrigCalleeIndices->contains(i)) {
1238
4
          SILValue tanParam;
1239
4
          if (origArg->getType().isObject()) {
1240
4
            tanParam = emitZeroDirect(
1241
4
                getRemappedTangentType(origArg->getType()).getASTType(), loc);
1242
4
            diffArgs.push_back(tanParam);
1243
4
          } else {
1244
0
            tanParam = diffBuilder.createAllocStack(
1245
0
                loc, getRemappedTangentType(origArg->getType()));
1246
0
            emitZeroIndirect(
1247
0
                getRemappedTangentType(origArg->getType()).getASTType(),
1248
0
                tanParam, loc);
1249
0
          }
1250
4
        }
1251
8
      }
1252
      // Otherwise, if the argument is active, handle the argument normally by
1253
      // getting its tangent value.
1254
2.56k
      else {
1255
2.56k
        SILValue tanParam;
1256
2.56k
        if (origArg->getType().isObject()) {
1257
1.71k
          tanParam = materializeTangent(getTangentValue(origArg), loc);
1258
1.71k
        } else {
1259
844
          tanParam = getTangentBuffer(ai->getParent(), origArg);
1260
844
        }
1261
2.56k
        diffArgs.push_back(tanParam);
1262
2.56k
        if (errorOccurred)
1263
0
          return;
1264
2.56k
      }
1265
3.98k
    }
1266
1267
    // If callee differential was reabstracted in JVP, reabstract the callee
1268
    // differential.
1269
1.60k
    if (!differentialType->isEqual(originalDifferentialType)) {
1270
388
      SILOptFunctionBuilder fb(context.getTransform());
1271
388
      differential = reabstractFunction(
1272
388
          diffBuilder, fb, loc, differential, originalDifferentialType,
1273
388
          [this](SubstitutionMap subs) -> SubstitutionMap {
1274
388
            return this->getOpSubstitutionMap(subs);
1275
388
          });
1276
388
    }
1277
1278
    // Call the differential.
1279
1.60k
    auto *differentialCall =
1280
1.60k
        diffBuilder.createApply(loc, differential, SubstitutionMap(), diffArgs);
1281
1.60k
    diffBuilder.emitDestroyValueOperation(loc, differential);
1282
1283
    // Get the original `apply` results.
1284
1.60k
    SmallVector<SILValue, 8> origDirectResults;
1285
1.60k
    forEachApplyDirectResult(ai, [&](SILValue directResult) {
1286
1.08k
      origDirectResults.push_back(directResult);
1287
1.08k
    });
1288
1.60k
    SmallVector<SILValue, 8> origAllResults;
1289
1.60k
    collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults);
1290
1291
    // Get the callee differential `apply` results.
1292
1.60k
    SmallVector<SILValue, 8> differentialDirectResults;
1293
1.60k
    extractAllElements(differentialCall, getDifferentialBuilder(),
1294
1.60k
                       differentialDirectResults);
1295
1.60k
    SmallVector<SILValue, 8> differentialAllResults;
1296
1.60k
    collectAllActualResultsInTypeOrder(
1297
1.60k
        differentialCall, differentialDirectResults, differentialAllResults);
1298
1.60k
    for (auto inoutArg : ai->getInoutArguments())
1299
92
      origAllResults.push_back(inoutArg);
1300
1.60k
    for (auto inoutArg : differentialCall->getInoutArguments())
1301
92
      differentialAllResults.push_back(inoutArg);
1302
1.60k
    assert(applyConfig.resultIndices->getNumIndices() ==
1303
1.60k
           differentialAllResults.size());
1304
1305
    // Set tangent values for original `apply` results.
1306
0
    unsigned differentialResultIndex = 0;
1307
1.61k
    for (auto resultIndex : applyConfig.resultIndices->getIndices()) {
1308
1.61k
      auto origResult = origAllResults[resultIndex];
1309
1.61k
      auto differentialResult =
1310
1.61k
          differentialAllResults[differentialResultIndex++];
1311
1.61k
      if (origResult->getType().isObject()) {
1312
1.07k
        if (!origResult->getType().is<TupleType>()) {
1313
1.07k
          setTangentValue(bb, origResult,
1314
1.07k
                          makeConcreteTangentValue(differentialResult));
1315
1.07k
        } else if (auto *dti = getSingleDestructureTupleUser(ai)) {
1316
0
          bool notSetValue = true;
1317
0
          for (auto result : dti->getResults()) {
1318
0
            if (activityInfo.isActive(result, getConfig())) {
1319
0
              assert(notSetValue &&
1320
0
                     "This was incorrectly set, should only have one active "
1321
0
                     "result from the tuple.");
1322
0
              notSetValue = false;
1323
0
              setTangentValue(bb, result,
1324
0
                              makeConcreteTangentValue(differentialResult));
1325
0
            }
1326
0
          }
1327
0
        }
1328
1.07k
      }
1329
1.61k
    }
1330
1.60k
  }
1331
1332
  /// Generate a `return` instruction in the current differential basic block.
1333
1.35k
  void emitReturnInstForDifferential() {
1334
1.35k
    auto &differential = getDifferential();
1335
1.35k
    auto diffLoc = differential.getLocation();
1336
1.35k
    auto &diffBuilder = getDifferentialBuilder();
1337
1338
    // Collect original results.
1339
1.35k
    SmallVector<SILValue, 2> originalResults;
1340
1.35k
    collectAllDirectResultsInTypeOrder(*original, originalResults);
1341
    // Collect differential direct results.
1342
1.35k
    SmallVector<SILValue, 8> retElts;
1343
1.35k
    for (auto i : range(originalResults.size())) {
1344
1.19k
      auto origResult = originalResults[i];
1345
1.19k
      if (!getConfig().resultIndices->contains(i))
1346
8
        continue;
1347
1.18k
      auto tanVal = materializeTangent(getTangentValue(origResult), diffLoc);
1348
1.18k
      retElts.push_back(tanVal);
1349
1.18k
    }
1350
1351
1.35k
    diffBuilder.createReturn(diffLoc,
1352
1.35k
                             joinElements(retElts, diffBuilder, diffLoc));
1353
1.35k
  }
1354
};
1355
1356
//--------------------------------------------------------------------------//
1357
// Initialization
1358
//--------------------------------------------------------------------------//
1359
1360
/// Initialization helper function.
1361
///
1362
/// Returns the substitution map used for type remapping.
1363
static SubstitutionMap getSubstitutionMap(SILFunction *original,
1364
1.35k
                                          SILFunction *jvp) {
1365
1.35k
  auto substMap = original->getForwardingSubstitutionMap();
1366
1.35k
  if (auto *jvpGenEnv = jvp->getGenericEnvironment()) {
1367
160
    auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap();
1368
160
    substMap = SubstitutionMap::get(
1369
160
        jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap},
1370
160
        LookUpConformanceInSubstitutionMap(jvpSubstMap));
1371
160
  }
1372
1.35k
  return substMap;
1373
1.35k
}
1374
1375
/// Initialization helper function.
1376
///
1377
/// Returns the activity info for the given original function, autodiff indices,
1378
/// and JVP generic signature.
1379
static const DifferentiableActivityInfo &
1380
getActivityInfo(ADContext &context, SILFunction *original,
1381
1.35k
                const AutoDiffConfig &config, SILFunction *jvp) {
1382
  // Get activity info of the original function.
1383
1.35k
  auto &passManager = context.getPassManager();
1384
1.35k
  auto *activityAnalysis =
1385
1.35k
      passManager.getAnalysis<DifferentiableActivityAnalysis>();
1386
1.35k
  auto &activityCollection = *activityAnalysis->get(original);
1387
1.35k
  auto &activityInfo = activityCollection.getActivityInfo(
1388
1.35k
      jvp->getLoweredFunctionType()->getSubstGenericSignature(),
1389
1.35k
      AutoDiffDerivativeFunctionKind::JVP);
1390
1.35k
  LLVM_DEBUG(activityInfo.dump(config, getADDebugStream()));
1391
1.35k
  return activityInfo;
1392
1.35k
}
1393
1394
JVPCloner::Implementation::Implementation(ADContext &context,
1395
                                          SILDifferentiabilityWitness *witness,
1396
                                          SILFunction *jvp,
1397
                                          DifferentiationInvoker invoker)
1398
    : TypeSubstCloner(*jvp, *witness->getOriginalFunction(),
1399
                      getSubstitutionMap(witness->getOriginalFunction(), jvp)),
1400
      context(context), original(witness->getOriginalFunction()),
1401
      witness(witness), jvp(jvp), invoker(invoker),
1402
      activityInfo(
1403
          getActivityInfo(context, original, witness->getConfig(), jvp)),
1404
      loopInfo(context.getPassManager().getAnalysis<SILLoopAnalysis>()
1405
                   ->get(original)),
1406
      differentialInfo(context, AutoDiffLinearMapKind::Differential, original,
1407
                       jvp, witness->getConfig(), activityInfo, loopInfo),
1408
      differentialBuilder(TangentBuilder(
1409
          *createEmptyDifferential(context, witness, &differentialInfo),
1410
          context)),
1411
1.35k
      diffLocalAllocBuilder(getDifferential(), context) {
1412
  // Create empty differential function.
1413
1.35k
  context.recordGeneratedFunction(&getDifferential());
1414
1.35k
}
1415
1416
JVPCloner::JVPCloner(ADContext &context, SILDifferentiabilityWitness *witness,
1417
                     SILFunction *jvp, DifferentiationInvoker invoker)
1418
1.35k
    : impl(*new Implementation(context, witness, jvp, invoker)) {}
1419
1420
1.35k
JVPCloner::~JVPCloner() { delete &impl; }
1421
1422
//--------------------------------------------------------------------------//
1423
// Differential struct mapping
1424
//--------------------------------------------------------------------------//
1425
1426
void JVPCloner::Implementation::initializeDifferentialTupleElements(
1427
1.35k
  SILBasicBlock *origBB, SILInstructionResultArray values) {
1428
1.35k
  auto *diffTupleTyple = differentialInfo.getLinearMapTupleType(origBB);
1429
1.35k
  assert(diffTupleTyple->getNumElements() == values.size() &&
1430
1.35k
         "The number of differential tuple fields must equal the number of "
1431
1.35k
         "differential struct element values");
1432
0
  auto res = differentialTupleElements.insert({origBB, values});
1433
1.35k
  (void)res;
1434
1.35k
  assert(res.second && "A pullback struct element already exists!");
1435
1.35k
}
1436
1437
/// Returns the differential tuple element value corresponding to the given
1438
/// original block and apply inst.
1439
1.60k
SILValue JVPCloner::Implementation::getDifferentialTupleElement(ApplyInst *ai) {
1440
1.60k
  unsigned idx = differentialInfo.lookUpLinearMapIndex(ai);
1441
1.60k
    assert((idx > 0 || (idx == 0 && ai->getParentBlock()->isEntry())) &&
1442
1.60k
           "impossible linear map index");
1443
0
  auto values = differentialTupleElements.lookup(ai->getParentBlock());
1444
1.60k
  assert(idx < values.size() &&
1445
1.60k
         "differential tuple element for this apply does not exist!");
1446
0
  return values[idx];
1447
1.60k
}
1448
1449
//--------------------------------------------------------------------------//
1450
// Tangent emission helpers
1451
//--------------------------------------------------------------------------//
1452
1453
1.35k
void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1454
  // Create differential blocks and arguments.
1455
1.35k
  auto &differential = getDifferential();
1456
1.35k
  auto diffLoc = differential.getLocation();
1457
1.35k
  auto *origEntry = original->getEntryBlock();
1458
1.35k
  auto origFnTy = original->getLoweredFunctionType();
1459
1460
1.35k
  for (auto &origBB : *original) {
1461
1.35k
    auto *diffBB = differential.createBasicBlock();
1462
1.35k
    diffBBMap.insert({&origBB, diffBB});
1463
    // If the BB is the original entry, then the differential block that we
1464
    // just created must be the differential function's entry. Create
1465
    // differential entry arguments and continue.
1466
1.35k
    if (&origBB == origEntry) {
1467
1.35k
      assert(diffBB->isEntry());
1468
0
      createEntryArguments(&differential);
1469
1.35k
      auto *lastArg = diffBB->getArguments().back();
1470
1.35k
#ifndef NDEBUG
1471
1.35k
      auto diffTupleLoweredType = remapSILTypeInDifferential(
1472
1.35k
          differentialInfo.getLinearMapTupleLoweredType(&origBB));
1473
1.35k
      assert(lastArg->getType() == diffTupleLoweredType);
1474
0
#endif
1475
0
      differentialStructArguments[&origBB] = lastArg;
1476
1.35k
    }
1477
1478
1.35k
    LLVM_DEBUG({
1479
1.35k
      auto &s = getADDebugStream()
1480
1.35k
                << "Original bb" + std::to_string(origBB.getDebugID())
1481
1.35k
                << ": To differentiate or not to differentiate?\n";
1482
1.35k
      for (auto &inst : origBB) {
1483
1.35k
        s << (differentialInfo.shouldDifferentiateInstruction(&inst) ? "[x] "
1484
1.35k
                                                                     : "[ ] ")
1485
1.35k
          << inst;
1486
1.35k
      }
1487
1.35k
    });
1488
1.35k
  }
1489
1490
1.35k
  assert(diffBBMap.size() == 1 &&
1491
1.35k
         "Can only currently handle single basic block functions");
1492
1493
  // The differential function has type:
1494
  // (arg0', ..., argn', entry_df_struct) -> result'.
1495
0
  auto diffParamArgs =
1496
1.35k
      differential.getArgumentsWithoutIndirectResults().drop_back();
1497
1.35k
  assert(diffParamArgs.size() ==
1498
1.35k
         witness->getConfig().parameterIndices->getNumIndices());
1499
0
  auto origParamArgs = original->getArgumentsWithoutIndirectResults();
1500
1501
  // TODO(TF-788): Re-enable non-varied result warning.
1502
  /*
1503
  // Check if result is not varied.
1504
  SmallVector<SILValue, 8> origFormalResults;
1505
  collectAllFormalResultsInTypeOrder(*original, origFormalResults);
1506
   std::get<0>(pair);
1507
  for (auto resultIndex : getConfig().results->getIndices()) {
1508
    auto origResult = origFormalResults[resultIndex];
1509
    // Emit warning if original result is not varied, because it will always
1510
    // have a zero derivative.
1511
    if (!activityInfo.isVaried(origResult, getConfig().parameters)) {
1512
      // Emit fixit if original result has a valid source location.
1513
      auto startLoc = origResult.getLoc().getStartSourceLoc();
1514
      auto endLoc = origResult.getLoc().getEndSourceLoc();
1515
      if (startLoc.isValid() && endLoc.isValid()) {
1516
        context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
1517
            .fixItInsert(startLoc, "withoutDerivative(at:")
1518
            .fixItInsertAfter(endLoc, ")");
1519
      }
1520
    }
1521
  }
1522
  */
1523
1524
  // Initialize tangent mapping for parameters.
1525
1.35k
  auto diffParamsIt = getConfig().parameterIndices->begin();
1526
2.00k
  for (auto index : range(diffParamArgs.size())) {
1527
2.00k
    auto *diffArg = diffParamArgs[index];
1528
2.00k
    auto *origArg = origParamArgs[*diffParamsIt];
1529
2.00k
    ++diffParamsIt;
1530
2.00k
    if (diffArg->getType().isAddress()) {
1531
248
      setTangentBuffer(origEntry, origArg, diffArg);
1532
1.75k
    } else {
1533
1.75k
      setTangentValue(origEntry, origArg, makeConcreteTangentValue(diffArg));
1534
1.75k
    }
1535
2.00k
    LLVM_DEBUG(getADDebugStream()
1536
2.00k
               << "Assigned parameter " << *diffArg
1537
2.00k
               << " as the tangent of original result " << *origArg);
1538
2.00k
  }
1539
1540
  // Initialize tangent mapping for original indirect results and non-wrt
1541
  // `inout` parameters. The tangent buffers of these address values are
1542
  // differential indirect results.
1543
1544
  // Collect original results.
1545
1.35k
  SmallVector<SILValue, 2> originalResults;
1546
1.35k
  collectAllFormalResultsInTypeOrder(*original, originalResults);
1547
1548
  // Iterate over differentiability results.
1549
1.35k
  differentialBuilder.setInsertionPoint(differential.getEntryBlock());
1550
1.35k
  auto diffIndResults = differential.getIndirectResults();
1551
1.35k
  unsigned differentialIndirectResultIndex = 0;
1552
1.36k
  for (auto resultIndex : getConfig().resultIndices->getIndices()) {
1553
1.36k
    auto origResult = originalResults[resultIndex];
1554
    // Handle original formal indirect result.
1555
1.36k
    if (resultIndex < origFnTy->getNumResults()) {
1556
      // Skip original direct results.
1557
1.32k
      if (origResult->getType().isObject())
1558
1.18k
        continue;
1559
140
      auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1560
140
      setTangentBuffer(origEntry, origResult, diffIndResult);
1561
      // If original indirect result is non-varied, zero-initialize its tangent
1562
      // buffer.
1563
140
      if (!activityInfo.isVaried(origResult, getConfig().parameterIndices))
1564
8
        emitZeroIndirect(diffIndResult->getType().getASTType(), diffIndResult,
1565
8
                         diffLoc);
1566
140
      continue;
1567
1.32k
    }
1568
    // Handle original non-wrt `inout` parameter.
1569
    // Only original *non-wrt* `inout` parameters have corresponding
1570
    // differential indirect results.
1571
40
    auto inoutParamIndex = resultIndex - origFnTy->getNumResults();
1572
40
    auto inoutParamIt = std::next(
1573
40
        origFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
1574
40
    auto paramIndex =
1575
40
        std::distance(origFnTy->getParameters().begin(), &*inoutParamIt);
1576
40
    if (getConfig().parameterIndices->contains(paramIndex))
1577
40
      continue;
1578
0
    auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1579
0
    setTangentBuffer(origEntry, origResult, diffIndResult);
1580
    // Original `inout` parameters are initialized, so their tangent buffers
1581
    // must also be initialized.
1582
0
    emitZeroIndirect(diffIndResult->getType().getASTType(), diffIndResult,
1583
0
                     diffLoc);
1584
0
  }
1585
1.35k
}
1586
1587
/*static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential(
1588
    ADContext &context, SILDifferentiabilityWitness *witness,
1589
1.35k
    LinearMapInfo *linearMapInfo) {
1590
1.35k
  auto &module = context.getModule();
1591
1.35k
  auto *original = witness->getOriginalFunction();
1592
1.35k
  auto *jvp = witness->getJVP();
1593
1.35k
  auto origTy = original->getLoweredFunctionType();
1594
  // Get witness generic signature for remapping types.
1595
  // Witness generic signature may have more requirements than JVP generic
1596
  // signature: when witness generic signature has same-type requirements
1597
  // binding all generic parameters to concrete types, JVP function type uses
1598
  // all the concrete types and JVP generic signature is null.
1599
1.35k
  auto witnessCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature();
1600
1.35k
  auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
1601
1602
  // Parameters of the differential are:
1603
  // - the tangent values of the wrt parameters.
1604
  // - the differential struct for the original entry.
1605
  // Result of the differential is in the tangent space of the original
1606
  // result.
1607
1.35k
  SmallVector<SILParameterInfo, 8> dfParams;
1608
1.35k
  SmallVector<SILResultInfo, 8> dfResults;
1609
1.35k
  auto origParams = origTy->getParameters();
1610
1.35k
  auto config = witness->getConfig();
1611
1612
1.36k
  for (auto resultIndex : config.resultIndices->getIndices()) {
1613
1.36k
    if (resultIndex < origTy->getNumResults()) {
1614
      // Handle formal original result.
1615
1.32k
      auto origResult = origTy->getResults()[resultIndex];
1616
1.32k
      origResult = origResult.getWithInterfaceType(
1617
1.32k
          origResult.getInterfaceType()->getReducedType(witnessCanGenSig));
1618
1.32k
      dfResults.push_back(
1619
1.32k
          SILResultInfo(origResult.getInterfaceType()
1620
1.32k
                            ->getAutoDiffTangentSpace(lookupConformance)
1621
1.32k
                            ->getType()
1622
1.32k
                            ->getReducedType(witnessCanGenSig),
1623
1.32k
                        origResult.getConvention()));
1624
1.32k
    } else {
1625
      // Handle original `inout` parameter.
1626
40
      auto inoutParamIndex = resultIndex - origTy->getNumResults();
1627
40
      auto inoutParamIt = std::next(
1628
40
          origTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
1629
40
      auto paramIndex =
1630
40
          std::distance(origTy->getParameters().begin(), &*inoutParamIt);
1631
      // If the original `inout` parameter is a differentiability parameter,
1632
      // then it already has a corresponding differential parameter. Do not add
1633
      // a corresponding differential result.
1634
40
      if (config.parameterIndices->contains(paramIndex))
1635
40
        continue;
1636
0
      auto inoutParam = origTy->getParameters()[paramIndex];
1637
0
      auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
1638
0
          lookupConformance);
1639
0
      assert(paramTan && "Parameter type does not have a tangent space?");
1640
0
      dfResults.push_back(
1641
0
          {paramTan->getCanonicalType(), ResultConvention::Indirect});
1642
0
    }
1643
1.36k
  }
1644
1645
  // Add differential parameters for the requested wrt parameters.
1646
2.00k
  for (auto i : config.parameterIndices->getIndices()) {
1647
2.00k
    auto origParam = origParams[i];
1648
2.00k
    origParam = origParam.getWithInterfaceType(
1649
2.00k
        origParam.getInterfaceType()->getReducedType(witnessCanGenSig));
1650
2.00k
    dfParams.push_back(
1651
2.00k
        SILParameterInfo(origParam.getInterfaceType()
1652
2.00k
                             ->getAutoDiffTangentSpace(lookupConformance)
1653
2.00k
                             ->getType()
1654
2.00k
                             ->getReducedType(witnessCanGenSig),
1655
2.00k
                         origParam.getConvention()));
1656
2.00k
  }
1657
1658
  // Accept a differential struct in the differential parameter list. This is
1659
  // the returned differential's closure context.
1660
1.35k
  auto *origEntry = original->getEntryBlock();
1661
1.35k
  auto dfTupleType =
1662
1.35k
    linearMapInfo->getLinearMapTupleLoweredType(origEntry).getASTType();
1663
1.35k
  dfParams.push_back({dfTupleType, ParameterConvention::Direct_Owned});
1664
1665
1.35k
  Mangle::DifferentiationMangler mangler;
1666
1.35k
  auto diffName = mangler.mangleLinearMap(
1667
1.35k
      witness->getOriginalFunction()->getName(),
1668
1.35k
      AutoDiffLinearMapKind::Differential, witness->getConfig());
1669
  // Set differential generic signature equal to JVP generic signature.
1670
  // Do not use witness generic signature, which may have same-type requirements
1671
  // binding all generic parameters to concrete types.
1672
1.35k
  auto diffGenericSig =
1673
1.35k
      jvp->getLoweredFunctionType()->getSubstGenericSignature();
1674
1.35k
  auto *diffGenericEnv = diffGenericSig.getGenericEnvironment();
1675
1.35k
  auto diffType = SILFunctionType::get(
1676
1.35k
      diffGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(),
1677
1.35k
      origTy->getCalleeConvention(), dfParams, {}, dfResults, llvm::None,
1678
1.35k
      origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(),
1679
1.35k
      original->getASTContext());
1680
1681
1.35k
  SILOptFunctionBuilder fb(context.getTransform());
1682
1.35k
  auto linkage = jvp->isSerialized() ? SILLinkage::Public : SILLinkage::Private;
1683
1.35k
  auto *differential = fb.createFunction(
1684
1.35k
      linkage, context.getASTContext().getIdentifier(diffName).str(), diffType,
1685
1.35k
      diffGenericEnv, original->getLocation(), original->isBare(),
1686
1.35k
      IsNotTransparent, jvp->isSerialized(),
1687
1.35k
      original->isDynamicallyReplaceable(),
1688
1.35k
      original->isDistributed(),
1689
1.35k
      original->isRuntimeAccessible());
1690
1.35k
  differential->setDebugScope(
1691
1.35k
      new (module) SILDebugScope(original->getLocation(), differential));
1692
1693
1.35k
  return differential;
1694
1.35k
}
1695
1696
1.35k
bool JVPCloner::Implementation::run() {
1697
1.35k
  PrettyStackTraceSILFunction trace("generating JVP and differential for",
1698
1.35k
                                    original);
1699
1.35k
  LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName()
1700
1.35k
                                << " to jvp @" << jvp->getName() << '\n');
1701
  // Create JVP and differential entry and arguments.
1702
1.35k
  auto *entry = jvp->createBasicBlock();
1703
1.35k
  createEntryArguments(jvp);
1704
1.35k
  prepareForDifferentialGeneration();
1705
  // Clone.
1706
1.35k
  SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(),
1707
1.35k
                                     entry->getArguments().end());
1708
1.35k
  cloneFunctionBody(original, entry, entryArgs);
1709
1.35k
  emitReturnInstForDifferential();
1710
  // If errors occurred, back out.
1711
1.35k
  if (errorOccurred)
1712
20
    return true;
1713
1.33k
  LLVM_DEBUG(getADDebugStream()
1714
1.33k
             << "Generated JVP for " << original->getName() << ":\n"
1715
1.33k
             << *jvp);
1716
1.33k
  LLVM_DEBUG(getADDebugStream()
1717
1.33k
             << "Generated differential for " << original->getName() << ":\n"
1718
1.33k
             << getDifferential());
1719
1.33k
  return errorOccurred;
1720
1.35k
}
1721
1722
} // end namespace autodiff
1723
} // end namespace swift
1724
1725
1.35k
bool JVPCloner::run() {
1726
1.35k
  bool foundError = impl.run();
1727
1.35k
#ifndef NDEBUG
1728
1.35k
  if (!foundError)
1729
1.33k
    getJVP().verify();
1730
1.35k
#endif
1731
1.35k
  return foundError;
1732
1.35k
}
1733
1734
1.33k
SILFunction &JVPCloner::getJVP() const { return impl.getJVP(); }