Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Mandatory/Differentiation.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- Differentiation.cpp - SIL Automatic Differentiation --*- C++ -*---===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2018 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// This file implements automatic differentiation.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#define DEBUG_TYPE "differentiation"
18
19
#include "swift/AST/ASTMangler.h"
20
#include "swift/AST/ASTPrinter.h"
21
#include "swift/AST/AnyFunctionRef.h"
22
#include "swift/AST/AutoDiff.h"
23
#include "swift/AST/Builtins.h"
24
#include "swift/AST/DeclContext.h"
25
#include "swift/AST/DiagnosticsSIL.h"
26
#include "swift/AST/Expr.h"
27
#include "swift/AST/GenericEnvironment.h"
28
#include "swift/AST/LazyResolver.h"
29
#include "swift/AST/ParameterList.h"
30
#include "swift/AST/SourceFile.h"
31
#include "swift/AST/SubstitutionMap.h"
32
#include "swift/AST/TypeCheckRequests.h"
33
#include "swift/SIL/FormalLinkage.h"
34
#include "swift/SIL/PrettyStackTrace.h"
35
#include "swift/SIL/SILBuilder.h"
36
#include "swift/SIL/TypeSubstCloner.h"
37
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
38
#include "swift/SILOptimizer/Differentiation/ADContext.h"
39
#include "swift/SILOptimizer/Differentiation/JVPCloner.h"
40
#include "swift/SILOptimizer/Differentiation/Thunk.h"
41
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
42
#include "swift/SILOptimizer/PassManager/Passes.h"
43
#include "swift/SILOptimizer/PassManager/Transforms.h"
44
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
45
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
46
#include "llvm/ADT/APSInt.h"
47
#include "llvm/ADT/BreadthFirstIterator.h"
48
#include "llvm/ADT/DenseSet.h"
49
#include "llvm/ADT/SmallSet.h"
50
#include "llvm/Support/CommandLine.h"
51
52
using namespace swift;
53
using namespace swift::autodiff;
54
using llvm::DenseMap;
55
using llvm::SmallDenseMap;
56
using llvm::SmallDenseSet;
57
using llvm::SmallMapVector;
58
using llvm::SmallSet;
59
60
/// This flag enables experimental `@differentiable(_linear)` function
61
/// transposition.
62
static llvm::cl::opt<bool> EnableExperimentalLinearMapTransposition(
63
    "enable-experimental-linear-map-transposition", llvm::cl::init(false));
64
65
//===----------------------------------------------------------------------===//
66
// Helpers
67
//===----------------------------------------------------------------------===//
68
69
/// Given a dumpable value, dumps it to `llvm::dbgs()`.
70
24.3k
template <typename T> static inline void debugDump(T &v) {
71
24.3k
  LLVM_DEBUG(llvm::dbgs() << "\n==== BEGIN DEBUG DUMP ====\n"
72
24.3k
                          << v << "\n==== END DEBUG DUMP ====\n");
73
24.3k
}
74
75
namespace {
76
77
class DifferentiationTransformer {
78
private:
79
  /// Reference to the main transform.
80
  SILModuleTransform &transform;
81
82
  /// Context necessary for performing the transformations.
83
  ADContext context;
84
85
  /// Promotes the given `differentiable_function` instruction to a valid
86
  /// `@differentiable` function-typed value.
87
  SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst,
88
                                           SILBuilder &builder, SILLocation loc,
89
                                           DifferentiationInvoker invoker);
90
91
  /// Given a `linear_function` instruction that is missing a transpose operand,
92
  /// return a new `linear_function` instruction with the transpose filled in.
93
  SILValue promoteToLinearFunction(LinearFunctionInst *inst,
94
                                   SILBuilder &builder, SILLocation loc,
95
                                   DifferentiationInvoker invoker);
96
97
public:
98
  /// Construct an `DifferentiationTransformer` for the given module.
99
  explicit DifferentiationTransformer(SILModuleTransform &transform)
100
24.3k
      : transform(transform), context(transform) {}
101
102
4
  SILModuleTransform &getTransform() { return transform; }
103
104
59.2k
  ADContext &getContext() { return context; }
105
106
  /// Canonicalize the given witness, filling in derivative functions if
107
  /// missing.
108
  ///
109
  /// Generated derivative functions have the same linkage as the witness.
110
  ///
111
  /// \param serializeFunctions specifies whether generated functions should be
112
  ///        serialized.
113
  bool canonicalizeDifferentiabilityWitness(
114
      SILDifferentiabilityWitness *witness, DifferentiationInvoker invoker,
115
      IsSerialized_t serializeFunctions);
116
117
  /// Process the given `differentiable_function` instruction, filling in
118
  /// missing derivative functions if necessary.
119
  bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi);
120
121
  /// Process the given `linear_function` instruction, filling in the missing
122
  /// transpose function if necessary.
123
  bool processLinearFunctionInst(LinearFunctionInst *lfi);
124
};
125
126
} // end anonymous namespace
127
128
/// If the original function doesn't have a return, it cannot be differentiated.
129
/// Returns true if error is emitted.
130
static bool diagnoseNoReturn(ADContext &context, SILFunction *original,
131
6.62k
                             DifferentiationInvoker invoker) {
132
6.62k
  if (original->findReturnBB() != original->end())
133
6.62k
    return false;
134
4
  context.emitNondifferentiabilityError(
135
4
      original->getLocation().getEndSourceLoc(), invoker,
136
4
      diag::autodiff_missing_return);
137
4
  return true;
138
6.62k
}
139
140
/// If the original function contains unsupported control flow, emit a "control
141
/// flow unsupported" error at appropriate source locations. Returns true if
142
/// error is emitted.
143
///
144
/// Update as control flow support is added.
145
static bool diagnoseUnsupportedControlFlow(ADContext &context,
146
                                           SILFunction *original,
147
6.62k
                                           DifferentiationInvoker invoker) {
148
6.62k
  if (original->size() <= 1)
149
6.10k
    return false;
150
  // Diagnose unsupported branching terminators.
151
2.54k
  for (auto &bb : *original) {
152
2.54k
    auto *term = bb.getTerminator();
153
    // Check supported branching terminators.
154
2.54k
    if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
155
2.54k
        isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) ||
156
2.54k
        isa<CheckedCastBranchInst>(term) ||
157
2.54k
        isa<CheckedCastAddrBranchInst>(term) || isa<TryApplyInst>(term))
158
1.95k
      continue;
159
    // If terminator is an unsupported branching terminator, emit an error.
160
588
    if (term->isBranch()) {
161
0
      context.emitNondifferentiabilityError(
162
0
          term, invoker, diag::autodiff_control_flow_not_supported);
163
0
      return true;
164
0
    }
165
588
  }
166
520
  return false;
167
520
}
168
169
/// Check whether the given requirements are satisfied, with the given
170
/// derivative generic signature (containing requirements), and substitution
171
/// map. Returns true if error is emitted.
172
static bool diagnoseUnsatisfiedRequirements(ADContext &context,
173
                                            CanSILFunctionType origFnTy,
174
                                            GenericSignature derivativeGenSig,
175
                                            SubstitutionMap substMap,
176
                                            DifferentiationInvoker invoker,
177
22.5k
                                            SourceLoc loc) {
178
  // If the original function is polymorphic and its generic signature is the
179
  // same as the derivative generic signature, then the requirements are
180
  // satisfied. This check is necessary because the subsequent logic does not
181
  // correctly handle polymorphic original functions.
182
  // TODO(TF-1055): Can be removed after we have a robust solution for TF-1055.
183
22.5k
  if (origFnTy->getInvocationGenericSignature() && derivativeGenSig &&
184
22.5k
      origFnTy->getInvocationGenericSignature()->isEqual(derivativeGenSig))
185
192
    return false;
186
187
  // If there are no derivative requirements, return false.
188
22.3k
  auto requirements = derivativeGenSig.getRequirements();
189
22.3k
  if (requirements.empty())
190
16.4k
    return false;
191
  // Iterate through all requirements and check whether they are satisfied.
192
5.93k
  auto *swiftModule = context.getModule().getSwiftModule();
193
5.93k
  SmallVector<Requirement, 2> unsatisfiedRequirements;
194
13.8k
  for (auto req : requirements) {
195
13.8k
    auto firstType = req.getFirstType();
196
13.8k
    Type secondType;
197
    // Substitute first and second types using the given substitution map,
198
    // looking up conformances in the current module, if possible.
199
13.8k
    if (auto substFirstType =
200
13.8k
            firstType.subst(QuerySubstitutionMap{substMap},
201
13.8k
                            LookUpConformanceInModule(swiftModule))) {
202
13.8k
      firstType = substFirstType;
203
13.8k
    }
204
13.8k
    if (req.getKind() != RequirementKind::Layout) {
205
13.7k
      secondType = req.getSecondType();
206
13.7k
      if (auto substSecondType =
207
13.7k
              secondType.subst(QuerySubstitutionMap{substMap},
208
13.7k
                               LookUpConformanceInModule(swiftModule))) {
209
13.7k
        secondType = substSecondType;
210
13.7k
      }
211
13.7k
    }
212
13.8k
    switch (req.getKind()) {
213
0
    case RequirementKind::SameShape:
214
0
      llvm_unreachable("Same-shape requirement not supported here");
215
216
    // Check layout requirements.
217
16
    case RequirementKind::Layout: {
218
16
      auto layout = req.getLayoutConstraint();
219
16
      switch (layout->getKind()) {
220
16
      case LayoutConstraintKind::Class:
221
16
        if (!firstType->satisfiesClassConstraint())
222
0
          unsatisfiedRequirements.push_back(req);
223
16
        continue;
224
0
      default:
225
        // TODO: Check other layout requirements. Note that `@differentiable`
226
        // attribute type-checking does not yet support layout requirements in
227
        // where clauses; layout requirements in derivative generic signatures
228
        // can be formed only from `differentiable_function` instructions whose
229
        // original function operand is generic with layout requirements.
230
0
        break;
231
16
      }
232
0
      continue;
233
16
    }
234
    // Check same type requirements.
235
3.48k
    case RequirementKind::SameType:
236
      // If the first type does not equal the second type, then record the
237
      // unsatisfied requirement.
238
3.48k
      if (!firstType->isEqual(secondType))
239
0
        unsatisfiedRequirements.push_back(req);
240
3.48k
      continue;
241
    // Check superclass requirements.
242
48
    case RequirementKind::Superclass: {
243
      // If the second type is not an exact superclass of second type, then
244
      // record the unsatisfied requirement.
245
48
      if (!secondType->isExactSuperclassOf(firstType))
246
0
        unsatisfiedRequirements.push_back(req);
247
48
      continue;
248
16
    }
249
    // Check conformance requirements.
250
10.2k
    case RequirementKind::Conformance: {
251
10.2k
      auto *protocol = req.getProtocolDecl();
252
10.2k
      assert(protocol && "Expected protocol in generic signature requirement");
253
      // If the first type does not conform to the second type in the current
254
      // module, then record the unsatisfied requirement.
255
10.2k
      if (!swiftModule->lookupConformance(firstType, protocol))
256
8
        unsatisfiedRequirements.push_back(req);
257
10.2k
      continue;
258
16
    }
259
13.8k
    }
260
13.8k
  }
261
5.93k
  if (unsatisfiedRequirements.empty())
262
5.92k
    return false;
263
  // Diagnose unsatisfied requirements.
264
4
  std::string reqText;
265
4
  llvm::raw_string_ostream stream(reqText);
266
4
  interleave(
267
4
      unsatisfiedRequirements,
268
8
      [&](Requirement req) { req.print(stream, PrintOptions()); },
269
4
      [&] { stream << ", "; });
270
4
  context.emitNondifferentiabilityError(
271
4
      loc, invoker, diag::autodiff_function_assoc_func_unmet_requirements,
272
4
      stream.str());
273
4
  return true;
274
5.93k
}
275
276
//===----------------------------------------------------------------------===//
277
// Code emission utilities
278
//===----------------------------------------------------------------------===//
279
280
/// Given an apply site, emit copies of all parameters and place them in
281
/// `copiedArgs`. Any buffers that need to be destroyed will be added to
282
/// `newArgsToDestroy`. Any new buffers that need to be deallocated will be
283
/// added to `newBuffersToDealloc`. This helper is used for duplicating an
284
/// apply site.
285
static void copyParameterArgumentsForApply(
286
    ApplySite applySite, SmallVectorImpl<SILValue> &copiedArgs,
287
    SmallVectorImpl<SILValue> &newArgsToDestroy,
288
6.97k
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
289
6.97k
  LLVM_DEBUG({
290
6.97k
    auto &s = getADDebugStream() << "Copying arguments from apply site: ";
291
6.97k
    applySite.getInstruction()->print(s);
292
6.97k
  });
293
6.97k
  auto loc = applySite.getLoc();
294
6.97k
  copiedArgs.reserve(applySite.getNumArguments());
295
6.97k
  SILBuilderWithScope copyBuilder(applySite.getInstruction());
296
6.97k
  for (auto &argOperand : applySite.getArgumentOperands()) {
297
756
    auto arg = argOperand.get();
298
756
    auto argConv = applySite.getArgumentConvention(argOperand);
299
756
    auto collectNewArg = [&](SILValue newArg) {
300
756
      copiedArgs.push_back(newArg);
301
756
      if (argConv.isGuaranteedConvention() &&
302
756
          argConv != SILArgumentConvention::Indirect_InoutAliasable)
303
624
        newArgsToDestroy.push_back(newArg);
304
756
    };
305
    // Copy the argument if it's to be owned by the newly created closure.
306
    // Objects are to be retained.
307
756
    if (arg->getType().isObject()) {
308
660
      auto newArg = arg;
309
660
      if (newArg->getOwnershipKind() != OwnershipKind::None)
310
520
        newArg = copyBuilder.emitCopyValueOperation(loc, arg);
311
660
      collectNewArg(newArg);
312
660
      continue;
313
660
    }
314
    // Addresses depend on argument conventions.
315
    // If the argument is an aliasable inout reference, do not copy the
316
    // argument since it's a `@noescape` capture.
317
96
    if (argConv == SILArgumentConvention::Indirect_InoutAliasable) {
318
0
      collectNewArg(arg);
319
0
      continue;
320
0
    }
321
    // Otherwise, it must be address-only. Create a new buffer and perform
322
    // `copy_addr`.
323
96
    auto *argCopy = copyBuilder.createAllocStack(loc, arg->getType());
324
96
    newBuffersToDealloc.push_back(argCopy);
325
96
    copyBuilder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
326
96
    collectNewArg(argCopy);
327
96
  }
328
6.97k
}
329
330
/// When a function value is used in an instruction (usually `apply`), there may
331
/// be conversion instructions in between, e.g. `thin_to_thick_function`. Given
332
/// a new function value and an old function value, this helper function
333
/// recursively converts the new function just like how the old function is
334
/// converted.
335
///
336
/// If the new function's generic signature is specified, it is used
337
/// to create substitution maps for reapplied `partial_apply` instructions.
338
static SILValue reapplyFunctionConversion(
339
    ADContext &context, SILValue newFunc, SILValue oldFunc,
340
    SILValue oldConvertedFunc, SILBuilder &builder, SILLocation loc,
341
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc,
342
    IndexSubset *parameterIndices, IndexSubset *resultIndices,
343
37.3k
    GenericSignature newFuncGenSig = GenericSignature()) {
344
  // If the old func is the new func, then there's no conversion.
345
37.3k
  if (oldFunc == oldConvertedFunc)
346
23.1k
    return newFunc;
347
  // Handle a few instruction cases.
348
  // copy_value
349
14.1k
  if (auto *cvi = dyn_cast<CopyValueInst>(oldConvertedFunc)) {
350
    // Note: no `copy_value` is needed for the re-converted function because the
351
    // caller of `reapplyFunctionConversion` should consume the re-converted
352
    // function.
353
704
    return reapplyFunctionConversion(
354
704
        context, newFunc, oldFunc, cvi->getOperand(), builder, loc,
355
704
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
356
704
  }
357
  // begin_borrow
358
13.4k
  if (auto *bbi = dyn_cast<BeginBorrowInst>(oldConvertedFunc)) {
359
    // Note: no `begin_borrow` is needed for the re-converted function because
360
    // the caller of `reapplyFunctionConversion` should consume the re-converted
361
    // function.
362
0
    return reapplyFunctionConversion(
363
0
        context, newFunc, oldFunc, bbi->getOperand(), builder, loc,
364
0
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
365
0
  }
366
  // convert_function
367
13.4k
  if (auto *cfi = dyn_cast<ConvertFunctionInst>(oldConvertedFunc)) {
368
208
    return reapplyFunctionConversion(
369
208
        context, newFunc, oldFunc, cfi->getOperand(), builder, loc,
370
208
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
371
208
  }
372
  // thin_to_thick_function
373
13.2k
  if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) {
374
6.23k
    auto innerNewFunc = reapplyFunctionConversion(
375
6.23k
        context, newFunc, oldFunc, tttfi->getOperand(), builder, loc,
376
6.23k
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
377
6.23k
    auto operandFnTy = innerNewFunc->getType().castTo<SILFunctionType>();
378
6.23k
    auto thickTy = operandFnTy->getWithRepresentation(
379
6.23k
        SILFunctionTypeRepresentation::Thick);
380
6.23k
    auto silTy = SILType::getPrimitiveObjectType(thickTy);
381
6.23k
    return builder.createThinToThickFunction(loc, innerNewFunc, silTy);
382
6.23k
  }
383
  // partial_apply
384
6.96k
  if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) {
385
6.96k
    SmallVector<SILValue, 8> newArgs;
386
6.96k
    newArgs.reserve(pai->getNumArguments());
387
6.96k
    SmallVector<SILValue, 1> newArgsToDestroy;
388
6.96k
    copyParameterArgumentsForApply(pai, newArgs, newArgsToDestroy,
389
6.96k
                                   newBuffersToDealloc);
390
6.96k
    auto innerNewFunc = reapplyFunctionConversion(
391
6.96k
        context, newFunc, oldFunc, pai->getCallee(), builder, loc,
392
6.96k
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
393
    // Reabstraction thunk `partial_apply` reapplications require special
394
    // support. Reabstraction thunk JVP/VJP expects a `@differentiable`
395
    // function-typed argument to avoid opaque function non-differentiability
396
    // errors. Thus, `partial_apply` reapplications must first form a
397
    // `differentiable_function` of the function-typed thunk argument.
398
6.96k
    auto isReabstractionThunkCallee = [&]() -> bool {
399
6.96k
      auto *fri = dyn_cast<FunctionRefInst>(oldFunc);
400
6.96k
      return fri && fri->getReferencedFunction()->isThunk() ==
401
6.57k
                        IsReabstractionThunk;
402
6.96k
    };
403
6.96k
    if (isReabstractionThunkCallee()) {
404
352
      assert(newArgs.size() == 1 &&
405
352
             "Expected reabstraction thunk to be partially applied with only "
406
352
             "one argument");
407
0
      auto *dfi = context.createDifferentiableFunction(
408
352
          builder, loc, parameterIndices, resultIndices, newArgs.back());
409
352
      context.getDifferentiableFunctionInstWorklist().push_back(dfi);
410
352
      newArgs.back() = dfi;
411
352
    }
412
    // Compute substitution map for reapplying `partial_apply`.
413
    // - If reapplied function is not polymorphic, use empty substitution map
414
    //   regardless of the original `partial_apply`'s substitution map.
415
    //   - This case is triggered for reapplying `partial_apply` where `newFunc`
416
    //     is a `differentiability_witness_function` where the witness generic
417
    //     signature has all concrete parameters while the original function's
418
    //     generic signature does not. In this case, the original function type
419
    //     is polymorphic while derivative function types are not (specialized
420
    //     with concrete types from same-type requirements).
421
    // - Otherwise, if `newFuncGenSig` is not specified, use the original
422
    //   `partial_apply`'s substitution map.
423
    // - Otherwise, if `newFuncGenSig` is specified, combine it with the
424
    //   original `partial_apply`'s substitution map.
425
0
    SubstitutionMap substMap;
426
6.96k
    if (innerNewFunc->getType().castTo<SILFunctionType>()->isPolymorphic()) {
427
6.35k
      if (!newFuncGenSig) {
428
384
        substMap = pai->getSubstitutionMap();
429
5.96k
      } else {
430
5.96k
        substMap = SubstitutionMap::get(
431
5.96k
            newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
432
5.96k
            LookUpConformanceInModule(builder.getModule().getSwiftModule()));
433
5.96k
      }
434
6.35k
    }
435
6.96k
    return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs,
436
6.96k
                                      ParameterConvention::Direct_Guaranteed);
437
6.96k
  }
438
0
  llvm_unreachable("Unhandled function conversion instruction");
439
0
}
440
441
/// Emits a reference to a derivative function of `original`, differentiated
442
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
443
/// the derivative function and the actual indices that the derivative function
444
/// is with respect to.
445
///
446
/// Returns `None` on failure, signifying that a diagnostic has been emitted
447
/// using `invoker`.
448
static llvm::Optional<std::pair<SILValue, AutoDiffConfig>>
449
emitDerivativeFunctionReference(
450
    DifferentiationTransformer &transformer, SILBuilder &builder,
451
    const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
452
    SILValue original, DifferentiationInvoker invoker,
453
23.2k
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
454
23.2k
  ADContext &context = transformer.getContext();
455
456
  // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
457
  // matches the given kind and desired differentiation parameter indices,
458
  // simply extract the derivative function of its function operand, retain the
459
  // derivative function, and return it.
460
23.2k
  if (auto *inst = original->getDefiningInstruction())
461
23.2k
    if (auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(inst))
462
8
      if (dfei->getExtractee() ==
463
8
          NormalDifferentiableFunctionTypeComponent::Original)
464
8
        original = dfei->getOperand();
465
466
  // If `original` is a `@differentiable` function, just extract the
467
  // derivative function.
468
23.2k
  if (auto diffableFnType = original->getType().castTo<SILFunctionType>()) {
469
23.2k
    if (diffableFnType->isDifferentiable()) {
470
8
      auto paramIndices =
471
8
          diffableFnType->getDifferentiabilityParameterIndices();
472
8
      for (auto i : desiredConfig.parameterIndices->getIndices()) {
473
8
        if (!paramIndices->contains(i)) {
474
0
          context.emitNondifferentiabilityError(
475
0
              original, invoker,
476
0
              diag::
477
0
                  autodiff_function_noderivative_parameter_not_differentiable);
478
0
          return llvm::None;
479
0
        }
480
8
      }
481
8
      auto borrowedDiffFunc =
482
8
          builder.emitBeginBorrowOperation(original.getLoc(), original);
483
8
      SILValue derivativeFn = builder.createDifferentiableFunctionExtract(
484
8
          borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc);
485
8
      if (derivativeFn->getOwnershipKind() != OwnershipKind::None)
486
0
        derivativeFn =
487
0
            builder.emitCopyValueOperation(original.getLoc(), derivativeFn);
488
8
      builder.emitEndBorrowOperation(original.getLoc(), borrowedDiffFunc);
489
8
      return std::make_pair(derivativeFn, desiredConfig);
490
8
    }
491
23.2k
  }
492
493
  // Handle `function_ref` original function.
494
23.2k
  if (auto *originalFRI =
495
23.2k
          peerThroughFunctionConversions<FunctionRefInst>(original)) {
496
22.5k
    auto loc = originalFRI->getLoc();
497
22.5k
    auto *originalFn = originalFRI->getReferencedFunction();
498
22.5k
    auto originalFnTy = originalFn->getLoweredFunctionType();
499
22.5k
    auto *desiredParameterIndices = desiredConfig.parameterIndices;
500
22.5k
    auto *desiredResultIndices = desiredConfig.resultIndices;
501
    // NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
502
    // parameters corresponding to captured variables.
503
    // TODO: If possible, change `autodiff::getLoweredParameterIndices` to
504
    // take `CaptureInfo` into account.
505
22.5k
    if (originalFnTy->getNumParameters() >
506
22.5k
        desiredParameterIndices->getCapacity()) {
507
712
      desiredParameterIndices = desiredParameterIndices->extendingCapacity(
508
712
          context.getASTContext(), originalFnTy->getNumParameters());
509
712
    }
510
    // Look up a differentiability witness with the exact configuration.
511
22.5k
    auto *minimalWitness = getExactDifferentiabilityWitness(
512
22.5k
        context.getModule(), originalFn, desiredParameterIndices,
513
22.5k
        desiredResultIndices);
514
    // Otherwise, look up a differentiability witness with a minimal superset
515
    // configuration.
516
22.5k
    if (!minimalWitness)
517
6.34k
      minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
518
6.34k
          context.getModule(), originalFn, DifferentiabilityKind::Reverse,
519
6.34k
          desiredParameterIndices, desiredResultIndices);
520
    // If no minimal witness exists, check non-differentiable cases before
521
    // creating a new private differentiability witness.
522
22.5k
    if (!minimalWitness) {
523
      // If the function is intentionally marked as being opaque to
524
      // differentiation, then we should not create a task for it.
525
3.40k
      if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
526
0
        context.emitNondifferentiabilityError(
527
0
            original, invoker,
528
0
            diag::autodiff_opaque_function_not_differentiable);
529
0
        return llvm::None;
530
0
      }
531
      // Check and diagnose non-differentiable arguments.
532
3.40k
      auto originalFnTy = originalFn->getLoweredFunctionType();
533
5.63k
      for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
534
5.63k
        if (desiredConfig.isWrtParameter(paramIndex) &&
535
5.63k
            !originalFnTy->getParameters()[paramIndex]
536
4.90k
                 .getSILStorageInterfaceType()
537
4.90k
                 .isDifferentiable(context.getModule())) {
538
0
          auto diag = context.emitNondifferentiabilityError(
539
0
              original, invoker, diag::autodiff_nondifferentiable_argument);
540
0
          return llvm::None;
541
0
        }
542
5.63k
      }
543
      // Check and diagnose non-differentiable results.
544
3.46k
      for (auto resultIndex : desiredResultIndices->getIndices()) {
545
3.46k
        SILType resultType;
546
3.46k
        if (resultIndex >= originalFnTy->getNumResults()) {
547
200
          auto semanticResultParamIdx = resultIndex - originalFnTy->getNumResults();
548
200
          auto semanticResultParam =
549
200
              *std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
550
200
                         semanticResultParamIdx);
551
200
          resultType = semanticResultParam.getSILStorageInterfaceType();
552
3.26k
        } else {
553
3.26k
          resultType = originalFnTy->getResults()[resultIndex]
554
3.26k
                           .getSILStorageInterfaceType();
555
3.26k
        }
556
3.46k
        if (!resultType.isDifferentiable(context.getModule())) {
557
0
          context.emitNondifferentiabilityError(
558
0
              original, invoker, diag::autodiff_nondifferentiable_result);
559
0
          return llvm::None;
560
0
        }
561
3.46k
      }
562
      // Check and diagnose external declarations.
563
3.40k
      if (originalFn->isExternalDeclaration()) {
564
0
        context.emitNondifferentiabilityError(
565
0
            original, invoker,
566
0
            diag::autodiff_external_nondifferentiable_function);
567
0
        return llvm::None;
568
0
      }
569
      // Sanity check passed. Create a new differentiability witness and
570
      // canonicalize it.
571
3.40k
      GenericSignature contextualDerivativeGenSig = GenericSignature();
572
3.40k
      if (invoker.getKind() ==
573
3.40k
          DifferentiationInvoker::Kind::IndirectDifferentiation)
574
0
        contextualDerivativeGenSig =
575
0
            invoker.getIndirectDifferentiation()
576
0
                .second->getDerivativeGenericSignature();
577
3.40k
      auto derivativeConstrainedGenSig =
578
3.40k
          autodiff::getConstrainedDerivativeGenericSignature(
579
3.40k
              originalFn->getLoweredFunctionType(),
580
3.40k
              desiredParameterIndices, desiredResultIndices,
581
3.40k
              contextualDerivativeGenSig,
582
3.40k
              LookUpConformanceInModule(context.getModule().getSwiftModule()));
583
3.40k
      minimalWitness = SILDifferentiabilityWitness::createDefinition(
584
3.40k
          context.getModule(), SILLinkage::Private, originalFn,
585
3.40k
          DifferentiabilityKind::Reverse, desiredParameterIndices,
586
3.40k
          desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
587
3.40k
          /*vjp*/ nullptr, /*isSerialized*/ false);
588
3.40k
      if (transformer.canonicalizeDifferentiabilityWitness(
589
3.40k
              minimalWitness, invoker, IsNotSerialized))
590
12
        return llvm::None;
591
3.40k
    }
592
22.5k
    assert(minimalWitness);
593
22.5k
    if (original->getFunction()->isSerialized() &&
594
22.5k
        !hasPublicVisibility(minimalWitness->getLinkage())) {
595
16
      enum { Inlinable = 0, DefaultArgument = 1 };
596
16
      unsigned fragileKind = Inlinable;
597
      // FIXME: This is not a very robust way of determining if the function is
598
      // a default argument. Also, we have not exhaustively listed all the kinds
599
      // of fragility.
600
16
      if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
601
8
        fragileKind = DefaultArgument;
602
16
      context.emitNondifferentiabilityError(
603
16
          original, invoker, diag::autodiff_private_derivative_from_fragile,
604
16
          fragileKind,
605
16
          isa_and_nonnull<AbstractClosureExpr>(
606
16
              originalFRI->getLoc().getAsASTNode<Expr>()));
607
16
      return llvm::None;
608
16
    }
609
    // TODO(TF-482): Move generic requirement checking logic to
610
    // `getExactDifferentiabilityWitness` and
611
    // `getOrCreateMinimalASTDifferentiabilityWitness`.
612
    // Get the substitution map for checking unmet generic requirements.
613
    // By default, use the forwarding substitution map of the original function.
614
    // If the original callee is a `partial_apply` or `apply` instruction, use
615
    // its substitution map instead.
616
22.5k
    auto substMap = original->getFunction()->getForwardingSubstitutionMap();
617
22.5k
    if (auto *pai =
618
22.5k
            peerThroughFunctionConversions<PartialApplyInst>(original)) {
619
6.58k
      substMap = pai->getSubstitutionMap();
620
15.9k
    } else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original)) {
621
0
      substMap = ai->getSubstitutionMap();
622
0
    }
623
22.5k
    if (diagnoseUnsatisfiedRequirements(
624
22.5k
            context, original->getType().castTo<SILFunctionType>(),
625
22.5k
            minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
626
22.5k
            original.getLoc().getSourceLoc()))
627
4
      return llvm::None;
628
22.5k
    DifferentiabilityWitnessFunctionKind witnessKind;
629
22.5k
    switch (kind) {
630
11.2k
    case AutoDiffDerivativeFunctionKind::JVP:
631
11.2k
      witnessKind = DifferentiabilityWitnessFunctionKind::JVP;
632
11.2k
      break;
633
11.2k
    case AutoDiffDerivativeFunctionKind::VJP:
634
11.2k
      witnessKind = DifferentiabilityWitnessFunctionKind::VJP;
635
11.2k
      break;
636
22.5k
    }
637
22.5k
    auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction(
638
22.5k
        loc, witnessKind, minimalWitness);
639
22.5k
    auto convertedRef = reapplyFunctionConversion(
640
22.5k
        context, derivativeFnRef, originalFRI, original, builder, loc,
641
22.5k
        newBuffersToDealloc, desiredConfig.parameterIndices,
642
22.5k
        desiredConfig.resultIndices,
643
22.5k
        derivativeFnRef->getType()
644
22.5k
            .getASTType()
645
22.5k
            ->castTo<SILFunctionType>()
646
22.5k
            ->getSubstGenericSignature());
647
22.5k
    return std::make_pair(convertedRef, minimalWitness->getConfig());
648
22.5k
  }
649
650
  // Handle `witness_method`.
651
684
  if (auto *witnessMethod =
652
684
          peerThroughFunctionConversions<WitnessMethodInst>(original)) {
653
372
    auto loc = witnessMethod->getLoc();
654
372
    auto requirementDeclRef = witnessMethod->getMember();
655
372
    auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
656
    // If requirement declaration does not have any derivative function
657
    // configurations, produce an error.
658
372
    if (requirementDecl->getDerivativeFunctionConfigurations().empty()) {
659
4
      context.emitNondifferentiabilityError(
660
4
          original, invoker, diag::autodiff_protocol_member_not_differentiable);
661
4
      return llvm::None;
662
4
    }
663
    // Find the minimal derivative configuration: minimal parameter indices and
664
    // corresponding derivative generic signature. If it does not exist, produce
665
    // an error.
666
368
    IndexSubset *minimalASTParamIndices = nullptr;
667
368
    auto minimalConfig = findMinimalDerivativeConfiguration(
668
368
        requirementDecl, desiredConfig.parameterIndices,
669
368
        minimalASTParamIndices);
670
368
    if (!minimalConfig) {
671
0
      context.emitNondifferentiabilityError(
672
0
          original, invoker,
673
0
          diag::autodiff_member_subset_indices_not_differentiable);
674
0
      return llvm::None;
675
0
    }
676
    // Emit a `witness_method` instruction for the derivative function.
677
368
    auto originalType = witnessMethod->getType().castTo<SILFunctionType>();
678
368
    auto assocType = originalType->getAutoDiffDerivativeFunctionType(
679
368
        minimalConfig->parameterIndices, minimalConfig->resultIndices, kind,
680
368
        context.getTypeConverter(),
681
368
        LookUpConformanceInModule(builder.getModule().getSwiftModule()));
682
368
    auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
683
368
        kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature,
684
368
        context.getASTContext());
685
368
    auto *ref = builder.createWitnessMethod(
686
368
        loc, witnessMethod->getLookupType(), witnessMethod->getConformance(),
687
368
        requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
688
368
        SILType::getPrimitiveObjectType(assocType));
689
368
    auto convertedRef = reapplyFunctionConversion(
690
368
        context, ref, witnessMethod, original, builder, loc,
691
368
        newBuffersToDealloc, desiredConfig.parameterIndices,
692
368
        desiredConfig.resultIndices);
693
368
    return std::make_pair(convertedRef, *minimalConfig);
694
368
  }
695
696
  // Handle `class_method`.
697
312
  if (auto *classMethod =
698
312
          peerThroughFunctionConversions<ClassMethodInst>(original)) {
699
312
    auto loc = classMethod->getLoc();
700
312
    auto methodDeclRef = classMethod->getMember();
701
312
    auto *methodDecl = methodDeclRef.getAbstractFunctionDecl();
702
    // If method declaration does not have any derivative function
703
    // configurations, produce an error.
704
312
    if (methodDecl->getDerivativeFunctionConfigurations().empty()) {
705
8
      context.emitNondifferentiabilityError(
706
8
          original, invoker, diag::autodiff_class_member_not_differentiable);
707
8
      return llvm::None;
708
8
    }
709
    // Find the minimal derivative configuration: minimal parameter indices and
710
    // corresponding derivative generic signature. If it does not exist, produce
711
    // an error.
712
304
    IndexSubset *minimalASTParamIndices = nullptr;
713
304
    auto minimalConfig = findMinimalDerivativeConfiguration(
714
304
        methodDecl, desiredConfig.parameterIndices, minimalASTParamIndices);
715
304
    if (!minimalConfig) {
716
0
      context.emitNondifferentiabilityError(
717
0
          original, invoker,
718
0
          diag::autodiff_member_subset_indices_not_differentiable);
719
0
      return llvm::None;
720
0
    }
721
    // Emit a `class_method` instruction for the derivative function.
722
304
    auto originalType = classMethod->getType().castTo<SILFunctionType>();
723
304
    auto assocType = originalType->getAutoDiffDerivativeFunctionType(
724
304
        minimalConfig->parameterIndices, minimalConfig->resultIndices, kind,
725
304
        context.getTypeConverter(),
726
304
        LookUpConformanceInModule(builder.getModule().getSwiftModule()));
727
304
    auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
728
304
        kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature,
729
304
        context.getASTContext());
730
304
    auto *ref = builder.createClassMethod(
731
304
        loc, classMethod->getOperand(),
732
304
        methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
733
304
        SILType::getPrimitiveObjectType(assocType));
734
304
    auto convertedRef = reapplyFunctionConversion(
735
304
        context, ref, classMethod, original, builder, loc, newBuffersToDealloc,
736
304
        desiredConfig.parameterIndices, desiredConfig.resultIndices);
737
304
    return std::make_pair(convertedRef, *minimalConfig);
738
304
  }
739
740
  // Emit the general opaque function error.
741
0
  context.emitNondifferentiabilityError(
742
0
      original, invoker, diag::autodiff_opaque_function_not_differentiable);
743
0
  return llvm::None;
744
312
}
745
746
//===----------------------------------------------------------------------===//
747
// `SILDifferentiabilityWitness` processing
748
//===----------------------------------------------------------------------===//
749
750
static SILFunction *createEmptyVJP(ADContext &context,
751
                                   SILDifferentiabilityWitness *witness,
752
5.25k
                                   IsSerialized_t isSerialized) {
753
5.25k
  auto original = witness->getOriginalFunction();
754
5.25k
  auto config = witness->getConfig();
755
5.25k
  LLVM_DEBUG({
756
5.25k
    auto &s = getADDebugStream();
757
5.25k
    s << "Creating VJP for " << original->getName() << ":\n\t";
758
5.25k
    s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
759
5.25k
    s << "Config: " << config << "\n\t";
760
5.25k
  });
761
762
5.25k
  auto &module = context.getModule();
763
5.25k
  auto originalTy = original->getLoweredFunctionType();
764
765
  // === Create an empty VJP. ===
766
5.25k
  Mangle::DifferentiationMangler mangler;
767
5.25k
  auto vjpName = mangler.mangleDerivativeFunction(
768
5.25k
      original->getName(), AutoDiffDerivativeFunctionKind::VJP, config);
769
5.25k
  auto vjpCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature();
770
5.25k
  GenericEnvironment *vjpGenericEnv = nullptr;
771
5.25k
  if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete())
772
940
    vjpGenericEnv = vjpCanGenSig.getGenericEnvironment();
773
5.25k
  auto vjpType = originalTy->getAutoDiffDerivativeFunctionType(
774
5.25k
      config.parameterIndices, config.resultIndices,
775
5.25k
      AutoDiffDerivativeFunctionKind::VJP,
776
5.25k
      module.Types, LookUpConformanceInModule(module.getSwiftModule()),
777
5.25k
      vjpCanGenSig,
778
5.25k
      /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);
779
780
5.25k
  SILOptFunctionBuilder fb(context.getTransform());
781
5.25k
  auto *vjp = fb.createFunction(
782
5.25k
      witness->getLinkage(),
783
5.25k
      context.getASTContext().getIdentifier(vjpName).str(), vjpType,
784
5.25k
      vjpGenericEnv, original->getLocation(), original->isBare(),
785
5.25k
      IsNotTransparent, isSerialized, original->isDynamicallyReplaceable(),
786
5.25k
      original->isDistributed(),
787
5.25k
      original->isRuntimeAccessible());
788
5.25k
  vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp));
789
790
5.25k
  LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType()
791
5.25k
                          << "\n");
792
5.25k
  return vjp;
793
5.25k
}
794
795
static SILFunction *createEmptyJVP(ADContext &context,
796
                                   SILDifferentiabilityWitness *witness,
797
5.50k
                                   IsSerialized_t isSerialized) {
798
5.50k
  auto original = witness->getOriginalFunction();
799
5.50k
  auto config = witness->getConfig();
800
5.50k
  LLVM_DEBUG({
801
5.50k
    auto &s = getADDebugStream();
802
5.50k
    s << "Creating JVP for " << original->getName() << ":\n\t";
803
5.50k
    s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
804
5.50k
    s << "Config: " << config << "\n\t";
805
5.50k
  });
806
807
5.50k
  auto &module = context.getModule();
808
5.50k
  auto originalTy = original->getLoweredFunctionType();
809
810
5.50k
  Mangle::DifferentiationMangler mangler;
811
5.50k
  auto jvpName = mangler.mangleDerivativeFunction(
812
5.50k
      original->getName(), AutoDiffDerivativeFunctionKind::JVP, config);
813
5.50k
  auto jvpCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature();
814
5.50k
  GenericEnvironment *jvpGenericEnv = nullptr;
815
5.50k
  if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete())
816
1.00k
    jvpGenericEnv = jvpCanGenSig.getGenericEnvironment();
817
5.50k
  auto jvpType = originalTy->getAutoDiffDerivativeFunctionType(
818
5.50k
      config.parameterIndices, config.resultIndices,
819
5.50k
      AutoDiffDerivativeFunctionKind::JVP,
820
5.50k
      module.Types, LookUpConformanceInModule(module.getSwiftModule()),
821
5.50k
      jvpCanGenSig,
822
5.50k
      /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);
823
824
5.50k
  SILOptFunctionBuilder fb(context.getTransform());
825
5.50k
  auto *jvp = fb.createFunction(
826
5.50k
      witness->getLinkage(),
827
5.50k
      context.getASTContext().getIdentifier(jvpName).str(), jvpType,
828
5.50k
      jvpGenericEnv, original->getLocation(), original->isBare(),
829
5.50k
      IsNotTransparent, isSerialized, original->isDynamicallyReplaceable(),
830
5.50k
      original->isDistributed(),
831
5.50k
      original->isRuntimeAccessible());
832
5.50k
  jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp));
833
834
5.50k
  LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType()
835
5.50k
                          << "\n");
836
5.50k
  return jvp;
837
5.50k
}
838
839
/// Apply the fatal error function with the given name of type
840
/// `@convention(thin) () -> Never` in `f`.
841
static void emitFatalError(ADContext &context, SILFunction *f,
842
4.14k
                           StringRef fatalErrorFuncName) {
843
4.14k
  auto *entry = f->createBasicBlock();
844
4.14k
  createEntryArguments(f);
845
4.14k
  SILBuilder builder(entry);
846
4.14k
  auto loc = f->getLocation();
847
  // Destroy all owned arguments to pass ownership verification.
848
4.14k
  for (auto *arg : entry->getArguments())
849
8.05k
    if (arg->getOwnershipKind() == OwnershipKind::Owned)
850
96
      builder.emitDestroyOperation(loc, arg);
851
  // Fatal error with a nice message.
852
4.14k
  auto neverTy =
853
4.14k
      context.getModule().getASTContext().getNeverType()->getCanonicalType();
854
4.14k
  auto neverResultInfo = SILResultInfo(neverTy, ResultConvention::Unowned);
855
  // Fatal error function must have type `@convention(thin) () -> Never`.
856
4.14k
  auto fatalErrorFnType = SILFunctionType::get(
857
4.14k
      /*genericSig*/ nullptr, SILFunctionType::ExtInfo::getThin(),
858
4.14k
      SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {},
859
4.14k
      /*interfaceYields*/ {}, neverResultInfo,
860
4.14k
      /*interfaceErrorResults*/ llvm::None, {}, {}, context.getASTContext());
861
4.14k
  auto fnBuilder = SILOptFunctionBuilder(context.getTransform());
862
4.14k
  auto *fatalErrorFn = fnBuilder.getOrCreateFunction(
863
4.14k
      loc, fatalErrorFuncName, SILLinkage::PublicExternal, fatalErrorFnType,
864
4.14k
      IsNotBare, IsNotTransparent, IsNotSerialized, IsNotDynamic,
865
4.14k
      IsNotDistributed, IsNotRuntimeAccessible, ProfileCounter(), IsNotThunk);
866
4.14k
  auto *fatalErrorFnRef = builder.createFunctionRef(loc, fatalErrorFn);
867
4.14k
  builder.createApply(loc, fatalErrorFnRef, SubstitutionMap(), {});
868
4.14k
  builder.createUnreachable(loc);
869
4.14k
}
870
871
/// Returns true on error.
872
bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
873
    SILDifferentiabilityWitness *witness, DifferentiationInvoker invoker,
874
5.66k
    IsSerialized_t serializeFunctions) {
875
5.66k
  std::string traceMessage;
876
5.66k
  llvm::raw_string_ostream OS(traceMessage);
877
5.66k
  OS << "processing ";
878
5.66k
  witness->print(OS);
879
5.66k
  OS << " on";
880
5.66k
  OS.flush();
881
5.66k
  PrettyStackTraceSILFunction trace(
882
5.66k
      traceMessage.c_str(), witness->getOriginalFunction());
883
884
5.66k
  assert(witness->isDefinition());
885
886
  // If the JVP doesn't exist, need to synthesize it.
887
5.66k
  if (!witness->getJVP()) {
888
    // Diagnose:
889
    // - Functions with no return.
890
    // - Functions with unsupported control flow.
891
5.50k
    if (context.getASTContext()
892
5.50k
            .LangOpts.hasFeature(Feature::ForwardModeDifferentiation) &&
893
5.50k
        (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) ||
894
1.36k
         diagnoseUnsupportedControlFlow(
895
1.36k
             context, witness->getOriginalFunction(), invoker)))
896
0
      return true;
897
898
    // Create empty JVP.
899
5.50k
    auto *jvp = createEmptyJVP(context, witness, serializeFunctions);
900
5.50k
    witness->setJVP(jvp);
901
5.50k
    context.recordGeneratedFunction(jvp);
902
903
    // For now, only do JVP generation if the flag is enabled and if custom VJP
904
    // does not exist. If custom VJP exists but custom JVP does not, skip JVP
905
    // generation because generated JVP may not match semantics of custom VJP.
906
    // Instead, create an empty JVP.
907
5.50k
    if (context.getASTContext()
908
5.50k
            .LangOpts.hasFeature(Feature::ForwardModeDifferentiation) &&
909
5.50k
        !witness->getVJP()) {
910
      // JVP and differential generation do not currently support functions with
911
      // multiple basic blocks.
912
1.36k
      if (witness->getOriginalFunction()->size() > 1) {
913
8
        context.emitNondifferentiabilityError(
914
8
            witness->getOriginalFunction()->getLocation().getSourceLoc(),
915
8
            invoker, diag::autodiff_jvp_control_flow_not_supported);
916
8
        return true;
917
8
      }
918
      // Emit JVP function.
919
1.35k
      JVPCloner cloner(context, witness, jvp, invoker);
920
1.35k
      if (cloner.run())
921
20
        return true;
922
4.14k
    } else {
923
      // If JVP generation is disabled or a user-defined custom VJP function
924
      // exists, fatal error with a nice message.
925
4.14k
      emitFatalError(context, jvp,
926
4.14k
                     "_fatalErrorForwardModeDifferentiationDisabled");
927
4.14k
      LLVM_DEBUG(getADDebugStream()
928
4.14k
                 << "Generated empty JVP for "
929
4.14k
                 << witness->getOriginalFunction()->getName() << ":\n"
930
4.14k
                 << *jvp);
931
4.14k
    }
932
5.50k
  }
933
934
  // If the VJP doesn't exist, need to synthesize it.
935
5.63k
  if (!witness->getVJP()) {
936
    // Diagnose:
937
    // - Functions with no return.
938
    // - Functions with unsupported control flow.
939
5.26k
    if (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) ||
940
5.26k
        diagnoseUnsupportedControlFlow(
941
5.25k
            context, witness->getOriginalFunction(), invoker))
942
4
      return true;
943
944
    // Create empty VJP.
945
5.25k
    auto *vjp = createEmptyVJP(context, witness, serializeFunctions);
946
5.25k
    witness->setVJP(vjp);
947
5.25k
    context.recordGeneratedFunction(vjp);
948
    // Emit VJP function.
949
5.25k
    VJPCloner cloner(context, witness, vjp, invoker);
950
5.25k
    return cloner.run();
951
5.26k
  }
952
376
  return false;
953
5.63k
}
954
955
//===----------------------------------------------------------------------===//
956
// Differentiation pass implementation
957
//===----------------------------------------------------------------------===//
958
959
/// The automatic differentiation pass.
960
namespace {
961
class Differentiation : public SILModuleTransform {
962
public:
963
24.3k
  Differentiation() : SILModuleTransform() {}
964
  void run() override;
965
};
966
} // end anonymous namespace
967
968
/// Given a curry thunk application, clone the thunk to return a
969
/// `@differentiable` function-typed value and apply the cloned thunk.
970
///
971
/// Curry thunk type: `(Self) -> (T, ...) -> U`.
972
/// Cloned thunk type: `(Self) -> @differentiable (T, ...) -> U`.
973
static SILValue promoteCurryThunkApplicationToDifferentiableFunction(
974
    DifferentiationTransformer &dt, DifferentiableFunctionInst *dfi,
975
11.6k
    SILBuilder &builder, SILLocation loc, DifferentiationInvoker invoker) {
976
11.6k
  auto origFnOperand = dfi->getOriginalFunction();
977
11.6k
  auto *parameterIndices = dfi->getParameterIndices();
978
11.6k
  auto *resultIndices = dfi->getResultIndices();
979
11.6k
  auto &context = dt.getContext();
980
981
  // Check for curry thunk application:
982
  // - The original function operand must be an `apply` instruction.
983
  // - The `apply` callee must be a `function_ref` instruction.
984
  // - The callee must return a function-typed value.
985
11.6k
  auto *ai = dyn_cast<ApplyInst>(origFnOperand);
986
11.6k
  if (!ai)
987
11.6k
    return nullptr;
988
4
  auto *thunkRef = dyn_cast<FunctionRefInst>(ai->getCallee());
989
4
  if (!thunkRef)
990
0
    return nullptr;
991
4
  auto *thunk = thunkRef->getReferencedFunction();
992
4
  auto thunkTy = thunk->getLoweredFunctionType();
993
4
  auto thunkResult = thunkTy->getSingleResult();
994
4
  auto resultFnTy = thunkResult.getInterfaceType()->getAs<SILFunctionType>();
995
4
  if (!resultFnTy)
996
0
    return nullptr;
997
998
  // Create a new curry thunk.
999
4
  AutoDiffConfig desiredConfig(parameterIndices, resultIndices);
1000
  // TODO(TF-685): Use more principled mangling for thunks.
1001
4
  auto newThunkName = "AD__" + thunk->getName().str() +
1002
4
                      "__differentiable_curry_thunk_" + desiredConfig.mangle();
1003
1004
  // Construct new curry thunk type with `@differentiable` function
1005
  // result.
1006
4
  auto diffResultFnTy = resultFnTy->getWithExtInfo(
1007
4
      resultFnTy->getExtInfo()
1008
4
          .intoBuilder()
1009
4
          .withDifferentiabilityKind(DifferentiabilityKind::Reverse)
1010
4
          .build());
1011
4
  auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy);
1012
4
  auto thunkType = SILFunctionType::get(
1013
4
      thunkTy->getSubstGenericSignature(), thunkTy->getExtInfo(),
1014
4
      thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(),
1015
4
      thunkTy->getParameters(), {}, {newThunkResult}, {},
1016
4
      thunkTy->getPatternSubstitutions(), thunkTy->getInvocationSubstitutions(),
1017
4
      thunkTy->getASTContext());
1018
1019
  // Construct new curry thunk, returning a `@differentiable` function.
1020
4
  SILOptFunctionBuilder fb(dt.getTransform());
1021
4
  auto *newThunk = fb.getOrCreateFunction(
1022
4
      loc, newThunkName, getSpecializedLinkage(thunk, thunk->getLinkage()),
1023
4
      thunkType, thunk->isBare(), thunk->isTransparent(), thunk->isSerialized(),
1024
4
      thunk->isDynamicallyReplaceable(), thunk->isDistributed(),
1025
4
      thunk->isRuntimeAccessible(),
1026
4
      ProfileCounter(), thunk->isThunk());
1027
  // If new thunk is newly created: clone the old thunk body, wrap the
1028
  // returned function value with an `differentiable_function`
1029
  // instruction, and process the `differentiable_function` instruction.
1030
4
  if (newThunk->empty()) {
1031
4
    newThunk->setGenericEnvironment(thunkType->getSubstGenericSignature().getGenericEnvironment());
1032
1033
4
    BasicTypeSubstCloner cloner(thunk, newThunk);
1034
4
    cloner.cloneFunction();
1035
4
    auto *retInst = cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
1036
4
    auto returnValue = retInst->getOperand();
1037
    // Create `differentiable_function` instruction directly after the
1038
    // defining instruction (e.g. `partial_apply`) of the returned value.
1039
    // Note: `differentiable_function` is not created at the end of the
1040
    // new thunk to avoid `alloc_stack`/`dealloc_stack` ordering issues.
1041
4
    SILBuilderWithScope dfiBuilder(
1042
4
        std::next(returnValue->getDefiningInstruction()->getIterator()));
1043
4
    auto *dfi = context.createDifferentiableFunction(
1044
4
        dfiBuilder, loc, parameterIndices, resultIndices, returnValue);
1045
4
    dfiBuilder.setInsertionPoint(newThunk->findReturnBB());
1046
4
    dfiBuilder.createReturn(loc, dfi);
1047
4
    retInst->eraseFromParent();
1048
1049
4
    context.recordGeneratedFunction(newThunk);
1050
4
    context.getDifferentiableFunctionInstWorklist().push_back(dfi);
1051
4
    if (dt.processDifferentiableFunctionInst(dfi))
1052
0
      return nullptr;
1053
4
  }
1054
1055
  // Apply the new curry thunk.
1056
4
  auto *newThunkRef = builder.createFunctionRef(loc, newThunk);
1057
4
  context.recordGeneratedFunctionReference(newThunkRef);
1058
4
  SmallVector<SILValue, 8> newArgs;
1059
4
  SmallVector<SILValue, 8> newArgsToDestroy;
1060
4
  SmallVector<AllocStackInst *, 1> newBuffersToDealloc;
1061
4
  copyParameterArgumentsForApply(ai, newArgs, newArgsToDestroy,
1062
4
                                 newBuffersToDealloc);
1063
4
  auto *newApply = builder.createApply(
1064
4
      loc, newThunkRef, ai->getSubstitutionMap(), newArgs,
1065
4
      ai->getApplyOptions());
1066
4
  for (auto arg : newArgsToDestroy)
1067
0
    builder.emitDestroyOperation(loc, arg);
1068
4
  for (auto *alloc : newBuffersToDealloc)
1069
0
    builder.createDeallocStack(loc, alloc);
1070
4
  return newApply;
1071
4
}
1072
1073
SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1074
    DifferentiableFunctionInst *dfi, SILBuilder &builder, SILLocation loc,
1075
11.6k
    DifferentiationInvoker invoker) {
1076
11.6k
  auto &astCtx = context.getASTContext();
1077
11.6k
  auto origFnOperand = dfi->getOriginalFunction();
1078
11.6k
  auto origFnTy = origFnOperand->getType().castTo<SILFunctionType>();
1079
11.6k
  auto *parameterIndices = dfi->getParameterIndices();
1080
11.6k
  auto *resultIndices = dfi->getResultIndices();
1081
1082
11.6k
  if (auto diffFn = promoteCurryThunkApplicationToDifferentiableFunction(
1083
11.6k
          *this, dfi, builder, loc, invoker))
1084
4
    return diffFn;
1085
1086
11.6k
  AutoDiffConfig desiredConfig(parameterIndices, resultIndices);
1087
11.6k
  SmallVector<SILValue, 2> derivativeFns;
1088
11.6k
  SmallVector<AllocStackInst *, 2> newBuffersToDealloc;
1089
11.6k
  for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
1090
23.2k
                                AutoDiffDerivativeFunctionKind::VJP}) {
1091
23.2k
    auto derivativeFnAndIndices = emitDerivativeFunctionReference(
1092
23.2k
        *this, builder, desiredConfig, derivativeFnKind, origFnOperand,
1093
23.2k
        invoker, newBuffersToDealloc);
1094
    // Show an error at the operator, highlight the argument, and show a note
1095
    // at the definition site of the argument.
1096
23.2k
    if (!derivativeFnAndIndices)
1097
44
      return nullptr;
1098
1099
23.2k
    auto derivativeFn = derivativeFnAndIndices->first;
1100
23.2k
    context.recordGeneratedFunctionReference(derivativeFn);
1101
1102
    // If desired indices are a subset of actual indices, create a "subset
1103
    // indices thunk" and destroy the emitted derivative function reference.
1104
    // - For JVPs: the thunked JVP returns a differential taking fewer
1105
    //   parameters (using `.zero` for the dropped parameters).
1106
    // - For VJPs: the thunked VJP returns a pullback that drops the unused
1107
    //   tangent values.
1108
23.2k
    auto actualConfig = derivativeFnAndIndices->second;
1109
    // NOTE: `desiredIndices` may come from a partially-applied function and
1110
    // have smaller capacity than `actualIndices`. We expect this logic to go
1111
    // away when we support `@differentiable` partial apply.
1112
    // if (actualIndices != desiredIndices) { // TODO: Re-enable.
1113
23.2k
    auto extendedDesiredParameterIndices =
1114
23.2k
        desiredConfig.parameterIndices->extendingCapacity(
1115
23.2k
            astCtx, actualConfig.parameterIndices->getCapacity());
1116
23.2k
    if (!actualConfig.parameterIndices->equals(extendedDesiredParameterIndices)
1117
23.2k
        || !actualConfig.resultIndices->equals(desiredConfig.resultIndices)) {
1118
      // Destroy the already emitted derivative function reference because it
1119
      // is no longer used.
1120
1.81k
      builder.emitDestroyValueOperation(loc, derivativeFn);
1121
      // Check if underlying original function reference has been partially
1122
      // applied with arguments. If so, produce an error: parameter subset
1123
      // thunks do not yet support this case because partially applied arguments
1124
      // cannot be propagated to parameter subset thunks.
1125
1.81k
      auto didPartiallyApplyArguments = [](SILValue original) {
1126
2.40k
        while (auto *pai =
1127
1.81k
                   peerThroughFunctionConversions<PartialApplyInst>(original)) {
1128
584
          if (pai->getNumArguments() > 0)
1129
0
            return true;
1130
584
          original = pai->getCallee();
1131
584
        }
1132
1.81k
        return false;
1133
1.81k
      };
1134
1.81k
      if (didPartiallyApplyArguments(origFnOperand)) {
1135
0
        context.emitNondifferentiabilityError(
1136
0
            origFnOperand, invoker,
1137
0
            diag::autodiff_cannot_param_subset_thunk_partially_applied_orig_fn);
1138
0
        return nullptr;
1139
0
      }
1140
      // Create the parameter subset thunk.
1141
1.81k
      assert(actualConfig.parameterIndices->isSupersetOf(
1142
1.81k
          extendedDesiredParameterIndices));
1143
0
      SILFunction *thunk;
1144
1.81k
      SubstitutionMap interfaceSubs;
1145
1.81k
      SILOptFunctionBuilder fb(transform);
1146
1.81k
      std::tie(thunk, interfaceSubs) =
1147
1.81k
          getOrCreateSubsetParametersThunkForDerivativeFunction(
1148
1.81k
              fb, origFnOperand, derivativeFn, derivativeFnKind, desiredConfig,
1149
1.81k
              actualConfig, context);
1150
1.81k
      auto *thunkFRI = builder.createFunctionRef(loc, thunk);
1151
1.81k
      if (auto genSig =
1152
1.81k
              thunk->getLoweredFunctionType()->getSubstGenericSignature()) {
1153
96
        derivativeFn =
1154
96
            builder.createPartialApply(loc, thunkFRI, interfaceSubs, {},
1155
96
                                       ParameterConvention::Direct_Guaranteed);
1156
1.72k
      } else {
1157
1.72k
        derivativeFn = thunkFRI;
1158
1.72k
      }
1159
1.81k
    }
1160
23.2k
    auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
1161
23.2k
        parameterIndices, resultIndices, derivativeFnKind,
1162
23.2k
        context.getTypeConverter(),
1163
23.2k
        LookUpConformanceInModule(context.getModule().getSwiftModule()));
1164
    // If `derivativeFn` is `@convention(thin)` but is expected to be
1165
    // `@convention(thick)`, emit a `thin_to_thick` instruction.
1166
23.2k
    if (expectedDerivativeFnTy->getRepresentation() ==
1167
23.2k
            SILFunctionTypeRepresentation::Thick &&
1168
23.2k
        derivativeFn->getType()
1169
13.2k
                .castTo<SILFunctionType>()
1170
13.2k
                ->getRepresentation() == SILFunctionTypeRepresentation::Thin) {
1171
488
      derivativeFn = builder.createThinToThickFunction(
1172
488
          loc, derivativeFn,
1173
488
          SILType::getPrimitiveObjectType(expectedDerivativeFnTy));
1174
488
    }
1175
    // If derivative function value's type is not ABI-compatible with the
1176
    // expected derivative function type (i.e. parameter and result conventions
1177
    // do not match), perform reabstraction.
1178
23.2k
    auto abiCompatibility = expectedDerivativeFnTy->isABICompatibleWith(
1179
23.2k
        derivativeFn->getType().castTo<SILFunctionType>(), *dfi->getFunction());
1180
23.2k
    if (!abiCompatibility.isCompatible()) {
1181
96
      SILOptFunctionBuilder fb(context.getTransform());
1182
96
      auto newDerivativeFn = reabstractFunction(
1183
96
          builder, fb, loc, derivativeFn, expectedDerivativeFnTy,
1184
96
          [](SubstitutionMap substMap) { return substMap; });
1185
96
      derivativeFn = newDerivativeFn;
1186
96
      assert(expectedDerivativeFnTy
1187
96
                 ->isABICompatibleWith(
1188
96
                     derivativeFn->getType().castTo<SILFunctionType>(),
1189
96
                     *dfi->getFunction())
1190
96
                 .isCompatible());
1191
96
    }
1192
1193
0
    derivativeFns.push_back(derivativeFn);
1194
23.2k
  }
1195
  // Deallocate temporary buffers used for creating derivative functions.
1196
11.6k
  for (auto *buf : llvm::reverse(newBuffersToDealloc))
1197
96
    builder.createDeallocStack(loc, buf);
1198
1199
  // If our original copy does not have none ownership, copy it.
1200
11.6k
  if (origFnOperand->getOwnershipKind() != OwnershipKind::None)
1201
3.47k
    origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
1202
11.6k
  auto *newDiffFn = context.createDifferentiableFunction(
1203
11.6k
      builder, loc, parameterIndices, resultIndices, origFnOperand,
1204
11.6k
      std::make_pair(derivativeFns[0], derivativeFns[1]));
1205
11.6k
  context.getDifferentiableFunctionInstWorklist().push_back(dfi);
1206
11.6k
  return newDiffFn;
1207
11.6k
}
1208
1209
SILValue DifferentiationTransformer::promoteToLinearFunction(
1210
    LinearFunctionInst *lfi, SILBuilder &builder, SILLocation loc,
1211
12
    DifferentiationInvoker invoker) {
1212
  // Note: for now, this function creates a new `linear_function` instruction
1213
  // with an undef transpose function operand. Eventually, a legitimate
1214
  // transpose function operand should be created and used.
1215
12
  auto origFnOperand = lfi->getOriginalFunction();
1216
12
  if (origFnOperand->getOwnershipKind() != OwnershipKind::None)
1217
0
    origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
1218
12
  auto *parameterIndices = lfi->getParameterIndices();
1219
12
  auto originalType = origFnOperand->getType().castTo<SILFunctionType>();
1220
12
  auto transposeFnType = originalType->getAutoDiffTransposeFunctionType(
1221
12
      parameterIndices, context.getTypeConverter(),
1222
12
      LookUpConformanceInModule(builder.getModule().getSwiftModule()));
1223
12
  auto transposeType = SILType::getPrimitiveObjectType(transposeFnType);
1224
12
  auto transposeFn = SILUndef::get(transposeType, builder.getFunction());
1225
12
  auto *newLinearFn = context.createLinearFunction(
1226
12
      builder, loc, parameterIndices, origFnOperand, SILValue(transposeFn));
1227
12
  context.getLinearFunctionInstWorklist().push_back(lfi);
1228
12
  return newLinearFn;
1229
12
}
1230
1231
bool DifferentiationTransformer::processDifferentiableFunctionInst(
1232
14.7k
    DifferentiableFunctionInst *dfi) {
1233
14.7k
  PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`",
1234
14.7k
                                   dfi);
1235
14.7k
  PrettyStackTraceSILFunction fnTrace("...in", dfi->getFunction());
1236
14.7k
  LLVM_DEBUG({
1237
14.7k
    auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n";
1238
14.7k
    dfi->printInContext(s);
1239
14.7k
  });
1240
1241
  // If `dfi` already has derivative functions, do not process.
1242
14.7k
  if (dfi->hasDerivativeFunctions())
1243
3.09k
    return false;
1244
1245
11.6k
  SILFunction *parent = dfi->getFunction();
1246
11.6k
  auto loc = dfi->getLoc();
1247
11.6k
  SILBuilderWithScope builder(dfi);
1248
11.6k
  auto differentiableFnValue =
1249
11.6k
      promoteToDifferentiableFunction(dfi, builder, loc, dfi);
1250
  // Mark `dfi` as processed so that it won't be reprocessed after deletion.
1251
11.6k
  context.markDifferentiableFunctionInstAsProcessed(dfi);
1252
11.6k
  if (!differentiableFnValue)
1253
44
    return true;
1254
  // Replace all uses of `dfi`.
1255
11.6k
  dfi->replaceAllUsesWith(differentiableFnValue);
1256
  // Destroy the original operand.
1257
11.6k
  builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction());
1258
11.6k
  dfi->eraseFromParent();
1259
11.6k
  transform.invalidateAnalysis(parent,
1260
11.6k
                               SILAnalysis::InvalidationKind::FunctionBody);
1261
11.6k
  return false;
1262
11.6k
}
1263
1264
bool DifferentiationTransformer::processLinearFunctionInst(
1265
12
    LinearFunctionInst *lfi) {
1266
12
  PrettyStackTraceSILNode dfiTrace("canonicalizing `linear_function`", lfi);
1267
12
  PrettyStackTraceSILFunction fnTrace("...in", lfi->getFunction());
1268
12
  LLVM_DEBUG({
1269
12
    auto &s = getADDebugStream() << "Processing LinearFunctionInst:\n";
1270
12
    lfi->printInContext(s);
1271
12
  });
1272
1273
  // If `lfi` already has a transpose function, do not process.
1274
12
  if (lfi->hasTransposeFunction())
1275
0
    return false;
1276
1277
12
  SILFunction *parent = lfi->getFunction();
1278
12
  auto loc = lfi->getLoc();
1279
12
  SILBuilderWithScope builder(lfi);
1280
12
  auto linearFnValue = promoteToLinearFunction(lfi, builder, loc, lfi);
1281
  // Mark `lfi` as processed so that it won't be reprocessed after deletion.
1282
12
  context.markLinearFunctionInstAsProcessed(lfi);
1283
12
  if (!linearFnValue)
1284
0
    return true;
1285
  // Replace all uses of `lfi`.
1286
12
  lfi->replaceAllUsesWith(linearFnValue);
1287
  // Destroy the original operand.
1288
12
  builder.emitDestroyValueOperation(loc, lfi->getOriginalFunction());
1289
12
  lfi->eraseFromParent();
1290
1291
12
  transform.invalidateAnalysis(parent,
1292
12
                               SILAnalysis::InvalidationKind::FunctionBody);
1293
12
  return false;
1294
12
}
1295
1296
/// Automatic differentiation transform entry.
1297
24.3k
void Differentiation::run() {
1298
24.3k
  auto &module = *getModule();
1299
24.3k
  auto &astCtx = module.getASTContext();
1300
24.3k
  debugDump(module);
1301
1302
  // A transformation helper.
1303
24.3k
  DifferentiationTransformer transformer(*this);
1304
24.3k
  ADContext &context = transformer.getContext();
1305
1306
24.3k
  bool errorOccurred = false;
1307
1308
  // Register all the SIL differentiability witnesses in the module that trigger
1309
  // differentiation.
1310
24.3k
  for (auto &witness : module.getDifferentiabilityWitnesses()) {
1311
2.29k
    if (witness.isDeclaration())
1312
36
      continue;
1313
2.26k
    context.addInvoker(&witness);
1314
2.26k
  }
1315
1316
  // Register all the `differentiable_function` and `linear_function`
1317
  // instructions in the module that trigger differentiation.
1318
1.55M
  for (SILFunction &f : module) {
1319
1.82M
    for (SILBasicBlock &bb : f) {
1320
19.9M
      for (SILInstruction &i : bb) {
1321
19.9M
        if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i)) {
1322
7.10k
          context.getDifferentiableFunctionInstWorklist().push_back(dfi);
1323
19.9M
        } else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
1324
          // If linear map transposition is not enabled and an uncanonical
1325
          // `linear_function` instruction is encountered, emit a diagnostic.
1326
          // FIXME(https://github.com/apple/swift/issues/54256): Finish support for linear map transposition.
1327
12
          if (!EnableExperimentalLinearMapTransposition) {
1328
4
            if (!lfi->hasTransposeFunction()) {
1329
4
              astCtx.Diags.diagnose(
1330
4
                lfi->getLoc().getSourceLoc(),
1331
4
                diag::autodiff_conversion_to_linear_function_not_supported);
1332
4
              errorOccurred = true;
1333
4
            }
1334
4
          }
1335
12
          context.getLinearFunctionInstWorklist().push_back(lfi);
1336
12
        }
1337
19.9M
      }
1338
1.82M
    }
1339
1.55M
  }
1340
1341
  // If nothing has triggered differentiation, there's nothing to do.
1342
24.3k
  if (context.getInvokers().empty() &&
1343
24.3k
      context.getDifferentiableFunctionInstWorklist().empty() &&
1344
24.3k
      context.getLinearFunctionInstWorklist().empty())
1345
24.0k
    return;
1346
1347
  // Differentiation relies on the stdlib (the Swift module).
1348
  // If it's not imported, it's an internal error.
1349
372
  if (!astCtx.getStdlibModule()) {
1350
0
    astCtx.Diags.diagnose(SourceLoc(),
1351
0
                          diag::autodiff_internal_swift_not_imported);
1352
0
    return;
1353
0
  }
1354
372
  if (!astCtx.getLoadedModule(astCtx.Id_Differentiation)) {
1355
0
    SourceLoc loc;
1356
0
    if (!context.getInvokers().empty()) {
1357
0
      loc = context.getInvokers().front().second.getLocation();
1358
0
    } else {
1359
0
      assert(!context.getDifferentiableFunctionInstWorklist().empty());
1360
0
      loc = context.getDifferentiableFunctionInstWorklist()
1361
0
                .pop_back_val()
1362
0
                ->getLoc()
1363
0
                .getSourceLoc();
1364
0
    }
1365
0
    astCtx.Diags.diagnose(loc,
1366
0
                          diag::autodiff_differentiation_module_not_imported);
1367
0
    return;
1368
0
  }
1369
1370
  // Process all invokers.
1371
2.26k
  for (auto invokerPair : context.getInvokers()) {
1372
2.26k
    auto *witness = invokerPair.first;
1373
2.26k
    auto invoker = invokerPair.second;
1374
2.26k
    if (transformer.canonicalizeDifferentiabilityWitness(
1375
2.26k
            witness, invoker, witness->getOriginalFunction()->isSerialized()))
1376
172
      errorOccurred = true;
1377
2.26k
  }
1378
1379
  // Iteratively process `differentiable_function` instruction worklist.
1380
26.7k
  while (!context.getDifferentiableFunctionInstWorklist().empty()) {
1381
26.3k
    auto *dfi = context.getDifferentiableFunctionInstWorklist().pop_back_val();
1382
    // Skip instructions that have been already been processed.
1383
26.3k
    if (context.isDifferentiableFunctionInstProcessed(dfi))
1384
11.6k
      continue;
1385
14.7k
    errorOccurred |= transformer.processDifferentiableFunctionInst(dfi);
1386
14.7k
  }
1387
1388
  // Iteratively process `linear_function` instruction worklist.
1389
396
  while (!context.getLinearFunctionInstWorklist().empty()) {
1390
24
    auto *lfi = context.getLinearFunctionInstWorklist().pop_back_val();
1391
    // Skip instructions that have been already been processed.
1392
24
    if (context.isLinearFunctionInstProcessed(lfi))
1393
12
      continue;
1394
12
    errorOccurred |= transformer.processLinearFunctionInst(lfi);
1395
12
  }
1396
1397
  // If any error occurred while processing witnesses or
1398
  // `differentiable_function` instructions, clean up.
1399
372
  if (errorOccurred) {
1400
24
    context.cleanUp();
1401
24
    return;
1402
24
  }
1403
1404
348
  LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n");
1405
348
}
1406
1407
//===----------------------------------------------------------------------===//
1408
// Pass creation
1409
//===----------------------------------------------------------------------===//
1410
1411
24.3k
SILTransform *swift::createDifferentiation() { return new Differentiation; }