Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/AST/AutoDiff.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- AutoDiff.cpp - Swift automatic differentiation utilities ---------===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "swift/AST/AutoDiff.h"
14
#include "swift/AST/ASTContext.h"
15
#include "swift/AST/GenericEnvironment.h"
16
#include "swift/AST/ImportCache.h"
17
#include "swift/AST/Module.h"
18
#include "swift/AST/TypeCheckRequests.h"
19
#include "swift/AST/Types.h"
20
21
using namespace swift;
22
23
AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind(
24
48
    StringRef string) {
25
48
  llvm::Optional<innerty> result =
26
48
      llvm::StringSwitch<llvm::Optional<innerty>>(string)
27
48
          .Case("jvp", JVP)
28
48
          .Case("vjp", VJP);
29
48
  assert(result && "Invalid string");
30
0
  rawValue = *result;
31
48
}
32
33
NormalDifferentiableFunctionTypeComponent::
34
    NormalDifferentiableFunctionTypeComponent(
35
6.63k
        AutoDiffDerivativeFunctionKind kind) {
36
6.63k
  switch (kind) {
37
3.31k
  case AutoDiffDerivativeFunctionKind::JVP:
38
3.31k
    rawValue = JVP;
39
3.31k
    return;
40
3.32k
  case AutoDiffDerivativeFunctionKind::VJP:
41
3.32k
    rawValue = VJP;
42
3.32k
    return;
43
6.63k
  }
44
6.63k
}
45
46
NormalDifferentiableFunctionTypeComponent::
47
144
    NormalDifferentiableFunctionTypeComponent(StringRef string) {
48
144
  llvm::Optional<innerty> result =
49
144
      llvm::StringSwitch<llvm::Optional<innerty>>(string)
50
144
          .Case("original", Original)
51
144
          .Case("jvp", JVP)
52
144
          .Case("vjp", VJP);
53
144
  assert(result && "Invalid string");
54
0
  rawValue = *result;
55
144
}
56
57
llvm::Optional<AutoDiffDerivativeFunctionKind>
58
29.4k
NormalDifferentiableFunctionTypeComponent::getAsDerivativeFunctionKind() const {
59
29.4k
  switch (rawValue) {
60
3.19k
  case Original:
61
3.19k
    return llvm::None;
62
10.8k
  case JVP:
63
10.8k
    return {AutoDiffDerivativeFunctionKind::JVP};
64
15.4k
  case VJP:
65
15.4k
    return {AutoDiffDerivativeFunctionKind::VJP};
66
29.4k
  }
67
0
  llvm_unreachable("invalid derivative kind");
68
0
}
69
70
LinearDifferentiableFunctionTypeComponent::
71
24
    LinearDifferentiableFunctionTypeComponent(StringRef string) {
72
24
  llvm::Optional<innerty> result =
73
24
      llvm::StringSwitch<llvm::Optional<innerty>>(string)
74
24
          .Case("original", Original)
75
24
          .Case("transpose", Transpose);
76
24
  assert(result && "Invalid string");
77
0
  rawValue = *result;
78
24
}
79
80
DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
81
208
    StringRef string) {
82
208
  llvm::Optional<innerty> result =
83
208
      llvm::StringSwitch<llvm::Optional<innerty>>(string)
84
208
          .Case("jvp", JVP)
85
208
          .Case("vjp", VJP)
86
208
          .Case("transpose", Transpose);
87
208
  assert(result && "Invalid string");
88
0
  rawValue = *result;
89
208
}
90
91
llvm::Optional<AutoDiffDerivativeFunctionKind>
92
85.4k
DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
93
85.4k
  switch (rawValue) {
94
42.4k
  case JVP:
95
42.4k
    return {AutoDiffDerivativeFunctionKind::JVP};
96
42.6k
  case VJP:
97
42.6k
    return {AutoDiffDerivativeFunctionKind::VJP};
98
336
  case Transpose:
99
336
    return llvm::None;
100
85.4k
  }
101
0
  llvm_unreachable("invalid derivative kind");
102
0
}
103
104
616
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
105
616
  s << "(parameters=";
106
616
  parameterIndices->print(s);
107
616
  s << " results=";
108
616
  resultIndices->print(s);
109
616
  if (derivativeGenericSignature) {
110
112
    s << " where=";
111
112
    derivativeGenericSignature->print(s);
112
112
  }
113
616
  s << ')';
114
616
}
115
116
70.9k
bool swift::isDifferentiableProgrammingEnabled(SourceFile &SF) {
117
70.9k
  auto &ctx = SF.getASTContext();
118
  // Return true if differentiable programming is explicitly enabled.
119
70.9k
  if (ctx.LangOpts.hasFeature(Feature::DifferentiableProgramming))
120
0
    return true;
121
  // Otherwise, return true iff the `_Differentiation` module is imported in
122
  // the given source file.
123
70.9k
  bool importsDifferentiationModule = false;
124
501k
  for (auto import : namelookup::getAllImports(&SF)) {
125
501k
    if (import.importedModule->getName() == ctx.Id_Differentiation) {
126
18.5k
      importsDifferentiationModule = true;
127
18.5k
      break;
128
18.5k
    }
129
501k
  }
130
70.9k
  return importsDifferentiationModule;
131
70.9k
}
132
133
// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
134
// most once (for curried method types) is sufficient.
135
static void unwrapCurryLevels(AnyFunctionType *fnTy,
136
39.0k
                              SmallVectorImpl<AnyFunctionType *> &results) {
137
89.5k
  while (fnTy != nullptr) {
138
50.4k
    results.push_back(fnTy);
139
50.4k
    fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
140
50.4k
  }
141
39.0k
}
142
143
19.3k
static unsigned countNumFlattenedElementTypes(Type type) {
144
19.3k
  if (auto *tupleTy = type->getCanonicalType()->getAs<TupleType>())
145
20
    return accumulate(tupleTy->getElementTypes(), 0,
146
40
                      [&](unsigned num, Type type) {
147
40
                        return num + countNumFlattenedElementTypes(type);
148
40
                      });
149
19.3k
  return 1;
150
19.3k
}
151
152
// TODO(TF-874): Simplify this helper and remove the `reverseCurryLevels` flag.
153
void AnyFunctionType::getSubsetParameters(
154
    IndexSubset *parameterIndices,
155
29.7k
    SmallVectorImpl<AnyFunctionType::Param> &results, bool reverseCurryLevels) {
156
29.7k
  SmallVector<AnyFunctionType *, 2> curryLevels;
157
29.7k
  unwrapCurryLevels(this, curryLevels);
158
159
29.7k
  SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
160
29.7k
  unsigned currentOffset = 0;
161
34.1k
  for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
162
34.1k
    curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
163
34.1k
    currentOffset += curryLevels[curryLevelIndex]->getNumParams();
164
34.1k
  }
165
166
  // If `reverseCurryLevels` is true, reverse the curry levels and offsets.
167
29.7k
  if (reverseCurryLevels) {
168
25.0k
    std::reverse(curryLevels.begin(), curryLevels.end());
169
25.0k
    std::reverse(curryLevelParameterIndexOffsets.begin(),
170
25.0k
                 curryLevelParameterIndexOffsets.end());
171
25.0k
  }
172
173
34.1k
  for (unsigned curryLevelIndex : indices(curryLevels)) {
174
34.1k
    auto *curryLevel = curryLevels[curryLevelIndex];
175
34.1k
    unsigned parameterIndexOffset =
176
34.1k
        curryLevelParameterIndexOffsets[curryLevelIndex];
177
34.1k
    for (unsigned paramIndex : range(curryLevel->getNumParams()))
178
45.3k
      if (parameterIndices->contains(parameterIndexOffset + paramIndex))
179
41.8k
        results.push_back(curryLevel->getParams()[paramIndex]);
180
34.1k
  }
181
29.7k
}
182
183
void autodiff::getFunctionSemanticResults(
184
    const AnyFunctionType *functionType,
185
    const IndexSubset *parameterIndices,
186
39.9k
    SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) {
187
39.9k
  auto &ctx = functionType->getASTContext();
188
189
  // Collect formal result type as a semantic result, unless it is
190
  // `Void`.
191
39.9k
  auto formalResultType = functionType->getResult();
192
39.9k
  if (auto *resultFunctionType =
193
39.9k
      functionType->getResult()->getAs<AnyFunctionType>())
194
11.2k
    formalResultType = resultFunctionType->getResult();
195
196
39.9k
  unsigned resultIdx = 0;
197
39.9k
  if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) {
198
    // Separate tuple elements into individual results.
199
38.1k
    if (formalResultType->is<TupleType>()) {
200
416
      for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
201
416
        resultTypes.emplace_back(elt.getType(), resultIdx++,
202
416
                                 /*isParameter*/ false);
203
416
      }
204
37.9k
    } else {
205
37.9k
      resultTypes.emplace_back(formalResultType, resultIdx++,
206
37.9k
                               /*isParameter*/ false);
207
37.9k
    }
208
38.1k
  }
209
210
  // Collect wrt semantic result (`inout`) parameters as
211
  // semantic results
212
39.9k
  auto collectSemanticResults = [&](const AnyFunctionType *functionType,
213
51.1k
                                    unsigned curryOffset = 0) {
214
63.8k
    for (auto paramAndIndex : enumerate(functionType->getParams())) {
215
63.8k
      if (!paramAndIndex.value().isAutoDiffSemanticResult())
216
61.7k
        continue;
217
218
2.12k
      unsigned idx = paramAndIndex.index() + curryOffset;
219
2.12k
      assert(idx < parameterIndices->getCapacity() &&
220
2.12k
             "invalid parameter index");
221
2.12k
      if (parameterIndices->contains(idx))
222
2.00k
        resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
223
2.00k
                                 resultIdx, /*isParameter*/ true);
224
2.12k
      resultIdx += 1;
225
2.12k
    }
226
51.1k
  };
227
228
39.9k
  if (auto *resultFnType =
229
39.9k
      functionType->getResult()->getAs<AnyFunctionType>()) {
230
    // Here we assume that the input is a function type with curried `Self`
231
11.2k
    assert(functionType->getNumParams() == 1 && "unexpected function type");
232
233
0
    collectSemanticResults(resultFnType);
234
11.2k
    collectSemanticResults(functionType, resultFnType->getNumParams());
235
11.2k
  } else
236
28.7k
    collectSemanticResults(functionType);
237
39.9k
}
238
239
IndexSubset *
240
autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType,
241
10.7k
                                           const IndexSubset *parameterIndices) {
242
10.7k
  auto &ctx = functionType->getASTContext();
243
244
10.7k
  SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
245
10.7k
  autodiff::getFunctionSemanticResults(functionType, parameterIndices,
246
10.7k
                                       semanticResults);
247
10.7k
  SmallVector<unsigned> resultIndices;
248
10.7k
  unsigned cap = 0;
249
11.0k
  for (const auto& result : semanticResults) {
250
11.0k
    resultIndices.push_back(result.index);
251
11.0k
    cap = std::max(cap, result.index + 1U);
252
11.0k
  }
253
254
10.7k
  return IndexSubset::get(ctx, cap, resultIndices);
255
10.7k
}
256
257
IndexSubset *
258
autodiff::getFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD,
259
7.34k
                                           const IndexSubset *parameterIndices) {
260
7.34k
  return getFunctionSemanticResultIndices(AFD->getInterfaceType()->castTo<AnyFunctionType>(),
261
7.34k
                                          parameterIndices);
262
7.34k
}
263
264
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
265
IndexSubset *
266
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,
267
9.36k
                                     AnyFunctionType *functionType) {
268
9.36k
  SmallVector<AnyFunctionType *, 2> curryLevels;
269
9.36k
  unwrapCurryLevels(functionType, curryLevels);
270
271
  // Compute the lowered sizes of all AST parameter types.
272
9.36k
  SmallVector<unsigned, 8> paramLoweredSizes;
273
9.36k
  unsigned totalLoweredSize = 0;
274
19.3k
  auto addLoweredParamInfo = [&](Type type) {
275
19.3k
    unsigned paramLoweredSize = countNumFlattenedElementTypes(type);
276
19.3k
    paramLoweredSizes.push_back(paramLoweredSize);
277
19.3k
    totalLoweredSize += paramLoweredSize;
278
19.3k
  };
279
9.36k
  for (auto *curryLevel : llvm::reverse(curryLevels))
280
16.3k
    for (auto &param : curryLevel->getParams())
281
19.3k
      addLoweredParamInfo(param.getPlainType());
282
283
  // Build lowered SIL parameter indices by setting the range of bits that
284
  // corresponds to each "set" AST parameter.
285
9.36k
  llvm::SmallVector<unsigned, 8> loweredSILIndices;
286
9.36k
  unsigned currentBitIndex = 0;
287
19.3k
  for (unsigned i : range(parameterIndices->getCapacity())) {
288
19.3k
    auto paramLoweredSize = paramLoweredSizes[i];
289
19.3k
    if (parameterIndices->contains(i)) {
290
14.2k
      auto indices = range(currentBitIndex, currentBitIndex + paramLoweredSize);
291
14.2k
      loweredSILIndices.append(indices.begin(), indices.end());
292
14.2k
    }
293
19.3k
    currentBitIndex += paramLoweredSize;
294
19.3k
  }
295
296
9.36k
  return IndexSubset::get(functionType->getASTContext(), totalLoweredSize,
297
9.36k
                          loweredSILIndices);
298
9.36k
}
299
300
/// Collects the semantic results of the given function type in
301
/// `originalResults`. The semantic results are formal results followed by
302
/// semantic result parameters, in type order.
303
void
304
autodiff::getSemanticResults(SILFunctionType *functionType,
305
                             IndexSubset *parameterIndices,
306
29.5k
                             SmallVectorImpl<SILResultInfo> &originalResults) {
307
  // Collect original formal results.
308
29.5k
  originalResults.append(functionType->getResults().begin(),
309
29.5k
                         functionType->getResults().end());
310
311
  // Collect original semantic result parameters.
312
51.3k
  for (auto i : range(functionType->getNumParameters())) {
313
51.3k
    auto param = functionType->getParameters()[i];
314
51.3k
    if (!param.isAutoDiffSemanticResult())
315
49.3k
      continue;
316
1.94k
    if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable)
317
1.94k
      originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect);
318
1.94k
  }
319
29.5k
}
320
321
GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
322
    SILFunctionType *originalFnTy,
323
    IndexSubset *diffParamIndices, IndexSubset *diffResultIndices,
324
    GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
325
11.0k
    bool isTranspose) {
326
11.0k
  if (!derivativeGenSig)
327
3.54k
    derivativeGenSig = originalFnTy->getInvocationGenericSignature();
328
11.0k
  if (!derivativeGenSig)
329
3.02k
    return nullptr;
330
8.06k
  auto &ctx = originalFnTy->getASTContext();
331
8.06k
  auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
332
8.06k
  SmallVector<Requirement, 4> requirements;
333
334
19.6k
  auto addRequirement = [&](CanType type) {
335
19.6k
    Requirement req(RequirementKind::Conformance, type,
336
19.6k
                    diffableProto->getDeclaredInterfaceType());
337
19.6k
    requirements.push_back(req);
338
19.6k
    if (isTranspose) {
339
      // Require linearity parameters to additionally satisfy
340
      // `Self == Self.TangentVector`.
341
90
      auto tanSpace = type->getAutoDiffTangentSpace(lookupConformance);
342
90
      auto tanType = tanSpace->getCanonicalType();
343
90
      Requirement req(RequirementKind::SameType, type, tanType);
344
90
      requirements.push_back(req);
345
90
    }
346
19.6k
  };
347
348
  // Require differentiability parameters to conform to `Differentiable`.
349
11.6k
  for (unsigned paramIdx : diffParamIndices->getIndices()) {
350
11.6k
    auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
351
11.6k
    addRequirement(paramType);
352
11.6k
  }
353
354
  // Require differentiability results to conform to `Differentiable`.
355
8.06k
  SmallVector<SILResultInfo, 2> originalResults;
356
8.06k
  getSemanticResults(originalFnTy, diffParamIndices, originalResults);
357
8.06k
  for (unsigned resultIdx : diffResultIndices->getIndices()) {
358
    // Handle formal original result.
359
8.03k
    if (resultIdx < originalFnTy->getNumResults()) {
360
7.58k
      auto resultType = originalResults[resultIdx].getInterfaceType();
361
7.58k
      addRequirement(resultType);
362
7.58k
      continue;
363
7.58k
    }
364
    // Handle original semantic result parameters.
365
    // FIXME: Constraint generic yields when we will start supporting them
366
456
    auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
367
456
    auto resultParamIt = std::next(
368
456
      originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
369
456
      resultParamIndex);
370
456
    auto paramIndex =
371
456
      std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
372
456
    addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
373
456
  }
374
375
8.06k
  return buildGenericSignature(ctx, derivativeGenSig,
376
8.06k
                               /*addedGenericParams*/ {},
377
8.06k
                               std::move(requirements));
378
11.0k
}
379
380
// Given the rest of a `Builtin.applyDerivative_{jvp|vjp}` or
381
// `Builtin.applyTranspose` operation name, attempts to parse the arity and
382
// throwing-ness from the operation name. Modifies the operation name argument
383
// in place as substrings get dropped.
384
static void parseAutoDiffBuiltinCommonConfig(
385
28
    StringRef &operationName, unsigned &arity, bool &throws) {
386
  // Parse '_arity'.
387
28
  constexpr char arityPrefix[] = "_arity";
388
28
  if (operationName.startswith(arityPrefix)) {
389
8
    operationName = operationName.drop_front(sizeof(arityPrefix) - 1);
390
8
    auto arityStr = operationName.take_while(llvm::isDigit);
391
8
    operationName = operationName.drop_front(arityStr.size());
392
8
    auto converted = llvm::to_integer(arityStr, arity);
393
8
    assert(converted); (void)converted;
394
8
    assert(arity > 0);
395
20
  } else {
396
20
    arity = 1;
397
20
  }
398
  // Parse '_throws'.
399
0
  constexpr char throwsPrefix[] = "_throws";
400
28
  if (operationName.startswith(throwsPrefix)) {
401
0
    operationName = operationName.drop_front(sizeof(throwsPrefix) - 1);
402
0
    throws = true;
403
28
  } else {
404
28
    throws = false;
405
28
  }
406
28
}
407
408
bool autodiff::getBuiltinApplyDerivativeConfig(
409
    StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
410
28
    unsigned &arity, bool &throws) {
411
28
  constexpr char prefix[] = "applyDerivative";
412
28
  if (!operationName.startswith(prefix))
413
0
    return false;
414
28
  operationName = operationName.drop_front(sizeof(prefix) - 1);
415
  // Parse 'jvp' or 'vjp'.
416
28
  constexpr char jvpPrefix[] = "_jvp";
417
28
  constexpr char vjpPrefix[] = "_vjp";
418
28
  if (operationName.startswith(jvpPrefix))
419
8
    kind = AutoDiffDerivativeFunctionKind::JVP;
420
20
  else if (operationName.startswith(vjpPrefix))
421
20
    kind = AutoDiffDerivativeFunctionKind::VJP;
422
28
  operationName = operationName.drop_front(sizeof(jvpPrefix) - 1);
423
28
  parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
424
28
  return operationName.empty();
425
28
}
426
427
bool autodiff::getBuiltinApplyTransposeConfig(
428
0
    StringRef operationName, unsigned &arity, bool &throws) {
429
0
  constexpr char prefix[] = "applyTranspose";
430
0
  if (!operationName.startswith(prefix))
431
0
    return false;
432
0
  operationName = operationName.drop_front(sizeof(prefix) - 1);
433
0
  parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
434
0
  return operationName.empty();
435
0
}
436
437
bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
438
0
    StringRef operationName, unsigned &arity, bool &throws) {
439
0
  constexpr char differentiablePrefix[] = "differentiableFunction";
440
0
  constexpr char linearPrefix[] = "linearFunction";
441
0
  if (operationName.startswith(differentiablePrefix))
442
0
    operationName = operationName.drop_front(sizeof(differentiablePrefix) - 1);
443
0
  else if (operationName.startswith(linearPrefix))
444
0
    operationName = operationName.drop_front(sizeof(linearPrefix) - 1);
445
0
  else
446
0
    return false;
447
0
  parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
448
0
  return operationName.empty();
449
0
}
450
451
GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature(
452
7.50k
    GenericSignature origGenSig, GenericSignature derivativeGenSig) {
453
  // If there is no derivative generic signature, return the original generic
454
  // signature.
455
7.50k
  if (!derivativeGenSig)
456
5.17k
    return origGenSig;
457
  // If derivative generic signature has all concrete generic parameters and is
458
  // equal to the original generic signature, return `nullptr`.
459
2.33k
  auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
460
2.33k
  auto origCanGenSig = origGenSig.getCanonicalSignature();
461
2.33k
  if (origCanGenSig == derivativeCanGenSig &&
462
2.33k
      derivativeCanGenSig->areAllParamsConcrete())
463
100
    return GenericSignature();
464
  // Otherwise, return the derivative generic signature.
465
2.23k
  return derivativeGenSig;
466
2.33k
}
467
468
317k
Type TangentSpace::getType() const {
469
317k
  switch (kind) {
470
314k
  case Kind::TangentVector:
471
314k
    return value.tangentVectorType;
472
3.48k
  case Kind::Tuple:
473
3.48k
    return value.tupleType;
474
317k
  }
475
0
  llvm_unreachable("invalid tangent space kind");
476
0
}
477
478
230k
CanType TangentSpace::getCanonicalType() const {
479
230k
  return getType()->getCanonicalType();
480
230k
}
481
482
0
NominalTypeDecl *TangentSpace::getNominal() const {
483
0
  assert(isTangentVector());
484
0
  return getTangentVector()->getNominalOrBoundGenericNominal();
485
0
}
486
487
const char DerivativeFunctionTypeError::ID = '\0';
488
489
0
void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
490
0
  OS << "original function type '";
491
0
  functionType->print(OS);
492
0
  OS << "' ";
493
0
  switch (kind) {
494
0
  case Kind::NoSemanticResults:
495
0
    OS << "has no semantic results ('Void' result)";
496
0
    break;
497
0
  case Kind::NoDifferentiabilityParameters:
498
0
    OS << "has no differentiability parameters";
499
0
    break;
500
0
  case Kind::NonDifferentiableDifferentiabilityParameter: {
501
0
    auto nonDiffParam = getNonDifferentiableTypeAndIndex();
502
0
    OS << "has non-differentiable differentiability parameter "
503
0
       << nonDiffParam.second << ": " << nonDiffParam.first;
504
0
    break;
505
0
  }
506
0
  case Kind::NonDifferentiableResult: {
507
0
    auto nonDiffResult = getNonDifferentiableTypeAndIndex();
508
0
    OS << "has non-differentiable result " << nonDiffResult.second << ": "
509
0
       << nonDiffResult.first;
510
0
    break;
511
0
  }
512
0
  }
513
0
}
514
515
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
516
0
                                     const DeclNameRefWithLoc &name) {
517
0
  os << name.Name;
518
0
  if (auto accessorKind = name.AccessorKind)
519
0
    os << '.' << getAccessorLabel(*accessorKind);
520
0
  return os;
521
0
}
522
523
bool swift::operator==(const TangentPropertyInfo::Error &lhs,
524
0
                       const TangentPropertyInfo::Error &rhs) {
525
0
  if (lhs.kind != rhs.kind)
526
0
    return false;
527
0
  switch (lhs.kind) {
528
0
  case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
529
0
  case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
530
0
  case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
531
0
  case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
532
0
  case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
533
0
  case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
534
0
    return true;
535
0
  case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
536
0
    return lhs.getType()->isEqual(rhs.getType());
537
0
  }
538
0
  llvm_unreachable("unhandled tangent property!");
539
0
}
540
541
0
void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) {
542
0
  os << "{ ";
543
0
  os << "tangent property: "
544
0
     << (info.tangentProperty ? info.tangentProperty->printRef() : "null");
545
0
  if (info.error) {
546
0
    os << ", error: ";
547
0
    switch (info.error->kind) {
548
0
    case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
549
0
      os << "'@noDerivative' original property has no tangent property";
550
0
      break;
551
0
    case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
552
0
      os << "nominal parent does not conform to 'Differentiable'";
553
0
      break;
554
0
    case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
555
0
      os << "original property type does not conform to 'Differentiable'";
556
0
      break;
557
0
    case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
558
0
      os << "'TangentVector' type is not a struct";
559
0
      break;
560
0
    case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
561
0
      os << "'TangentVector' struct does not have stored property with the "
562
0
            "same name as the original property";
563
0
      break;
564
0
    case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
565
0
      os << "tangent property's type is not equal to the original property's "
566
0
            "'TangentVector' type";
567
0
      break;
568
0
    case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
569
0
      os << "'TangentVector' property '" << info.tangentProperty->getName()
570
0
         << "' is not a stored property";
571
0
      break;
572
0
    }
573
0
  }
574
0
  os << " }";
575
0
}
576
577
TangentPropertyInfo TangentStoredPropertyRequest::evaluate(
578
748
    Evaluator &evaluator, VarDecl *originalField, CanType baseType) const {
579
748
  assert(((originalField->hasStorage() && originalField->isInstanceMember()) ||
580
748
          originalField->hasAttachedPropertyWrapper()) &&
581
748
         "Expected a stored property or a property-wrapped property");
582
0
  auto *parentDC = originalField->getDeclContext();
583
748
  assert(parentDC->isTypeContext());
584
0
  auto *moduleDecl = originalField->getModuleContext();
585
748
  auto parentTan =
586
748
      baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl));
587
  // Error if parent nominal type does not conform to `Differentiable`.
588
748
  if (!parentTan) {
589
0
    return TangentPropertyInfo(
590
0
        TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable);
591
0
  }
592
  // Error if original stored property is `@noDerivative`.
593
748
  if (originalField->getAttrs().hasAttribute<NoDerivativeAttr>()) {
594
0
    return TangentPropertyInfo(
595
0
        TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty);
596
0
  }
597
  // Error if original property's type does not conform to `Differentiable`.
598
748
  auto originalFieldType = baseType->getTypeOfMember(
599
748
      originalField->getModuleContext(), originalField);
600
748
  auto originalFieldTan = originalFieldType->getAutoDiffTangentSpace(
601
748
      LookUpConformanceInModule(moduleDecl));
602
748
  if (!originalFieldTan) {
603
8
    return TangentPropertyInfo(
604
8
        TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable);
605
8
  }
606
  // Get the parent `TangentVector` type.
607
740
  auto parentTanType =
608
740
      baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl))
609
740
          ->getType();
610
740
  auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct();
611
  // Error if parent `TangentVector` is not a struct.
612
740
  if (!parentTanStruct) {
613
8
    return TangentPropertyInfo(
614
8
        TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct);
615
8
  }
616
  // Find the corresponding field in the tangent space.
617
732
  VarDecl *tanField = nullptr;
618
  // If `TangentVector` is the original struct, then the tangent property is the
619
  // original property.
620
732
  if (parentTanStruct == parentDC->getSelfStructDecl()) {
621
148
    tanField = originalField;
622
148
  }
623
  // Otherwise, look up the field by name.
624
584
  else {
625
584
    auto tanFieldLookup =
626
584
        parentTanStruct->lookupDirect(originalField->getName());
627
584
    llvm::erase_if(tanFieldLookup,
628
584
                   [](ValueDecl *v) { return !isa<VarDecl>(v); });
629
    // Error if tangent property could not be found.
630
584
    if (tanFieldLookup.empty()) {
631
12
      return TangentPropertyInfo(
632
12
          TangentPropertyInfo::Error::Kind::TangentPropertyNotFound);
633
12
    }
634
572
    tanField = cast<VarDecl>(tanFieldLookup.front());
635
572
  }
636
  // Error if tangent property's type is not equal to the original property's
637
  // `TangentVector` type.
638
720
  auto originalFieldTanType = originalFieldTan->getType();
639
720
  auto tanFieldType =
640
720
      parentTanType->getTypeOfMember(tanField->getModuleContext(), tanField);
641
720
  if (!originalFieldTanType->isEqual(tanFieldType)) {
642
12
    return TangentPropertyInfo(
643
12
        TangentPropertyInfo::Error::Kind::TangentPropertyWrongType,
644
12
        originalFieldTanType);
645
12
  }
646
  // Error if tangent property is not a stored property.
647
708
  if (!tanField->hasStorage()) {
648
12
    return TangentPropertyInfo(
649
12
        TangentPropertyInfo::Error::Kind::TangentPropertyNotStored);
650
12
  }
651
  // Otherwise, tangent property is valid.
652
696
  return TangentPropertyInfo(tanField);
653
708
}
654
655
0
void SILDifferentiabilityWitnessKey::print(llvm::raw_ostream &s) const {
656
0
  s << "(original=@" << originalFunctionName << " kind=";
657
0
  switch (kind) {
658
0
  case DifferentiabilityKind::NonDifferentiable:
659
0
    s << "nondifferentiable";
660
0
    break;
661
0
  case DifferentiabilityKind::Forward:
662
0
    s << "forward";
663
0
    break;
664
0
  case DifferentiabilityKind::Reverse:
665
0
    s << "reverse";
666
0
    break;
667
0
  case DifferentiabilityKind::Normal:
668
0
    s << "normal";
669
0
    break;
670
0
  case DifferentiabilityKind::Linear:
671
0
    s << "linear";
672
0
    break;
673
0
  }
674
0
  s << " config=" << config << ')';
675
0
}
676
677
Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind(
678
14.4k
    AutoDiffDerivativeFunctionKind kind) {
679
14.4k
  switch (kind) {
680
7.24k
  case AutoDiffDerivativeFunctionKind::JVP:
681
7.24k
    return Demangle::AutoDiffFunctionKind::JVP;
682
7.20k
  case AutoDiffDerivativeFunctionKind::VJP: return Demangle::AutoDiffFunctionKind::VJP;
683
14.4k
  }
684
14.4k
}
685
686
Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind(
687
7.44k
    AutoDiffLinearMapKind kind) {
688
7.44k
  switch (kind) {
689
1.72k
  case AutoDiffLinearMapKind::Differential:
690
1.72k
    return Demangle::AutoDiffFunctionKind::Differential;
691
5.72k
  case AutoDiffLinearMapKind::Pullback:
692
5.72k
    return Demangle::AutoDiffFunctionKind::Pullback;
693
7.44k
  }
694
7.44k
}
695
696
Demangle::MangledDifferentiabilityKind
697
20.8k
Demangle::getMangledDifferentiabilityKind(DifferentiabilityKind kind) {
698
20.8k
  using namespace Demangle;
699
20.8k
  switch (kind) {
700
0
  #define SIMPLE_CASE(CASE) \
701
20.8k
    case DifferentiabilityKind::CASE: return MangledDifferentiabilityKind::CASE;
702
0
  SIMPLE_CASE(NonDifferentiable)
703
0
  SIMPLE_CASE(Forward)
704
20.8k
  SIMPLE_CASE(Reverse)
705
0
  SIMPLE_CASE(Normal)
706
15
  SIMPLE_CASE(Linear)
707
20.8k
  #undef SIMPLE_CASE
708
20.8k
  }
709
20.8k
}