Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- LinearMapInfo.cpp ------------------------------------*- C++ -*---===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// Linear map tuple and branching trace enum information for differentiation.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#define DEBUG_TYPE "differentiation"
18
19
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
20
#include "swift/SILOptimizer/Differentiation/ADContext.h"
21
22
#include "swift/AST/DeclContext.h"
23
#include "swift/AST/ParameterList.h"
24
#include "swift/AST/SourceFile.h"
25
#include "swift/SIL/LoopInfo.h"
26
27
namespace swift {
28
namespace autodiff {
29
30
//===----------------------------------------------------------------------===//
31
// Local helpers
32
//===----------------------------------------------------------------------===//
33
34
/// Clone the generic parameters of the given generic signature and return a new
35
/// `GenericParamList`.
36
static GenericParamList *cloneGenericParameters(ASTContext &ctx,
37
                                                DeclContext *dc,
38
1.54k
                                                CanGenericSignature sig) {
39
1.54k
  SmallVector<GenericTypeParamDecl *, 2> clonedParams;
40
1.71k
  for (auto paramType : sig.getGenericParams()) {
41
1.71k
    auto *clonedParam = GenericTypeParamDecl::createImplicit(
42
1.71k
        dc, paramType->getName(), paramType->getDepth(), paramType->getIndex(),
43
1.71k
        paramType->isParameterPack());
44
1.71k
    clonedParam->setDeclContext(dc);
45
1.71k
    clonedParams.push_back(clonedParam);
46
1.71k
  }
47
1.54k
  return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc());
48
1.54k
}
49
50
//===----------------------------------------------------------------------===//
51
// LinearMapInfo methods
52
//===----------------------------------------------------------------------===//
53
54
LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
55
                             SILFunction *original, SILFunction *derivative,
56
                             const AutoDiffConfig &config,
57
                             const DifferentiableActivityInfo &activityInfo,
58
                             SILLoopInfo *loopInfo)
59
    : kind(kind), original(original), derivative(derivative),
60
      activityInfo(activityInfo), loopInfo(loopInfo), config(config),
61
      synthesizedFile(context.getOrCreateSynthesizedFile(original)),
62
6.61k
      typeConverter(context.getTypeConverter()) {
63
6.61k
  generateDifferentiationDataStructures(context, derivative);
64
6.61k
}
65
66
7.35k
SILType LinearMapInfo::remapTypeInDerivative(SILType ty) {
67
7.35k
  if (ty.hasArchetype())
68
656
    return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext());
69
6.69k
  return derivative->mapTypeIntoContext(ty);
70
7.35k
}
71
72
EnumDecl *
73
LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
74
8.60k
                                        CanGenericSignature genericSig) {
75
8.60k
  assert(originalBB->getParent() == original);
76
0
  auto &astCtx = original->getASTContext();
77
8.60k
  auto &file = getSynthesizedFile();
78
  // Create a branching trace enum.
79
8.60k
  Mangle::ASTMangler mangler;
80
8.60k
  auto config = this->config.withGenericSignature(genericSig);
81
8.60k
  auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
82
8.60k
      AutoDiffGeneratedDeclarationKind::BranchingTraceEnum,
83
8.60k
      original->getName().str(), originalBB->getDebugID(), kind, config);
84
8.60k
  auto enumId = astCtx.getIdentifier(enumName);
85
8.60k
  auto loc = original->getLocation().getSourceLoc();
86
8.60k
  GenericParamList *genericParams = nullptr;
87
8.60k
  if (genericSig)
88
1.54k
    genericParams = cloneGenericParameters(astCtx, &file, genericSig);
89
8.60k
  auto *branchingTraceDecl = new (astCtx) EnumDecl(
90
8.60k
      /*EnumLoc*/ SourceLoc(), /*Name*/ enumId, /*NameLoc*/ loc,
91
8.60k
      /*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file);
92
  // Note: must mark enum as implicit to satisfy assertion in
93
  // `Parser::parseDeclListDelayed`.
94
8.60k
  branchingTraceDecl->setImplicit();
95
8.60k
  if (genericSig)
96
1.54k
    branchingTraceDecl->setGenericSignature(genericSig);
97
8.60k
  switch (original->getEffectiveSymbolLinkage()) {
98
252
  case swift::SILLinkage::Public:
99
252
  case swift::SILLinkage::PublicNonABI:
100
    // Branching trace enums shall not be resilient.
101
252
    branchingTraceDecl->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true));
102
252
    branchingTraceDecl->getAttrs().add(new (astCtx) UsableFromInlineAttr(/*Implicit*/ true));
103
252
    LLVM_FALLTHROUGH;
104
3.08k
  case swift::SILLinkage::Hidden:
105
3.19k
  case swift::SILLinkage::Shared:
106
3.19k
    branchingTraceDecl->setAccess(AccessLevel::Internal);
107
3.19k
    break;
108
5.41k
  case swift::SILLinkage::Private:
109
5.41k
    branchingTraceDecl->setAccess(AccessLevel::FilePrivate);
110
5.41k
    break;
111
0
  default:
112
    // When the original function has external linkage, we create an internal
113
    // struct for use by our own module. This is necessary for cross-cell
114
    // differentiation in Jupyter.
115
    // TODO: Add a test in the compiler that exercises a similar situation as
116
    // cross-cell differentiation in Jupyter.
117
0
    branchingTraceDecl->setAccess(AccessLevel::Internal);
118
8.60k
  }
119
8.60k
  file.addTopLevelDecl(branchingTraceDecl);
120
8.60k
  file.getParentModule()->clearLookupCache();
121
122
8.60k
  return branchingTraceDecl;
123
8.60k
}
124
125
void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
126
1.99k
                                               SILLoopInfo *loopInfo) {
127
1.99k
  auto &astCtx = original->getASTContext();
128
1.99k
  auto *moduleDecl = original->getModule().getSwiftModule();
129
1.99k
  auto loc = original->getLocation().getSourceLoc();
130
1.99k
  auto *branchingTraceDecl = getBranchingTraceDecl(originalBB);
131
132
  // Add basic block enum cases.
133
2.64k
  for (auto *predBB : originalBB->getPredecessorBlocks()) {
134
    // Create dummy declaration representing enum case parameter.
135
2.64k
    auto *decl = new (astCtx)
136
2.64k
        ParamDecl(loc, loc, Identifier(), loc, Identifier(), moduleDecl);
137
2.64k
    decl->setSpecifier(ParamDecl::Specifier::Default);
138
    // If predecessor block is in a loop, its linear map tuple will be
139
    // indirectly referenced in memory owned by the context object. The payload
140
    // is just a raw pointer.
141
2.64k
    if (loopInfo->getLoopFor(predBB)) {
142
408
      heapAllocatedContext = true;
143
408
      decl->setInterfaceType(astCtx.TheRawPointerType);
144
2.24k
    } else { // Otherwise the payload is the linear map tuple.
145
2.24k
      auto *linearMapStructTy = getLinearMapTupleType(predBB);
146
2.24k
      assert(linearMapStructTy && "must have linear map struct type for predecessor BB");
147
0
      auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
148
2.24k
      decl->setInterfaceType(
149
2.24k
          canLinearMapStructTy->hasArchetype()
150
2.24k
              ? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
151
2.24k
    }
152
    // Create enum element and enum case declarations.
153
0
    auto *paramList = ParameterList::create(astCtx, {decl});
154
2.64k
    auto bbId = "bb" + std::to_string(predBB->getDebugID());
155
2.64k
    auto *enumEltDecl = new (astCtx) EnumElementDecl(
156
2.64k
        /*IdentifierLoc*/ loc, DeclName(astCtx.getIdentifier(bbId)), paramList,
157
2.64k
        loc, /*RawValueExpr*/ nullptr, branchingTraceDecl);
158
2.64k
    enumEltDecl->setImplicit();
159
2.64k
    auto *enumCaseDecl = EnumCaseDecl::create(
160
2.64k
        /*CaseLoc*/ loc, {enumEltDecl}, branchingTraceDecl);
161
2.64k
    enumCaseDecl->setImplicit();
162
2.64k
    branchingTraceDecl->addMember(enumEltDecl);
163
2.64k
    branchingTraceDecl->addMember(enumCaseDecl);
164
    // Record enum element declaration.
165
2.64k
    branchingTraceEnumCases.insert({{predBB, originalBB}, enumEltDecl});
166
2.64k
  }
167
1.99k
}
168
169
170
7.35k
Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) {
171
7.35k
  SmallVector<SILValue, 4> allResults;
172
7.35k
  SmallVector<unsigned, 8> activeParamIndices;
173
7.35k
  SmallVector<unsigned, 8> activeResultIndices;
174
7.35k
  collectMinimalIndicesForFunctionCall(ai, config, activityInfo, allResults,
175
7.35k
                                       activeParamIndices, activeResultIndices);
176
177
  // Check if there are any active results or arguments. If not, skip
178
  // this instruction.
179
7.37k
  auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
180
7.37k
    return activityInfo.isActive(res, config);
181
7.37k
  });
182
7.35k
  bool hasActiveSemanticResultArgument = false;
183
7.35k
  bool hasActiveArguments = false;
184
7.35k
  auto numIndirectResults = ai->getNumIndirectResults();
185
17.8k
  for (auto argIdx : range(ai->getSubstCalleeConv().getNumParameters())) {
186
17.8k
    auto arg = ai->getArgumentsWithoutIndirectResults()[argIdx];
187
17.8k
    if (activityInfo.isActive(arg, config)) {
188
11.4k
      hasActiveArguments = true;
189
11.4k
      auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg(
190
11.4k
          numIndirectResults + argIdx);
191
11.4k
      if (paramInfo.isAutoDiffSemanticResult())
192
608
        hasActiveSemanticResultArgument = true;
193
11.4k
    }
194
17.8k
  }
195
7.35k
  if (!hasActiveArguments)
196
0
    return {};
197
7.35k
  if (!hasActiveResults && !hasActiveSemanticResultArgument)
198
0
    return {};
199
200
  // Compute differentiability parameters.
201
  // - If the callee has `@differentiable` function type, use differentiation
202
  //   parameters from the function type.
203
  // - Otherwise, use the active parameters.
204
7.35k
  IndexSubset *parameters;
205
7.35k
  auto origFnSubstTy = ai->getSubstCalleeType();
206
7.35k
  auto remappedOrigFnSubstTy =
207
7.35k
      remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy))
208
7.35k
          .castTo<SILFunctionType>()
209
7.35k
          ->getUnsubstitutedType(original->getModule());
210
7.35k
  if (remappedOrigFnSubstTy->isDifferentiable()) {
211
80
    parameters = remappedOrigFnSubstTy->getDifferentiabilityParameterIndices();
212
7.27k
  } else {
213
7.27k
    parameters = IndexSubset::get(
214
7.27k
        original->getASTContext(),
215
7.27k
        ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices);
216
7.27k
  }
217
  // Compute differentiability results.
218
7.35k
  auto *results = IndexSubset::get(original->getASTContext(),
219
7.35k
                                   remappedOrigFnSubstTy->getNumAutoDiffSemanticResults(),
220
7.35k
                                   activeResultIndices);
221
  // Create autodiff indices for the `apply` instruction.
222
7.35k
  AutoDiffConfig applyConfig(parameters, results);
223
224
  // Check for non-differentiable original function type.
225
7.35k
  auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType
226
7.35k
                                                            origFnTy) {
227
    // Check non-differentiable arguments.
228
11.4k
    for (auto paramIndex : applyConfig.parameterIndices->getIndices()) {
229
11.4k
      auto remappedParamType =
230
11.4k
          origFnTy->getParameters()[paramIndex].getSILStorageInterfaceType();
231
11.4k
      if (!remappedParamType.isDifferentiable(derivative->getModule()))
232
20
        return true;
233
11.4k
    }
234
    // Check non-differentiable results.
235
7.44k
    for (auto resultIndex : applyConfig.resultIndices->getIndices()) {
236
7.44k
      SILType remappedResultType;
237
7.44k
      if (resultIndex >= origFnTy->getNumResults()) {
238
604
        auto semanticResultArgIdx = resultIndex - origFnTy->getNumResults();
239
604
        auto semanticResultArg =
240
604
            *std::next(ai->getAutoDiffSemanticResultArguments().begin(),
241
604
                       semanticResultArgIdx);
242
604
        remappedResultType = semanticResultArg->getType();
243
6.83k
      } else {
244
6.83k
        remappedResultType =
245
6.83k
            origFnTy->getResults()[resultIndex].getSILStorageInterfaceType();
246
6.83k
      }
247
7.44k
      if (!remappedResultType.isDifferentiable(derivative->getModule()))
248
12
        return true;
249
7.44k
    }
250
7.32k
    return false;
251
7.33k
  };
252
7.35k
  if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy))
253
32
    return nullptr;
254
255
7.32k
  AutoDiffDerivativeFunctionKind derivativeFnKind(kind);
256
7.32k
  auto derivativeFnType =
257
7.32k
      remappedOrigFnSubstTy
258
7.32k
          ->getAutoDiffDerivativeFunctionType(
259
7.32k
              parameters, results, derivativeFnKind, context.getTypeConverter(),
260
7.32k
              LookUpConformanceInModule(
261
7.32k
                  derivative->getModule().getSwiftModule()))
262
7.32k
          ->getUnsubstitutedType(original->getModule());
263
264
7.32k
  auto derivativeFnResultTypes = derivativeFnType->getAllResultsInterfaceType();
265
7.32k
  auto linearMapSILType = derivativeFnResultTypes;
266
7.32k
  if (auto tupleType = linearMapSILType.getAs<TupleType>()) {
267
6.78k
    linearMapSILType = SILType::getPrimitiveObjectType(
268
6.78k
        tupleType.getElementType(tupleType->getElements().size() - 1));
269
6.78k
  }
270
7.32k
  if (auto fnTy = linearMapSILType.getAs<SILFunctionType>()) {
271
7.32k
    linearMapSILType = SILType::getPrimitiveObjectType(
272
7.32k
        fnTy->getUnsubstitutedType(original->getModule()));
273
7.32k
  }
274
275
  // IRGen requires decls to have AST types (not `SILFunctionType`), so we
276
  // convert the `SILFunctionType` of the linear map to a `FunctionType` with
277
  // the same parameters and results.
278
7.32k
  auto silFnTy = linearMapSILType.castTo<SILFunctionType>();
279
7.32k
  SmallVector<AnyFunctionType::Param, 8> params;
280
8.37k
  for (auto &param : silFnTy->getParameters()) {
281
8.37k
    ParameterTypeFlags flags;
282
8.37k
    if (param.isAutoDiffSemanticResult())
283
604
      flags = flags.withInOut(true);
284
285
8.37k
    params.push_back(
286
8.37k
        AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags));
287
8.37k
  }
288
289
7.32k
  AnyFunctionType *astFnTy;
290
7.32k
  if (auto genSig = silFnTy->getSubstGenericSignature()) {
291
    // FIXME: Verify ExtInfo state is correct, not working by accident.
292
0
    GenericFunctionType::ExtInfo info;
293
0
    astFnTy = GenericFunctionType::get(
294
0
        genSig, params, silFnTy->getAllResultsInterfaceType().getASTType(),
295
0
        info);
296
7.32k
  } else {
297
7.32k
    FunctionType::ExtInfo info;
298
7.32k
    astFnTy = FunctionType::get(
299
7.32k
        params, silFnTy->getAllResultsInterfaceType().getASTType(), info);
300
7.32k
  }
301
302
7.32k
  if (astFnTy->hasArchetype())
303
628
    return astFnTy->mapTypeOutOfContext();
304
305
6.69k
  return astFnTy;
306
7.32k
}
307
308
void LinearMapInfo::generateDifferentiationDataStructures(
309
6.61k
    ADContext &context, SILFunction *derivativeFn) {
310
6.61k
  auto &astCtx = original->getASTContext();
311
  // Get the derivative function generic signature.
312
6.61k
  CanGenericSignature derivativeFnGenSig = nullptr;
313
6.61k
  if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment())
314
1.10k
    derivativeFnGenSig =
315
1.10k
        derivativeFnGenEnv->getGenericSignature().getCanonicalSignature();
316
317
  // Create branching trace enum for each original block and add it as a field
318
  // in the corresponding struct.
319
6.61k
  StringRef traceEnumFieldName;
320
6.61k
  switch (kind) {
321
1.35k
  case AutoDiffLinearMapKind::Differential:
322
1.35k
    traceEnumFieldName = "successor";
323
1.35k
    break;
324
5.25k
  case AutoDiffLinearMapKind::Pullback:
325
5.25k
    traceEnumFieldName = "predecessor";
326
5.25k
    break;
327
6.61k
  }
328
329
8.60k
  for (auto &origBB : *original) {
330
8.60k
    auto *traceEnum =
331
8.60k
        createBranchingTraceDecl(&origBB, derivativeFnGenSig);
332
8.60k
    branchingTraceDecls.insert({&origBB, traceEnum});
333
8.60k
  }
334
335
  // Add linear map fields to the linear map tuples.
336
  //
337
  // Now we need to be very careful as we're having a very subtle
338
  // chicken-and-egg problem. We need lowered branch trace enum type for the
339
  // linear map typle type. However branch trace enum type lowering depends on
340
  // the lowering of its elements (at very least, the type classification of
341
  // being trivial / non-trivial). As the lowering is cached we need to ensure
342
  // we compute lowered type for the branch trace enum when the corresponding
343
  // EnumDecl is fully complete: we cannot add more entries without causing some
344
  // very subtle issues later on. However, the elements of the enum are linear
345
  // map tuples of predecessors, that correspondingly may contain branch trace
346
  // enums of corresponding predecessor BBs.
347
  //
348
  // Traverse all BBs in reverse post-order traversal order to ensure we process
349
  // each BB before its predecessors.
350
6.61k
  llvm::ReversePostOrderTraversal<SILFunction *> RPOT(original);
351
15.2k
  for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) {
352
8.60k
    auto *origBB = *Iter;
353
8.60k
    SmallVector<TupleTypeElt, 4> linearTupleTypes;
354
8.60k
    if (!origBB->isEntry()) {
355
1.99k
      populateBranchingTraceDecl(origBB, loopInfo);
356
357
1.99k
      CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType();
358
1.99k
      linearTupleTypes.emplace_back(traceEnumType,
359
1.99k
                                    astCtx.getIdentifier(traceEnumFieldName));
360
1.99k
    }
361
362
8.60k
    if (isSemanticMemberAccessor(original)) {
363
      // Do not add linear map fields for semantic member accessors, which have
364
      // special-case pullback generation. Linear map tuples should be empty.
365
8.33k
    } else {
366
100k
      for (auto &inst : *origBB) {
367
100k
        if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
368
          // Add linear map field to struct for active `apply` instructions.
369
          // Skip array literal intrinsic applications since array literal
370
          // initialization is linear and handled separately.
371
11.5k
          if (!shouldDifferentiateApplySite(ai) ||
372
11.5k
              ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC))
373
3.99k
            continue;
374
7.56k
          if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC))
375
212
            continue;
376
7.35k
          LLVM_DEBUG(getADDebugStream()
377
7.35k
                     << "Adding linear map tuple field for " << *ai);
378
7.35k
          if (Type linearMapType = getLinearMapType(context, ai)) {
379
7.32k
            linearMapIndexMap.insert({ai, linearTupleTypes.size()});
380
7.32k
            linearTupleTypes.emplace_back(linearMapType);
381
7.32k
          }
382
7.35k
        }
383
100k
      }
384
8.33k
    }
385
386
8.60k
    linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)});
387
8.60k
  }
388
389
  // Print generated linear map structs and branching trace enums.
390
  // These declarations do not show up with `-emit-sil` because they are
391
  // implicit. Instead, use `-Xllvm -debug-only=differentiation` to test
392
  // declarations with FileCheck.
393
6.61k
  LLVM_DEBUG({
394
6.61k
    auto &s = getADDebugStream();
395
6.61k
    PrintOptions printOptions;
396
6.61k
    printOptions.TypeDefinitions = true;
397
6.61k
    printOptions.ExplodePatternBindingDecls = true;
398
6.61k
    printOptions.SkipImplicit = false;
399
6.61k
    s << "Generated linear map tuples and branching trace enums for @"
400
6.61k
      << original->getName() << ":\n";
401
6.61k
    for (auto &origBB : *original) {
402
6.61k
      auto *linearMapTuple = getLinearMapTupleType(&origBB);
403
6.61k
      linearMapTuple->print(s, printOptions);
404
6.61k
      s << '\n';
405
6.61k
    }
406
407
6.61k
    for (auto &origBB : *original) {
408
6.61k
      auto *traceEnum = getBranchingTraceDecl(&origBB);
409
6.61k
      traceEnum->print(s, printOptions);
410
6.61k
      s << '\n';
411
6.61k
    }
412
6.61k
  });
413
6.61k
}
414
415
/// Returns a flag that indicates whether the `apply` instruction should be
416
/// differentiated, given the differentiation indices of the instruction's
417
/// parent function. Whether the `apply` should be differentiated is determined
418
/// sequentially from the following conditions:
419
/// 1. The instruction has an active `inout` argument.
420
/// 2. The instruction is a call to the array literal initialization intrinsic
421
///    ("array.uninitialized_intrinsic"), where the result is active and where
422
///    there is a `store` of an active value into the array's buffer.
423
/// 3. The instruction has both an active result (direct or indirect) and an
424
///    active argument.
425
42.6k
bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
426
  // Function applications with an active inout argument should be
427
  // differentiated.
428
42.6k
  for (auto inoutArg : applySite.getInoutArguments())
429
2.80k
    if (activityInfo.isActive(inoutArg, config))
430
2.28k
      return true;
431
432
40.3k
  bool hasActiveDirectResults = false;
433
40.3k
  forEachApplyDirectResult(applySite, [&](SILValue directResult) {
434
29.1k
    hasActiveDirectResults |= activityInfo.isActive(directResult, config);
435
29.1k
  });
436
40.3k
  bool hasActiveIndirectResults =
437
40.3k
      llvm::any_of(applySite.getIndirectSILResults(), [&](SILValue result) {
438
12.2k
        return activityInfo.isActive(result, config);
439
12.2k
      });
440
40.3k
  bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults;
441
442
  // TODO: Pattern match to make sure there is at least one `store` to the
443
  // array's active buffer.
444
40.3k
  if (ArraySemanticsCall(applySite.getInstruction(),
445
40.3k
                         semantics::ARRAY_UNINITIALIZED_INTRINSIC) &&
446
40.3k
      hasActiveResults)
447
868
    return true;
448
449
39.5k
  auto arguments = applySite.getArgumentsWithoutIndirectResults();
450
53.7k
  bool hasActiveArguments = llvm::any_of(arguments, [&](SILValue arg) {
451
53.7k
    return activityInfo.isActive(arg, config);
452
53.7k
  });
453
39.5k
  return hasActiveResults && hasActiveArguments;
454
40.3k
}
455
456
static bool shouldDifferentiateInjectEnumAddr(
457
    const InjectEnumAddrInst &inject,
458
    const DifferentiableActivityInfo &activityInfo,
459
12
    const AutoDiffConfig &config) {
460
12
  SILValue en = inject.getOperand();
461
24
  for (auto use : en->getUses()) {
462
24
    auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser());
463
24
    if (init && activityInfo.isActive(init, config))
464
8
      return true;
465
24
  }
466
4
  return false;
467
12
}
468
469
/// Returns a flag indicating whether the instruction should be differentiated,
470
/// given the differentiation indices of the instruction's parent function.
471
/// Whether the instruction should be differentiated is determined sequentially
472
/// from any of the following conditions:
473
/// 1. The instruction is a full apply site and `shouldDifferentiateApplyInst`
474
///    returns true.
475
/// 2. The instruction has a source operand and a destination operand, both
476
///    being active.
477
/// 3. The instruction is an allocation instruction and has an active result.
478
/// 4. The instruction performs reference counting, lifetime ending, access
479
///    ending, or destroying on an active operand.
480
/// 5. The instruction creates an SSA copy of an active operand.
481
108k
bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
482
  // A full apply site with an active argument and an active result (direct or
483
  // indirect) should be differentiated.
484
108k
  if (FullApplySite::isa(inst))
485
11.7k
    return shouldDifferentiateApplySite(FullApplySite(inst));
486
  // Anything with an active result and an active operand should be
487
  // differentiated.
488
97.1k
  auto hasActiveOperands =
489
97.1k
      llvm::any_of(inst->getAllOperands(), [&](Operand &op) {
490
65.5k
        return activityInfo.isActive(op.get(), config);
491
65.5k
      });
492
97.1k
  auto hasActiveResults = llvm::any_of(inst->getResults(), [&](SILValue val) {
493
51.5k
    return activityInfo.isActive(val, config);
494
51.5k
  });
495
97.1k
  if (hasActiveOperands && hasActiveResults)
496
14.1k
    return true;
497
    // `store`-like instructions do not have an SSA result, but have two
498
    // operands that represent the source and the destination. We treat them as
499
    // the input and the output, respectively.
500
    // For `store`-like instructions whose destination is an element address
501
    // from an `array.uninitialized_intrinsic` application, return true if the
502
    // intrinsic application (representing the semantic destination) is active.
503
82.9k
#define CHECK_INST_TYPE_ACTIVE_DEST(INST)                                      \
504
311k
  if (auto *castInst = dyn_cast<INST##Inst>(inst))                             \
505
311k
    return activityInfo.isActive(castInst->getDest(), config);
506
82.9k
  CHECK_INST_TYPE_ACTIVE_DEST(Store)
507
77.1k
  CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrow)
508
77.0k
  CHECK_INST_TYPE_ACTIVE_DEST(CopyAddr)
509
74.3k
  CHECK_INST_TYPE_ACTIVE_DEST(UnconditionalCheckedCastAddr)
510
74.2k
#undef CHECK_INST_TYPE_ACTIVE_DEST
511
  // Should differentiate any allocation instruction that has an active result.
512
74.2k
  if ((isa<AllocationInst>(inst) && hasActiveResults))
513
6.83k
    return true;
514
67.4k
  if (hasActiveOperands) {
515
    // Should differentiate any instruction that performs reference counting,
516
    // lifetime ending, access ending, or destroying on an active operand.
517
29.7k
    if (isa<RefCountingInst>(inst) || isa<EndAccessInst>(inst) ||
518
29.7k
        isa<EndBorrowInst>(inst) || isa<DeallocationInst>(inst) ||
519
29.7k
        isa<DestroyValueInst>(inst) || isa<DestroyAddrInst>(inst))
520
14.6k
      return true;
521
29.7k
  }
522
523
  // Should differentiate `inject_enum_addr` if the corresponding
524
  // `init_enum_addr` has an active operand.
525
52.8k
  if (auto inject = dyn_cast<InjectEnumAddrInst>(inst))
526
12
    if (shouldDifferentiateInjectEnumAddr(*inject, activityInfo, config))
527
8
      return true;
528
529
52.8k
  return false;
530
52.8k
}
531
532
} // end namespace autodiff
533
} // end namespace swift