Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/Common.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// Automatic differentiation common utilities.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#include "swift/Basic/STLExtras.h"
18
#define DEBUG_TYPE "differentiation"
19
20
#include "swift/SILOptimizer/Differentiation/Common.h"
21
#include "swift/AST/TypeCheckRequests.h"
22
#include "swift/SILOptimizer/Differentiation/ADContext.h"
23
24
namespace swift {
25
namespace autodiff {
26
27
25.9k
raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
28
29
//===----------------------------------------------------------------------===//
30
// Helpers
31
//===----------------------------------------------------------------------===//
32
33
10.5k
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
34
  // Find the `pointer_to_address` result, peering through `index_addr`.
35
10.5k
  auto *ptai = dyn_cast<PointerToAddressInst>(v);
36
10.5k
  if (auto *iai = dyn_cast<IndexAddrInst>(v))
37
124
    ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0));
38
10.5k
  if (!ptai)
39
10.0k
    return nullptr;
40
  // Return the `array.uninitialized_intrinsic` application, if it exists.
41
488
  if (auto *dti = dyn_cast<DestructureTupleInst>(
42
488
          ptai->getOperand()->getDefiningInstruction()))
43
488
    return ArraySemanticsCall(dti->getOperand(),
44
488
                              semantics::ARRAY_UNINITIALIZED_INTRINSIC);
45
0
  return nullptr;
46
488
}
47
48
30.9k
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
49
30.9k
  bool foundDestructureTupleUser = false;
50
30.9k
  if (!value->getType().is<TupleType>())
51
0
    return nullptr;
52
30.9k
  DestructureTupleInst *result = nullptr;
53
30.9k
  for (auto *use : value->getUses()) {
54
1.63k
    if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) {
55
1.63k
      assert(!foundDestructureTupleUser &&
56
1.63k
             "There should only be one `destructure_tuple` user of a tuple");
57
0
      foundDestructureTupleUser = true;
58
1.63k
      result = dti;
59
1.63k
    }
60
1.63k
  }
61
30.9k
  return result;
62
30.9k
}
63
64
47.0k
bool isSemanticMemberAccessor(SILFunction *original) {
65
47.0k
  auto *dc = original->getDeclContext();
66
47.0k
  if (!dc)
67
744
    return false;
68
46.2k
  auto *decl = dc->getAsDecl();
69
46.2k
  if (!decl)
70
10.3k
    return false;
71
35.9k
  auto *accessor = dyn_cast<AccessorDecl>(decl);
72
35.9k
  if (!accessor)
73
33.2k
    return false;
74
  // Currently, only getters and setters are supported.
75
  // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
76
2.63k
  if (accessor->getAccessorKind() != AccessorKind::Get &&
77
2.63k
      accessor->getAccessorKind() != AccessorKind::Set)
78
0
    return false;
79
  // Accessor must come from a `var` declaration.
80
2.63k
  auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
81
2.63k
  if (!varDecl)
82
68
    return false;
83
  // Return true for stored property accessors.
84
2.56k
  if (varDecl->hasStorage() && varDecl->isInstanceMember())
85
716
    return true;
86
  // Return true for properties that have attached property wrappers.
87
1.85k
  if (varDecl->hasAttachedPropertyWrapper())
88
1.28k
    return true;
89
  // Otherwise, return false.
90
  // User-defined accessors can never be supported because they may use custom
91
  // logic that does not semantically perform a member access.
92
564
  return false;
93
1.85k
}
94
95
0
bool hasSemanticMemberAccessorCallee(ApplySite applySite) {
96
0
  if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee()))
97
0
    if (auto *F = FRI->getReferencedFunctionOrNull())
98
0
      return isSemanticMemberAccessor(F);
99
0
  return false;
100
0
}
101
102
void forEachApplyDirectResult(
103
    FullApplySite applySite,
104
77.1k
    llvm::function_ref<void(SILValue)> resultCallback) {
105
77.1k
  switch (applySite.getKind()) {
106
77.0k
  case FullApplySiteKind::ApplyInst: {
107
77.0k
    auto *ai = cast<ApplyInst>(applySite.getInstruction());
108
77.0k
    if (!ai->getType().is<TupleType>()) {
109
46.0k
      resultCallback(ai);
110
46.0k
      return;
111
46.0k
    }
112
30.9k
    if (auto *dti = getSingleDestructureTupleUser(ai))
113
1.63k
      for (auto directResult : dti->getResults())
114
3.26k
        resultCallback(directResult);
115
30.9k
    break;
116
77.0k
  }
117
96
  case FullApplySiteKind::BeginApplyInst: {
118
96
    auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
119
96
    for (auto directResult : bai->getResults())
120
192
      resultCallback(directResult);
121
96
    break;
122
77.0k
  }
123
68
  case FullApplySiteKind::TryApplyInst: {
124
68
    auto *tai = cast<TryApplyInst>(applySite.getInstruction());
125
68
    for (auto *succBB : tai->getSuccessorBlocks())
126
136
      for (auto *arg : succBB->getArguments())
127
136
        resultCallback(arg);
128
68
    break;
129
77.0k
  }
130
77.1k
  }
131
77.1k
}
132
133
void collectAllFormalResultsInTypeOrder(SILFunction &function,
134
11.9k
                                        SmallVectorImpl<SILValue> &results) {
135
11.9k
  SILFunctionConventions convs(function.getLoweredFunctionType(),
136
11.9k
                               function.getModule());
137
11.9k
  auto indResults = function.getIndirectResults();
138
11.9k
  auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
139
11.9k
  auto retVal = retInst->getOperand();
140
11.9k
  SmallVector<SILValue, 8> dirResults;
141
11.9k
  if (auto *tupleInst =
142
11.9k
          dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction()))
143
3.39k
    dirResults.append(tupleInst->getElements().begin(),
144
3.39k
                      tupleInst->getElements().end());
145
8.54k
  else
146
8.54k
    dirResults.push_back(retVal);
147
11.9k
  unsigned indResIdx = 0, dirResIdx = 0;
148
11.9k
  for (auto &resInfo : convs.getResults())
149
11.4k
    results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
150
11.4k
                                               : indResults[indResIdx++]);
151
  // Treat semantic result parameters as semantic results.
152
  // Append them` parameters after formal results.
153
19.2k
  for (auto i : range(convs.getNumParameters())) {
154
19.2k
    auto paramInfo = convs.getParameters()[i];
155
19.2k
    if (!paramInfo.isAutoDiffSemanticResult())
156
18.3k
      continue;
157
816
    auto *argument = function.getArgumentsWithoutIndirectResults()[i];
158
816
    results.push_back(argument);
159
816
  }
160
11.9k
}
161
162
void collectAllDirectResultsInTypeOrder(SILFunction &function,
163
1.35k
                                        SmallVectorImpl<SILValue> &results) {
164
1.35k
  SILFunctionConventions convs(function.getLoweredFunctionType(),
165
1.35k
                               function.getModule());
166
1.35k
  auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
167
1.35k
  auto retVal = retInst->getOperand();
168
1.35k
  if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
169
188
    results.append(tupleInst->getElements().begin(),
170
188
                   tupleInst->getElements().end());
171
1.16k
  else
172
1.16k
    results.push_back(retVal);
173
1.35k
}
174
175
void collectAllActualResultsInTypeOrder(
176
    ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
177
15.1k
    SmallVectorImpl<SILValue> &results) {
178
15.1k
  auto calleeConvs = ai->getSubstCalleeConv();
179
15.1k
  unsigned indResIdx = 0, dirResIdx = 0;
180
17.4k
  for (auto &resInfo : calleeConvs.getResults()) {
181
17.4k
    results.push_back(resInfo.isFormalDirect()
182
17.4k
                          ? extractedDirectResults[dirResIdx++]
183
17.4k
                          : ai->getIndirectSILResults()[indResIdx++]);
184
17.4k
  }
185
15.1k
}
186
187
void collectMinimalIndicesForFunctionCall(
188
    ApplyInst *ai, const AutoDiffConfig &parentConfig,
189
    const DifferentiableActivityInfo &activityInfo,
190
    SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
191
14.6k
    SmallVectorImpl<unsigned> &resultIndices) {
192
14.6k
  auto calleeFnTy = ai->getSubstCalleeType();
193
14.6k
  auto calleeConvs = ai->getSubstCalleeConv();
194
195
  // Parameter indices are indices (in the callee type signature) of parameter
196
  // arguments that are varied or are arguments.
197
  // Record all parameter indices in type order.
198
14.6k
  unsigned currentParamIdx = 0;
199
35.6k
  for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
200
35.6k
    if (activityInfo.isActive(applyArg, parentConfig))
201
22.8k
      paramIndices.push_back(currentParamIdx);
202
35.6k
    ++currentParamIdx;
203
35.6k
  }
204
205
  // Result indices are indices (in the callee type signature) of results that
206
  // are useful.
207
14.6k
  SmallVector<SILValue, 8> directResults;
208
14.6k
  forEachApplyDirectResult(ai, [&](SILValue directResult) {
209
9.00k
    directResults.push_back(directResult);
210
9.00k
  });
211
14.6k
  auto indirectResults = ai->getIndirectSILResults();
212
  // Record all results and result indices in type order.
213
14.6k
  results.reserve(calleeFnTy->getNumResults());
214
14.6k
  unsigned dirResIdx = 0;
215
14.6k
  unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult();
216
14.6k
  for (const auto &resAndIdx : enumerate(calleeConvs.getResults())) {
217
13.8k
    const auto &res = resAndIdx.value();
218
13.8k
    unsigned idx = resAndIdx.index();
219
13.8k
    if (res.isFormalDirect()) {
220
9.00k
      results.push_back(directResults[dirResIdx]);
221
9.00k
      if (auto dirRes = directResults[dirResIdx])
222
9.00k
        if (dirRes && activityInfo.isActive(dirRes, parentConfig))
223
8.93k
          resultIndices.push_back(idx);
224
9.00k
      ++dirResIdx;
225
9.00k
    } else {
226
4.79k
      results.push_back(indirectResults[indResIdx]);
227
4.79k
      if (activityInfo.isActive(indirectResults[indResIdx], parentConfig))
228
4.76k
        resultIndices.push_back(idx);
229
4.79k
      ++indResIdx;
230
4.79k
    }
231
13.8k
  }
232
  
233
  // Record all semantic result parameters as results.
234
14.6k
  auto semanticResultParamResultIndex = calleeFnTy->getNumResults();
235
35.6k
  for (const auto &paramAndIdx : enumerate(calleeConvs.getParameters())) {
236
35.6k
    const auto &param = paramAndIdx.value();
237
35.6k
    if (!param.isAutoDiffSemanticResult())
238
34.4k
      continue;
239
1.21k
    unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
240
1.21k
    results.push_back(ai->getArgument(idx));
241
1.21k
    resultIndices.push_back(semanticResultParamResultIndex++);
242
1.21k
  }
243
244
  // Make sure the function call has active results.
245
14.6k
#ifndef NDEBUG
246
14.6k
  assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults());
247
0
  assert(llvm::any_of(results, [&](SILValue result) {
248
14.6k
    return activityInfo.isActive(result, parentConfig);
249
14.6k
  }));
250
14.6k
#endif
251
14.6k
}
252
253
llvm::Optional<std::pair<SILDebugLocation, SILDebugVariable>>
254
34.9k
findDebugLocationAndVariable(SILValue originalValue) {
255
34.9k
  if (auto *asi = dyn_cast<AllocStackInst>(originalValue))
256
6.66k
    return swift::transform(asi->getVarInfo(),  [&](SILDebugVariable var) {
257
2.90k
      return std::make_pair(asi->getDebugLocation(), var);
258
2.90k
    });
259
58.3k
  for (auto *use : originalValue->getUses()) {
260
58.3k
    if (auto *dvi = dyn_cast<DebugValueInst>(use->getUser()))
261
14.1k
      return swift::transform(dvi->getVarInfo(), [&](SILDebugVariable var) {
262
        // We need to drop `op_deref` here as we're transferring debug info
263
        // location from debug_value instruction (which describes how to get value)
264
        // into alloc_stack (which describes the location)
265
14.1k
        if (var.DIExpr.startsWithDeref())
266
2.26k
          var.DIExpr.eraseElement(var.DIExpr.element_begin());
267
14.1k
        return std::make_pair(dvi->getDebugLocation(), var);
268
14.1k
      });
269
58.3k
  }
270
14.1k
  return llvm::None;
271
28.2k
}
272
273
//===----------------------------------------------------------------------===//
274
// Diagnostic utilities
275
//===----------------------------------------------------------------------===//
276
277
92
SILLocation getValidLocation(SILValue v) {
278
92
  auto loc = v.getLoc();
279
92
  if (loc.isNull() || loc.getSourceLoc().isInvalid())
280
4
    loc = v->getFunction()->getLocation();
281
92
  return loc;
282
92
}
283
284
4.52k
SILLocation getValidLocation(SILInstruction *inst) {
285
4.52k
  auto loc = inst->getLoc();
286
4.52k
  if (loc.isNull() || loc.getSourceLoc().isInvalid())
287
484
    loc = inst->getFunction()->getLocation();
288
4.52k
  return loc;
289
4.52k
}
290
291
//===----------------------------------------------------------------------===//
292
// Tangent property lookup utilities
293
//===----------------------------------------------------------------------===//
294
295
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
296
                                  CanType baseType, SILLocation loc,
297
4.10k
                                  DifferentiationInvoker invoker) {
298
4.10k
  auto &astCtx = context.getASTContext();
299
4.10k
  auto tanFieldInfo = evaluateOrDefault(
300
4.10k
      astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType},
301
4.10k
      TangentPropertyInfo(nullptr));
302
  // If no error, return the tangent property.
303
4.10k
  if (tanFieldInfo)
304
4.04k
    return tanFieldInfo.tangentProperty;
305
  // Otherwise, diagnose error and return nullptr.
306
52
  assert(tanFieldInfo.error);
307
0
  auto *parentDC = originalField->getDeclContext();
308
52
  assert(parentDC->isTypeContext());
309
0
  auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr();
310
52
  auto fieldName = originalField->getNameStr();
311
52
  auto sourceLoc = loc.getSourceLoc();
312
52
  switch (tanFieldInfo.error->kind) {
313
0
  case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
314
0
    llvm_unreachable(
315
0
        "`@noDerivative` stored property accesses should not be "
316
0
        "differentiated; activity analysis should not mark as varied");
317
0
  case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
318
0
    context.emitNondifferentiabilityError(
319
0
        sourceLoc, invoker,
320
0
        diag::autodiff_stored_property_parent_not_differentiable,
321
0
        parentDeclName, fieldName);
322
0
    break;
323
8
  case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
324
8
    context.emitNondifferentiabilityError(
325
8
        sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable,
326
8
        parentDeclName, fieldName, originalField->getInterfaceType());
327
8
    break;
328
8
  case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
329
8
    context.emitNondifferentiabilityError(
330
8
        sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct,
331
8
        parentDeclName, fieldName);
332
8
    break;
333
12
  case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
334
12
    context.emitNondifferentiabilityError(
335
12
        sourceLoc, invoker,
336
12
        diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName,
337
12
        fieldName);
338
12
    break;
339
12
  case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
340
12
    context.emitNondifferentiabilityError(
341
12
        sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type,
342
12
        parentDeclName, fieldName, tanFieldInfo.error->getType());
343
12
    break;
344
12
  case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
345
12
    context.emitNondifferentiabilityError(
346
12
        sourceLoc, invoker, diag::autodiff_tangent_property_not_stored,
347
12
        parentDeclName, fieldName);
348
12
    break;
349
52
  }
350
52
  return nullptr;
351
52
}
352
353
VarDecl *getTangentStoredProperty(ADContext &context,
354
                                  SingleValueInstruction *projectionInst,
355
                                  CanType baseType,
356
3.57k
                                  DifferentiationInvoker invoker) {
357
3.57k
  assert(isa<StructExtractInst>(projectionInst) ||
358
3.57k
         isa<StructElementAddrInst>(projectionInst) ||
359
3.57k
         isa<RefElementAddrInst>(projectionInst));
360
0
  Projection proj(projectionInst);
361
3.57k
  auto loc = getValidLocation(projectionInst);
362
3.57k
  auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType());
363
3.57k
  return getTangentStoredProperty(context, field, baseType,
364
3.57k
                                  loc, invoker);
365
3.57k
}
366
367
//===----------------------------------------------------------------------===//
368
// Code emission utilities
369
//===----------------------------------------------------------------------===//
370
371
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
372
22.2k
                      SILLocation loc) {
373
22.2k
  if (elements.size() == 1)
374
10.2k
    return elements.front();
375
11.9k
  return builder.createTuple(loc, elements);
376
22.2k
}
377
378
void extractAllElements(SILValue value, SILBuilder &builder,
379
23.3k
                        SmallVectorImpl<SILValue> &results) {
380
23.3k
  auto tupleType = value->getType().getAs<TupleType>();
381
23.3k
  if (!tupleType) {
382
11.4k
    results.push_back(value);
383
11.4k
    return;
384
11.4k
  }
385
11.9k
  if (builder.hasOwnership()) {
386
11.9k
    auto *dti = builder.createDestructureTuple(value.getLoc(), value);
387
11.9k
    results.append(dti->getResults().begin(), dti->getResults().end());
388
11.9k
    return;
389
11.9k
  }
390
0
  for (auto i : range(tupleType->getNumElements()))
391
0
    results.push_back(builder.createTupleExtract(value.getLoc(), value, i));
392
0
}
393
394
SILValue emitMemoryLayoutSize(
395
0
    SILBuilder &builder, SILLocation loc, CanType type) {
396
0
  auto &ctx = builder.getASTContext();
397
0
  auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof));
398
0
  auto *builtin = cast<FuncDecl>(getBuiltinValueDecl(ctx, id));
399
0
  auto metatypeTy = SILType::getPrimitiveObjectType(
400
0
      CanMetatypeType::get(type, MetatypeRepresentation::Thin));
401
0
  auto metatypeVal = builder.createMetatype(loc, metatypeTy);
402
0
  return builder.createBuiltin(
403
0
      loc, id, SILType::getBuiltinWordType(ctx),
404
0
      SubstitutionMap::get(
405
0
          builtin->getGenericSignature(), ArrayRef<Type>{type}, {}),
406
0
      {metatypeVal});
407
0
}
408
409
SILValue emitProjectTopLevelSubcontext(
410
    SILBuilder &builder, SILLocation loc, SILValue context,
411
204
    SILType subcontextType) {
412
204
  assert(context->getOwnershipKind() == OwnershipKind::Guaranteed);
413
0
  auto &ctx = builder.getASTContext();
414
204
  auto id = ctx.getIdentifier(
415
204
      getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext));
416
204
  assert(context->getType() == SILType::getNativeObjectType(ctx));
417
0
  auto *subcontextAddr = builder.createBuiltin(
418
204
      loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context});
419
204
  return builder.createPointerToAddress(
420
204
      loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true);
421
204
}
422
423
//===----------------------------------------------------------------------===//
424
// Utilities for looking up derivatives of functions
425
//===----------------------------------------------------------------------===//
426
427
/// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one,
428
/// returns `nullptr`.
429
6.34k
static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) {
430
6.34k
  auto *DC = F->getDeclContext();
431
6.34k
  if (!DC)
432
88
    return nullptr;
433
6.25k
  auto *D = DC->getAsDecl();
434
6.25k
  if (!D)
435
1.49k
    return nullptr;
436
4.76k
  return dyn_cast<AbstractFunctionDecl>(D);
437
6.25k
}
438
439
SILDifferentiabilityWitness *
440
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
441
                                 IndexSubset *parameterIndices,
442
22.5k
                                 IndexSubset *resultIndices) {
443
22.5k
  for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction(
444
22.5k
           original->getName())) {
445
18.4k
    if (w->getParameterIndices() == parameterIndices &&
446
18.4k
        w->getResultIndices() == resultIndices)
447
16.2k
      return w;
448
18.4k
  }
449
6.34k
  return nullptr;
450
22.5k
}
451
452
llvm::Optional<AutoDiffConfig>
453
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
454
                                   IndexSubset *parameterIndices,
455
5.43k
                                   IndexSubset *&minimalASTParameterIndices) {
456
5.43k
  llvm::Optional<AutoDiffConfig> minimalConfig = llvm::None;
457
5.43k
  auto configs = original->getDerivativeFunctionConfigurations();
458
5.43k
  for (auto &config : configs) {
459
3.85k
    auto *silParameterIndices = autodiff::getLoweredParameterIndices(
460
3.85k
        config.parameterIndices,
461
3.85k
        original->getInterfaceType()->castTo<AnyFunctionType>());
462
463
3.85k
    if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) {
464
0
      assert(original->getCaptureInfo().hasLocalCaptures());
465
0
      silParameterIndices =
466
0
        silParameterIndices->extendingCapacity(original->getASTContext(),
467
0
                                               parameterIndices->getCapacity());
468
0
    }
469
470
    // If all indices in `parameterIndices` are in `daParameterIndices`, and
471
    // it has fewer indices than our current candidate and a primitive VJP,
472
    // then `attr` is our new candidate.
473
    //
474
    // NOTE(TF-642): `attr` may come from a un-partial-applied function and
475
    // have larger capacity than the desired indices. We expect this logic to
476
    // go away when `partial_apply` supports `@differentiable` callees.
477
3.85k
    if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
478
3.85k
            original->getASTContext(), silParameterIndices->getCapacity())) &&
479
        // fewer parameters than before
480
3.85k
        (!minimalConfig ||
481
3.63k
         silParameterIndices->getNumIndices() <
482
3.61k
             minimalConfig->parameterIndices->getNumIndices())) {
483
3.61k
      minimalASTParameterIndices = config.parameterIndices;
484
3.61k
      minimalConfig =
485
3.61k
          AutoDiffConfig(silParameterIndices, config.resultIndices,
486
3.61k
                         autodiff::getDifferentiabilityWitnessGenericSignature(
487
3.61k
                             original->getGenericSignature(),
488
3.61k
                             config.derivativeGenericSignature));
489
3.61k
    }
490
3.85k
  }
491
5.43k
  return minimalConfig;
492
5.43k
}
493
494
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
495
    SILModule &module, SILFunction *original, DifferentiabilityKind kind,
496
6.34k
    IndexSubset *parameterIndices, IndexSubset *resultIndices) {
497
  // Explicit differentiability witnesses only exist on SIL functions that come
498
  // from AST functions.
499
6.34k
  auto *originalAFD = findAbstractFunctionDecl(original);
500
6.34k
  if (!originalAFD)
501
1.58k
    return nullptr;
502
503
4.76k
  IndexSubset *minimalASTParameterIndices = nullptr;
504
4.76k
  auto minimalConfig = findMinimalDerivativeConfiguration(
505
4.76k
      originalAFD, parameterIndices, minimalASTParameterIndices);
506
4.76k
  if (!minimalConfig)
507
1.82k
    return nullptr;
508
509
2.94k
  std::string originalName = original->getName().str();
510
  // If original function requires a foreign entry point, use the foreign SIL
511
  // function to get or create the minimal differentiability witness.
512
2.94k
  if (requiresForeignEntryPoint(originalAFD)) {
513
304
    originalName = SILDeclRef(originalAFD).asForeign().mangle();
514
304
    original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign());
515
304
  }
516
517
2.94k
  auto *existingWitness = module.lookUpDifferentiabilityWitness(
518
2.94k
      {originalName, kind, *minimalConfig});
519
2.94k
  if (existingWitness)
520
1.90k
    return existingWitness;
521
522
1.03k
  assert(original->isExternalDeclaration() &&
523
1.03k
         "SILGen should create differentiability witnesses for all function "
524
1.03k
         "definitions with explicit differentiable attributes");
525
526
0
  return SILDifferentiabilityWitness::createDeclaration(
527
1.03k
      module, SILLinkage::PublicExternal, original, kind,
528
1.03k
      minimalConfig->parameterIndices, minimalConfig->resultIndices,
529
1.03k
      minimalConfig->derivativeGenericSignature);
530
2.94k
}
531
532
} // end namespace autodiff
533
} // end namespace swift