Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/Sema/DerivedConformanceDifferentiable.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- DerivedConformanceDifferentiable.cpp - Derived Differentiable ----===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// This file implements explicit derivation of the Differentiable protocol for
14
// struct and class types.
15
//
16
//===----------------------------------------------------------------------===//
17
18
#include "CodeSynthesis.h"
19
#include "TypeChecker.h"
20
#include "TypeCheckType.h"
21
#include "llvm/ADT/SmallPtrSet.h"
22
#include "swift/AST/AutoDiff.h"
23
#include "swift/AST/Decl.h"
24
#include "swift/AST/Expr.h"
25
#include "swift/AST/Module.h"
26
#include "swift/AST/ParameterList.h"
27
#include "swift/AST/Pattern.h"
28
#include "swift/AST/PropertyWrappers.h"
29
#include "swift/AST/ProtocolConformance.h"
30
#include "swift/AST/Stmt.h"
31
#include "swift/AST/TypeCheckRequests.h"
32
#include "swift/AST/Types.h"
33
#include "DerivedConformances.h"
34
35
using namespace swift;
36
37
/// Return true if `move(by:)` can be invoked on the given `Differentiable`-
38
/// conforming property.
39
///
40
/// If the given property is a `var`, return true because `move(by:)` can be
41
/// invoked regardless.  Otherwise, return true if and only if the property's
42
/// type's 'Differentiable.move(by:)' witness is non-mutating.
43
static bool canInvokeMoveByOnProperty(
44
5.49k
    VarDecl *vd, ProtocolConformanceRef diffableConformance) {
45
5.49k
  assert(diffableConformance && "Property must conform to 'Differentiable'");
46
  // `var` always supports `move(by:)` since it is mutable.
47
5.49k
  if (vd->getIntroducer() == VarDecl::Introducer::Var)
48
5.30k
    return true;
49
  // When the property is a `let`, the only case that would be supported is when
50
  // it has a `move(by:)` protocol requirement witness that is non-mutating.
51
188
  auto interfaceType = vd->getInterfaceType();
52
188
  auto &C = vd->getASTContext();
53
188
  auto witness = diffableConformance.getWitnessByName(
54
188
      interfaceType, DeclName(C, C.Id_move, {C.Id_by}));
55
188
  if (!witness)
56
0
    return false;
57
188
  auto *decl = cast<FuncDecl>(witness.getDecl());
58
188
  return !decl->isMutating();
59
188
}
60
61
/// Get the stored properties of a nominal type that are relevant for
62
/// differentiation, except the ones tagged `@noDerivative`.
63
static void
64
getStoredPropertiesForDifferentiation(
65
    NominalTypeDecl *nominal, DeclContext *DC,
66
    SmallVectorImpl<VarDecl *> &result,
67
3.60k
    bool includeLetPropertiesWithNonmutatingMoveBy = false) {
68
3.60k
  auto &C = nominal->getASTContext();
69
3.60k
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
70
5.12k
  for (auto *vd : nominal->getStoredProperties()) {
71
    // Peer through property wrappers: use original wrapped properties instead.
72
5.12k
    if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
73
      // Skip immutable wrapped properties. `mutating func move(by:)` cannot
74
      // be synthesized to update these properties.
75
668
      if (!originalProperty->isSettable(DC))
76
56
        continue;
77
      // Use the original wrapped property.
78
612
      vd = originalProperty;
79
612
    }
80
    // Skip stored properties with `@noDerivative` attribute.
81
5.07k
    if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
82
584
      continue;
83
4.48k
    if (vd->getInterfaceType()->hasError())
84
0
      continue;
85
4.48k
    auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
86
4.48k
    auto conformance = TypeChecker::conformsToProtocol(
87
4.48k
        varType, diffableProto, DC->getParentModule());
88
4.48k
    if (!conformance)
89
128
      continue;
90
    // Skip `let` stored properties with a mutating `move(by:)` if requested.
91
    // `mutating func move(by:)` cannot be synthesized to update `let`
92
    // properties.
93
4.36k
    if (!includeLetPropertiesWithNonmutatingMoveBy && 
94
4.36k
        !canInvokeMoveByOnProperty(vd, conformance))
95
40
      continue;
96
4.32k
    result.push_back(vd);
97
4.32k
  }
98
3.60k
}
99
100
/// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
101
/// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
102
914
static StructDecl *convertToStructDecl(ValueDecl *v) {
103
914
  if (auto *structDecl = dyn_cast<StructDecl>(v))
104
906
    return structDecl;
105
8
  auto *typeDecl = dyn_cast<TypeDecl>(v);
106
8
  if (!typeDecl)
107
0
    return nullptr;
108
8
  return dyn_cast_or_null<StructDecl>(
109
8
      typeDecl->getDeclaredInterfaceType()->getAnyNominal());
110
8
}
111
112
/// Get the `Differentiable` protocol `TangentVector` associated type witness
113
/// for the given interface type and declaration context.
114
static Type getTangentVectorInterfaceType(Type contextualType,
115
2.00k
                                          DeclContext *DC) {
116
2.00k
  auto &C = DC->getASTContext();
117
2.00k
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
118
2.00k
  assert(diffableProto && "`Differentiable` protocol not found");
119
0
  auto conf =
120
2.00k
      TypeChecker::conformsToProtocol(contextualType, diffableProto,
121
2.00k
                                      DC->getParentModule());
122
2.00k
  assert(conf && "Contextual type must conform to `Differentiable`");
123
2.00k
  if (!conf)
124
0
    return nullptr;
125
2.00k
  auto tanType = conf.getTypeWitnessByName(contextualType, C.Id_TangentVector);
126
2.00k
  return tanType->hasArchetype() ? tanType->mapTypeOutOfContext() : tanType;
127
2.00k
}
128
129
/// Returns true iff the given nominal type declaration can derive
130
/// `TangentVector` as `Self` in the given conformance context.
131
static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal,
132
3.08k
                                         DeclContext *DC) {
133
  // `Self` must not be a class declaration.
134
3.08k
  if (nominal->getSelfClassDecl())
135
892
    return false;
136
137
2.19k
  auto nominalTypeInContext =
138
2.19k
      DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType());
139
2.19k
  auto &C = nominal->getASTContext();
140
2.19k
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
141
2.19k
  auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
142
  // `Self` must conform to `AdditiveArithmetic`.
143
2.19k
  if (!TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto,
144
2.19k
                                       DC->getParentModule()))
145
1.85k
    return false;
146
480
  for (auto *field : nominal->getStoredProperties()) {
147
    // `Self` must not have any `@noDerivative` stored properties.
148
480
    if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
149
0
      return false;
150
    // `Self` must have all stored properties satisfy `Self == TangentVector`.
151
480
    auto fieldType = DC->mapTypeIntoContext(field->getValueInterfaceType());
152
480
    auto conf = TypeChecker::conformsToProtocol(fieldType, diffableProto,
153
480
                                                DC->getParentModule());
154
480
    if (!conf)
155
0
      return false;
156
480
    auto tangentType = conf.getTypeWitnessByName(fieldType, C.Id_TangentVector);
157
480
    if (!fieldType->isEqual(tangentType))
158
48
      return false;
159
480
  }
160
296
  return true;
161
344
}
162
163
bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
164
                                                 DeclContext *DC,
165
2.00k
                                                 ValueDecl *requirement) {
166
  // Experimental differentiable programming must be enabled.
167
2.00k
  if (auto *SF = DC->getParentSourceFile())
168
2.00k
    if (!isDifferentiableProgrammingEnabled(*SF))
169
0
      return false;
170
171
2.00k
  auto &C = nominal->getASTContext();
172
  // If there are any `TangentVector` type witness candidates, check whether
173
  // there exists only a single valid candidate.
174
2.00k
  bool canUseTangentVectorAsSelf = canDeriveTangentVectorAsSelf(nominal, DC);
175
2.00k
  auto isValidTangentVectorCandidate = [&](ValueDecl *v) -> bool {
176
    // Valid candidate must be a struct or a typealias to a struct.
177
914
    auto *structDecl = convertToStructDecl(v);
178
914
    if (!structDecl)
179
4
      return false;
180
    // Valid candidate must either:
181
    // 1. Be implicit (previously synthesized).
182
910
    if (structDecl->isImplicit())
183
898
      return true;
184
    // 2. Equal nominal, when the nominal can derive `TangentVector` as `Self`.
185
    // Nominal type must not customize `TangentVector` to anything other than
186
    // `Self`. Otherwise, synthesis is semantically unsupported.
187
12
    if (structDecl == nominal && canUseTangentVectorAsSelf)
188
0
      return true;
189
    // Otherwise, candidate is invalid.
190
12
    return false;
191
12
  };
192
2.00k
  auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector);
193
  // There can be at most one valid `TangentVector` type.
194
2.00k
  if (tangentDecls.size() > 1)
195
0
    return false;
196
  // There cannot be any invalid `TangentVector` types.
197
2.00k
  if (tangentDecls.size() == 1) {
198
914
    auto *tangentDecl = tangentDecls.front();
199
914
    if (!isValidTangentVectorCandidate(tangentDecl))
200
16
      return false;
201
914
  }
202
203
  // Check `TangentVector` struct derivation conditions.
204
  // Nominal type must be a struct or class. (No stored properties is okay.)
205
1.98k
  if (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))
206
0
    return false;
207
  // If there are no `TangentVector` candidates, derivation is possible if all
208
  // differentiation stored properties conform to `Differentiable`.
209
1.98k
  SmallVector<VarDecl *, 16> diffProperties;
210
1.98k
  getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
211
1.98k
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
212
2.38k
  return llvm::all_of(diffProperties, [&](VarDecl *v) {
213
2.38k
    if (v->getInterfaceType()->hasError())
214
0
      return false;
215
2.38k
    auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
216
2.38k
    return (bool)TypeChecker::conformsToProtocol(varType, diffableProto,
217
2.38k
                                                 DC->getParentModule());
218
2.38k
  });
219
1.98k
}
220
221
/// Synthesize body for `move(by:)`.
222
static std::pair<BraceStmt *, bool>
223
674
deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
224
674
  auto &C = funcDecl->getASTContext();
225
674
  auto *parentDC = funcDecl->getParent();
226
674
  auto *nominal = parentDC->getSelfNominalTypeDecl();
227
228
  // Get `Differentiable.move(by:)` protocol requirement.
229
674
  auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable);
230
674
  auto *requirement = getProtocolRequirement(diffProto, C.Id_move);
231
232
  // Get references to `self` and parameter declarations.
233
674
  auto *selfDecl = funcDecl->getImplicitSelfDecl();
234
674
  auto *selfDRE =
235
674
      new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
236
674
  auto *paramDecl = funcDecl->getParameters()->get(0);
237
674
  auto *paramDRE =
238
674
      new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
239
240
674
  SmallVector<VarDecl *, 8> diffProperties;
241
674
  getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
242
243
  // Create call expression applying a member `move(by:)` method to a
244
  // parameter member: `self.<member>.move(by: offset.<member>)`.
245
826
  auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
246
826
    auto *module = nominal->getModuleContext();
247
826
    auto memberType =
248
826
        parentDC->mapTypeIntoContext(member->getValueInterfaceType());
249
826
    auto confRef = module->lookupConformance(memberType, diffProto);
250
826
    assert(confRef && "Member does not conform to `Differentiable`");
251
252
    // Get member type's requirement witness: `<Member>.move(by:)`.
253
0
    ValueDecl *memberWitnessDecl = requirement;
254
826
    if (confRef.isConcrete())
255
634
      if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement))
256
634
        memberWitnessDecl = witness;
257
826
    assert(memberWitnessDecl && "Member witness declaration must exist");
258
259
    // Create reference to member method: `self.<member>.move(by:)`.
260
0
    Expr *memberExpr =
261
826
        new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
262
826
                              /*Implicit*/ true);
263
826
    auto *memberMethodExpr =
264
826
        new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl,
265
826
                              DeclNameLoc(), /*Implicit*/ true);
266
267
    // Create reference to parameter member: `offset.<member>`.
268
826
    VarDecl *paramMember = nullptr;
269
826
    auto *paramNominal = paramDecl->getTypeInContext()->getAnyNominal();
270
826
    assert(paramNominal && "Parameter should have a nominal type");
271
    // Find parameter member corresponding to returned nominal member.
272
1.23k
    for (auto *candidate : paramNominal->getStoredProperties()) {
273
1.23k
      if (candidate->getName() == member->getName()) {
274
826
        paramMember = candidate;
275
826
        break;
276
826
      }
277
1.23k
    }
278
826
    assert(paramMember && "Could not find corresponding parameter member");
279
0
    auto *paramMemberExpr =
280
826
        new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(),
281
826
                              /*Implicit*/ true);
282
    // Create expression: `self.<member>.move(by: offset.<member>)`.
283
826
    auto *args = ArgumentList::forImplicitSingle(C, C.Id_by, paramMemberExpr);
284
826
    return CallExpr::createImplicit(C, memberMethodExpr, args);
285
826
  };
286
287
  // Collect member `move(by:)` method call expressions.
288
674
  SmallVector<ASTNode, 2> memberMethodCallExprs;
289
674
  SmallVector<Identifier, 2> memberNames;
290
826
  for (auto *member : diffProperties) {
291
826
    memberMethodCallExprs.push_back(createMemberMethodCallExpr(member));
292
826
    memberNames.push_back(member->getName());
293
826
  }
294
674
  auto *braceStmt = BraceStmt::create(C, SourceLoc(), memberMethodCallExprs,
295
674
                                      SourceLoc(), true);
296
674
  return std::pair<BraceStmt *, bool>(braceStmt, false);
297
674
}
298
299
/// Synthesize function declaration for a `Differentiable` method requirement.
300
static ValueDecl *deriveDifferentiable_method(
301
    DerivedConformance &derived, Identifier methodName, Identifier argumentName,
302
    Identifier parameterName, Type parameterType, Type returnType,
303
898
    AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
304
898
  auto *nominal = derived.Nominal;
305
898
  auto &C = derived.Context;
306
898
  auto *parentDC = derived.getConformanceContext();
307
308
898
  auto *param = new (C) ParamDecl(SourceLoc(), SourceLoc(), argumentName,
309
898
                                  SourceLoc(), parameterName, parentDC);
310
898
  param->setSpecifier(ParamDecl::Specifier::Default);
311
898
  param->setInterfaceType(parameterType);
312
898
  param->setImplicit();
313
898
  ParameterList *params = ParameterList::create(C, {param});
314
315
898
  DeclName declName(C, methodName, params);
316
898
  auto *const funcDecl = FuncDecl::createImplicit(
317
898
      C, StaticSpellingKind::None, declName, /*NameLoc=*/SourceLoc(),
318
898
      /*Async=*/false,
319
898
      /*Throws=*/false,
320
898
      /*ThrownType=*/Type(),
321
898
      /*GenericParams=*/nullptr, params, returnType, parentDC);
322
898
  funcDecl->setSynthesized();
323
898
  if (!nominal->getSelfClassDecl())
324
614
    funcDecl->setSelfAccessKind(SelfAccessKind::Mutating);
325
898
  funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
326
327
898
  funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
328
898
  funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
329
330
898
  derived.addMembersToConformanceContext({funcDecl});
331
898
  return funcDecl;
332
898
}
333
334
/// Synthesize the `move(by:)` function declaration.
335
898
static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
336
898
  auto &C = derived.Context;
337
898
  auto *parentDC = derived.getConformanceContext();
338
898
  auto tangentType =
339
898
      getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC);
340
898
  return deriveDifferentiable_method(
341
898
      derived, C.Id_move, C.Id_by, C.Id_offset, tangentType,
342
898
      C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr});
343
898
}
344
345
/// Return associated `TangentVector` struct for a nominal type, if it exists.
346
/// If not, synthesize the struct.
347
static StructDecl *
348
942
getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
349
942
  auto *parentDC = derived.getConformanceContext();
350
942
  auto *nominal = derived.Nominal;
351
942
  auto &C = nominal->getASTContext();
352
353
  // If the associated struct already exists, return it.
354
942
  auto lookup = nominal->lookupDirect(C.Id_TangentVector);
355
942
  assert(lookup.size() < 2 &&
356
942
         "Expected at most one associated type named `TangentVector`");
357
942
  if (lookup.size() == 1) {
358
0
    auto *structDecl = convertToStructDecl(lookup.front());
359
0
    assert(structDecl && "Expected lookup result to be a struct");
360
0
    return structDecl;
361
0
  }
362
363
  // Otherwise, synthesize a new struct.
364
365
  // Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must
366
  // inherit, by collecting all the `TangentVector` conformance requirements imposed by the
367
  // protocols that `derived.ConformanceDecl` inherits.
368
  //
369
  // Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
370
  // the `Differentiable` protocol itself requires that its `TangentVector` conforms to
371
  // `AdditiveArithmetic` and `Differentiable`.
372
942
  llvm::SmallSetVector<ProtocolDecl *, 4> tvDesiredProtos;
373
374
942
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
375
942
  auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector);
376
377
942
  auto localProtos = cast<IterableDeclContext>(derived.ConformanceDecl)
378
942
      ->getLocalProtocols();
379
1.57k
  for (auto proto : localProtos) {
380
2.99k
    for (auto req : proto->getRequirementSignature().getRequirements()) {
381
2.99k
      if (req.getKind() != RequirementKind::Conformance)
382
942
        continue;
383
2.05k
      auto *firstType = req.getFirstType()->getAs<DependentMemberType>();
384
2.05k
      if (!firstType || firstType->getAssocType() != tvAssocType)
385
140
        continue;
386
1.91k
      tvDesiredProtos.insert(req.getProtocolDecl());
387
1.91k
    }
388
1.57k
  }
389
942
  SmallVector<InheritedEntry, 4> tvDesiredProtoInherited;
390
942
  for (auto *p : tvDesiredProtos)
391
1.91k
    tvDesiredProtoInherited.push_back(
392
1.91k
      InheritedEntry(TypeLoc::withoutLoc(p->getDeclaredInterfaceType())));
393
394
  // Cache original members and their associated types for later use.
395
942
  SmallVector<VarDecl *, 8> diffProperties;
396
942
  getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
397
398
942
  auto synthesizedLoc = derived.ConformanceDecl->getEndLoc();
399
942
  auto *structDecl =
400
942
      new (C) StructDecl(synthesizedLoc, C.Id_TangentVector, synthesizedLoc,
401
942
                         /*Inherited*/ C.AllocateCopy(tvDesiredProtoInherited),
402
942
                         /*GenericParams*/ {}, parentDC);
403
942
  structDecl->setBraces({synthesizedLoc, synthesizedLoc});
404
942
  structDecl->setImplicit();
405
942
  structDecl->setSynthesized();
406
942
  structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
407
408
  // Add stored properties to the `TangentVector` struct.
409
1.11k
  for (auto *member : diffProperties) {
410
    // Add a tangent stored property to the `TangentVector` struct, with the
411
    // name and `TangentVector` type of the original property.
412
1.11k
    auto *tangentProperty = new (C) VarDecl(
413
1.11k
        member->isStatic(), member->getIntroducer(),
414
1.11k
        /*NameLoc*/ SourceLoc(), member->getName(), structDecl);
415
    // Note: `tangentProperty` is not marked as implicit or synthesized here,
416
    // because that incorrectly affects memberwise initializer synthesis and
417
    // causes the type checker to not guarantee the order of these members.
418
1.11k
    auto memberContextualType =
419
1.11k
        parentDC->mapTypeIntoContext(member->getValueInterfaceType());
420
1.11k
    auto memberTanType =
421
1.11k
        getTangentVectorInterfaceType(memberContextualType, parentDC);
422
1.11k
    tangentProperty->setInterfaceType(memberTanType);
423
1.11k
    Pattern *memberPattern =
424
1.11k
        NamedPattern::createImplicit(C, tangentProperty, memberTanType);
425
1.11k
    memberPattern =
426
1.11k
        TypedPattern::createImplicit(C, memberPattern, memberTanType);
427
1.11k
    memberPattern->setType(memberTanType);
428
1.11k
    auto *memberBinding = PatternBindingDecl::createImplicit(
429
1.11k
        C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr,
430
1.11k
        structDecl);
431
1.11k
    structDecl->addMember(tangentProperty);
432
1.11k
    structDecl->addMember(memberBinding);
433
1.11k
    tangentProperty->copyFormalAccessFrom(member,
434
1.11k
                                          /*sourceIsParentContext*/ true);
435
1.11k
    tangentProperty->setSetterAccess(member->getFormalAccess());
436
437
    // Cache the tangent property.
438
1.11k
    C.evaluator.cacheOutput(TangentStoredPropertyRequest{member, CanType()},
439
1.11k
                            TangentPropertyInfo(tangentProperty));
440
441
    // Now that the original property has a corresponding tangent property, it
442
    // should be marked `@differentiable` so that the differentiation transform
443
    // will synthesize derivative functions for its accessors. We only add this
444
    // to public stored properties, because their access outside the module will
445
    // go through accessor declarations.
446
1.11k
    if (member->getEffectiveAccess() > AccessLevel::Internal &&
447
1.11k
        !member->getAttrs().hasAttribute<DifferentiableAttr>()) {
448
68
      auto *getter = member->getSynthesizedAccessor(AccessorKind::Get);
449
68
      (void)getter->getInterfaceType();
450
      // If member or its getter already has a `@differentiable` attribute,
451
      // continue.
452
68
      if (member->getAttrs().hasAttribute<DifferentiableAttr>() ||
453
68
          getter->getAttrs().hasAttribute<DifferentiableAttr>())
454
0
        continue;
455
68
      GenericSignature derivativeGenericSignature =
456
68
          getter->getGenericSignature();
457
      // If the parent declaration context is an extension, the nominal type may
458
      // conditionally conform to `Differentiable`. Use the extension generic
459
      // requirements in getter `@differentiable` attributes.
460
68
      if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
461
12
        if (auto extGenSig = extDecl->getGenericSignature())
462
12
          derivativeGenericSignature = extGenSig;
463
68
      auto *diffableAttr = DifferentiableAttr::create(
464
68
          getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
465
68
          DifferentiabilityKind::Reverse,
466
68
          /*parameterIndices*/ IndexSubset::get(C, 1, {0}),
467
68
          derivativeGenericSignature);
468
68
      member->getAttrs().add(diffableAttr);
469
68
    }
470
1.11k
  }
471
472
  // If nominal type is `@frozen`, also mark `TangentVector` struct.
473
942
  if (nominal->getAttrs().hasAttribute<FrozenAttr>())
474
8
    structDecl->getAttrs().add(new (C) FrozenAttr(/*implicit*/ true));
475
  
476
  // Add `typealias TangentVector = Self` so that the `TangentVector` itself
477
  // won't need its own conformance derivation.
478
942
  auto *tangentEqualsSelfAlias = new (C) TypeAliasDecl(
479
942
      SourceLoc(), SourceLoc(), C.Id_TangentVector, SourceLoc(),
480
942
      /*GenericParams*/ nullptr, structDecl);
481
942
  tangentEqualsSelfAlias->setUnderlyingType(structDecl->getDeclaredInterfaceType());
482
942
  tangentEqualsSelfAlias->copyFormalAccessFrom(structDecl,
483
942
                                               /*sourceIsParentContext*/ true);
484
942
  tangentEqualsSelfAlias->setImplicit();
485
942
  tangentEqualsSelfAlias->setSynthesized();
486
942
  structDecl->addMember(tangentEqualsSelfAlias);
487
488
  // The implicit memberwise constructor must be explicitly created so that it
489
  // can called in `AdditiveArithmetic` and `Differentiable` methods. Normally,
490
  // the memberwise constructor is synthesized during SILGen, which is too late.
491
942
  TypeChecker::addImplicitConstructors(structDecl);
492
493
  // After memberwise initializer is synthesized, mark members as implicit.
494
942
  for (auto *member : structDecl->getStoredProperties())
495
1.11k
    member->setImplicit();
496
497
942
  derived.addMembersToConformanceContext({structDecl});
498
499
942
  TypeChecker::checkConformancesInContext(structDecl);
500
501
942
  return structDecl;
502
942
}
503
504
/// Diagnose stored properties in the nominal that do not have an explicit
505
/// `@noDerivative` attribute, but either:
506
/// - Do not conform to `Differentiable`.
507
/// - Are a `let` stored property.
508
/// Emit a warning and a fixit so that users will make the attribute explicit.
509
static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
510
                                                 NominalTypeDecl *nominal,
511
942
                                                 DeclContext *DC) {
512
  // If nominal type can conform to `AdditiveArithmetic`, suggest adding a
513
  // conformance to `AdditiveArithmetic` in fix-its.
514
  // `Differentiable` protocol requirements all have default implementations
515
  // when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
516
  // derived conformances will no longer be necessary.
517
942
  bool nominalCanDeriveAdditiveArithmetic =
518
942
      DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
519
942
  auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
520
  // Check all stored properties.
521
1.34k
  for (auto *vd : nominal->getStoredProperties()) {
522
    // Peer through property wrappers: use original wrapped properties.
523
1.34k
    if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
524
      // Skip wrapped properties with `@noDerivative` attribute.
525
176
      if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
526
24
        continue;
527
      // Diagnose wrapped properties whose property wrappers do not define
528
      // `wrappedValue.set`. `mutating func move(by:)` cannot be synthesized
529
      // to update these properties.
530
152
      if (!originalProperty->isSettable(DC)) {
531
8
        auto *wrapperDecl =
532
8
            vd->getInterfaceType()->getNominalOrBoundGenericNominal();
533
8
        auto loc =
534
8
            originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
535
8
        Context.Diags
536
8
            .diagnose(
537
8
                loc,
538
8
                diag::
539
8
                    differentiable_immutable_wrapper_implicit_noderivative_fixit,
540
8
                wrapperDecl->getName(), nominal->getName(),
541
8
                nominalCanDeriveAdditiveArithmetic)
542
8
            .fixItInsert(loc, "@noDerivative ");
543
        // Add an implicit `@noDerivative` attribute.
544
8
        originalProperty->getAttrs().add(
545
8
            new (Context) NoDerivativeAttr(/*Implicit*/ true));
546
8
        continue;
547
8
      }
548
      // Use the original wrapped property.
549
144
      vd = originalProperty;
550
144
    }
551
1.31k
    if (vd->getInterfaceType()->hasError())
552
0
      continue;
553
    // Skip stored properties with `@noDerivative` attribute.
554
1.31k
    if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
555
116
      continue;
556
    // Check whether to diagnose stored property.
557
1.19k
    auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
558
1.19k
    auto diffableConformance =
559
1.19k
        TypeChecker::conformsToProtocol(varType, diffableProto,
560
1.19k
                                        DC->getParentModule());
561
    // If stored property should not be diagnosed, continue.
562
1.19k
    if (diffableConformance && 
563
1.19k
        canInvokeMoveByOnProperty(vd, diffableConformance))
564
1.11k
      continue;
565
    // Otherwise, add an implicit `@noDerivative` attribute.
566
84
    vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
567
84
    auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
568
84
    assert(loc.isValid() && "Expected valid source location");
569
    // Diagnose properties that do not conform to `Differentiable`.
570
84
    if (!diffableConformance) {
571
64
      Context.Diags
572
64
          .diagnose(
573
64
              loc,
574
64
              diag::differentiable_nondiff_type_implicit_noderivative_fixit,
575
64
              vd->getName(), vd->getTypeInContext(), nominal->getName(),
576
64
              nominalCanDeriveAdditiveArithmetic)
577
64
          .fixItInsert(loc, "@noDerivative ");
578
64
      continue;
579
64
    }
580
    // Otherwise, diagnose `let` property.
581
20
    Context.Diags
582
20
        .diagnose(loc,
583
20
                  diag::differentiable_let_property_implicit_noderivative_fixit,
584
20
                  nominal->getName(), nominalCanDeriveAdditiveArithmetic)
585
20
        .fixItInsert(loc, "@noDerivative ");
586
20
  }
587
942
}
588
589
/// Get or synthesize `TangentVector` struct type.
590
static std::pair<Type, TypeDecl *>
591
942
getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
592
942
  auto *parentDC = derived.getConformanceContext();
593
942
  auto *nominal = derived.Nominal;
594
942
  auto &C = nominal->getASTContext();
595
596
  // Get or synthesize `TangentVector` struct.
597
942
  auto *tangentStruct =
598
942
      getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector);
599
942
  if (!tangentStruct)
600
0
    return std::make_pair(nullptr, nullptr);
601
602
  // Check and emit warnings for implicit `@noDerivative` members.
603
942
  checkAndDiagnoseImplicitNoDerivative(C, nominal, parentDC);
604
605
  // Return the `TangentVector` struct type.
606
942
  return std::make_pair(
607
942
    parentDC->mapTypeIntoContext(
608
942
      tangentStruct->getDeclaredInterfaceType()),
609
942
    tangentStruct);
610
942
}
611
612
/// Synthesize the `TangentVector` struct type.
613
static std::pair<Type, TypeDecl *>
614
1.08k
deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) {
615
1.08k
  auto *parentDC = derived.getConformanceContext();
616
1.08k
  auto *nominal = derived.Nominal;
617
618
  // If nominal type can derive `TangentVector` as the contextual `Self` type,
619
  // return it.
620
1.08k
  if (canDeriveTangentVectorAsSelf(nominal, parentDC))
621
144
    return std::make_pair(parentDC->getSelfTypeInContext(), nullptr);
622
623
  // Otherwise, get or synthesize `TangentVector` struct type.
624
942
  return getOrSynthesizeTangentVectorStructType(derived);
625
1.08k
}
626
627
930
ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
628
  // Diagnose unknown requirements.
629
930
  if (requirement->getBaseName() != Context.Id_move) {
630
0
    Context.Diags.diagnose(requirement->getLoc(),
631
0
                           diag::broken_differentiable_requirement);
632
0
    return nullptr;
633
0
  }
634
  // Diagnose conformances in disallowed contexts.
635
930
  if (checkAndDiagnoseDisallowedContext(requirement))
636
16
    return nullptr;
637
638
  // Start an error diagnostic before attempting derivation.
639
  // If derivation succeeds, cancel the diagnostic.
640
914
  DiagnosticTransaction diagnosticTransaction(Context.Diags);
641
914
  ConformanceDecl->diagnose(diag::type_does_not_conform,
642
914
                            Nominal->getDeclaredType(), getProtocolType());
643
914
  requirement->diagnose(diag::no_witnesses,
644
914
                        getProtocolRequirementKind(requirement),
645
914
                        requirement, getProtocolType(), /*AddFixIt=*/false);
646
647
  // If derivation is possible, cancel the diagnostic and perform derivation.
648
914
  if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) {
649
898
    diagnosticTransaction.abort();
650
898
    if (requirement->getBaseName() == Context.Id_move)
651
898
      return deriveDifferentiable_move(*this);
652
898
  }
653
654
  // Otherwise, return nullptr.
655
16
  return nullptr;
656
914
}
657
658
std::pair<Type, TypeDecl *>
659
1.08k
DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
660
  // Diagnose unknown requirements.
661
1.08k
  if (requirement->getBaseName() != Context.Id_TangentVector) {
662
0
    Context.Diags.diagnose(requirement->getLoc(),
663
0
                           diag::broken_differentiable_requirement);
664
0
    return std::make_pair(nullptr, nullptr);
665
0
  }
666
667
  // Start an error diagnostic before attempting derivation.
668
  // If derivation succeeds, cancel the diagnostic.
669
1.08k
  DiagnosticTransaction diagnosticTransaction(Context.Diags);
670
1.08k
  ConformanceDecl->diagnose(diag::type_does_not_conform,
671
1.08k
                            Nominal->getDeclaredType(), getProtocolType());
672
1.08k
  requirement->diagnose(diag::no_witnesses_type, requirement);
673
674
  // If derivation is possible, cancel the diagnostic and perform derivation.
675
1.08k
  if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) {
676
1.08k
    diagnosticTransaction.abort();
677
1.08k
    return deriveDifferentiable_TangentVectorStruct(*this);
678
1.08k
  }
679
680
  // Otherwise, return nullptr.
681
0
  return std::make_pair(nullptr, nullptr);
682
1.08k
}