/Volumes/compiler/apple/swift/lib/Sema/DerivedConformanceDifferentiable.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- DerivedConformanceDifferentiable.cpp - Derived Differentiable ----===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // This file implements explicit derivation of the Differentiable protocol for |
14 | | // struct and class types. |
15 | | // |
16 | | //===----------------------------------------------------------------------===// |
17 | | |
18 | | #include "CodeSynthesis.h" |
19 | | #include "TypeChecker.h" |
20 | | #include "TypeCheckType.h" |
21 | | #include "llvm/ADT/SmallPtrSet.h" |
22 | | #include "swift/AST/AutoDiff.h" |
23 | | #include "swift/AST/Decl.h" |
24 | | #include "swift/AST/Expr.h" |
25 | | #include "swift/AST/Module.h" |
26 | | #include "swift/AST/ParameterList.h" |
27 | | #include "swift/AST/Pattern.h" |
28 | | #include "swift/AST/PropertyWrappers.h" |
29 | | #include "swift/AST/ProtocolConformance.h" |
30 | | #include "swift/AST/Stmt.h" |
31 | | #include "swift/AST/TypeCheckRequests.h" |
32 | | #include "swift/AST/Types.h" |
33 | | #include "DerivedConformances.h" |
34 | | |
35 | | using namespace swift; |
36 | | |
37 | | /// Return true if `move(by:)` can be invoked on the given `Differentiable`- |
38 | | /// conforming property. |
39 | | /// |
40 | | /// If the given property is a `var`, return true because `move(by:)` can be |
41 | | /// invoked regardless. Otherwise, return true if and only if the property's |
42 | | /// type's 'Differentiable.move(by:)' witness is non-mutating. |
43 | | static bool canInvokeMoveByOnProperty( |
44 | 5.49k | VarDecl *vd, ProtocolConformanceRef diffableConformance) { |
45 | 5.49k | assert(diffableConformance && "Property must conform to 'Differentiable'"); |
46 | | // `var` always supports `move(by:)` since it is mutable. |
47 | 5.49k | if (vd->getIntroducer() == VarDecl::Introducer::Var) |
48 | 5.30k | return true; |
49 | | // When the property is a `let`, the only case that would be supported is when |
50 | | // it has a `move(by:)` protocol requirement witness that is non-mutating. |
51 | 188 | auto interfaceType = vd->getInterfaceType(); |
52 | 188 | auto &C = vd->getASTContext(); |
53 | 188 | auto witness = diffableConformance.getWitnessByName( |
54 | 188 | interfaceType, DeclName(C, C.Id_move, {C.Id_by})); |
55 | 188 | if (!witness) |
56 | 0 | return false; |
57 | 188 | auto *decl = cast<FuncDecl>(witness.getDecl()); |
58 | 188 | return !decl->isMutating(); |
59 | 188 | } |
60 | | |
61 | | /// Get the stored properties of a nominal type that are relevant for |
62 | | /// differentiation, except the ones tagged `@noDerivative`. |
63 | | static void |
64 | | getStoredPropertiesForDifferentiation( |
65 | | NominalTypeDecl *nominal, DeclContext *DC, |
66 | | SmallVectorImpl<VarDecl *> &result, |
67 | 3.60k | bool includeLetPropertiesWithNonmutatingMoveBy = false) { |
68 | 3.60k | auto &C = nominal->getASTContext(); |
69 | 3.60k | auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
70 | 5.12k | for (auto *vd : nominal->getStoredProperties()) { |
71 | | // Peer through property wrappers: use original wrapped properties instead. |
72 | 5.12k | if (auto *originalProperty = vd->getOriginalWrappedProperty()) { |
73 | | // Skip immutable wrapped properties. `mutating func move(by:)` cannot |
74 | | // be synthesized to update these properties. |
75 | 668 | if (!originalProperty->isSettable(DC)) |
76 | 56 | continue; |
77 | | // Use the original wrapped property. |
78 | 612 | vd = originalProperty; |
79 | 612 | } |
80 | | // Skip stored properties with `@noDerivative` attribute. |
81 | 5.07k | if (vd->getAttrs().hasAttribute<NoDerivativeAttr>()) |
82 | 584 | continue; |
83 | 4.48k | if (vd->getInterfaceType()->hasError()) |
84 | 0 | continue; |
85 | 4.48k | auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); |
86 | 4.48k | auto conformance = TypeChecker::conformsToProtocol( |
87 | 4.48k | varType, diffableProto, DC->getParentModule()); |
88 | 4.48k | if (!conformance) |
89 | 128 | continue; |
90 | | // Skip `let` stored properties with a mutating `move(by:)` if requested. |
91 | | // `mutating func move(by:)` cannot be synthesized to update `let` |
92 | | // properties. |
93 | 4.36k | if (!includeLetPropertiesWithNonmutatingMoveBy && |
94 | 4.36k | !canInvokeMoveByOnProperty(vd, conformance)) |
95 | 40 | continue; |
96 | 4.32k | result.push_back(vd); |
97 | 4.32k | } |
98 | 3.60k | } |
99 | | |
100 | | /// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a |
101 | | /// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`. |
102 | 914 | static StructDecl *convertToStructDecl(ValueDecl *v) { |
103 | 914 | if (auto *structDecl = dyn_cast<StructDecl>(v)) |
104 | 906 | return structDecl; |
105 | 8 | auto *typeDecl = dyn_cast<TypeDecl>(v); |
106 | 8 | if (!typeDecl) |
107 | 0 | return nullptr; |
108 | 8 | return dyn_cast_or_null<StructDecl>( |
109 | 8 | typeDecl->getDeclaredInterfaceType()->getAnyNominal()); |
110 | 8 | } |
111 | | |
112 | | /// Get the `Differentiable` protocol `TangentVector` associated type witness |
113 | | /// for the given interface type and declaration context. |
114 | | static Type getTangentVectorInterfaceType(Type contextualType, |
115 | 2.00k | DeclContext *DC) { |
116 | 2.00k | auto &C = DC->getASTContext(); |
117 | 2.00k | auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
118 | 2.00k | assert(diffableProto && "`Differentiable` protocol not found"); |
119 | 0 | auto conf = |
120 | 2.00k | TypeChecker::conformsToProtocol(contextualType, diffableProto, |
121 | 2.00k | DC->getParentModule()); |
122 | 2.00k | assert(conf && "Contextual type must conform to `Differentiable`"); |
123 | 2.00k | if (!conf) |
124 | 0 | return nullptr; |
125 | 2.00k | auto tanType = conf.getTypeWitnessByName(contextualType, C.Id_TangentVector); |
126 | 2.00k | return tanType->hasArchetype() ? tanType->mapTypeOutOfContext() : tanType; |
127 | 2.00k | } |
128 | | |
129 | | /// Returns true iff the given nominal type declaration can derive |
130 | | /// `TangentVector` as `Self` in the given conformance context. |
131 | | static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal, |
132 | 3.08k | DeclContext *DC) { |
133 | | // `Self` must not be a class declaration. |
134 | 3.08k | if (nominal->getSelfClassDecl()) |
135 | 892 | return false; |
136 | | |
137 | 2.19k | auto nominalTypeInContext = |
138 | 2.19k | DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); |
139 | 2.19k | auto &C = nominal->getASTContext(); |
140 | 2.19k | auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
141 | 2.19k | auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
142 | | // `Self` must conform to `AdditiveArithmetic`. |
143 | 2.19k | if (!TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, |
144 | 2.19k | DC->getParentModule())) |
145 | 1.85k | return false; |
146 | 480 | for (auto *field : nominal->getStoredProperties()) { |
147 | | // `Self` must not have any `@noDerivative` stored properties. |
148 | 480 | if (field->getAttrs().hasAttribute<NoDerivativeAttr>()) |
149 | 0 | return false; |
150 | | // `Self` must have all stored properties satisfy `Self == TangentVector`. |
151 | 480 | auto fieldType = DC->mapTypeIntoContext(field->getValueInterfaceType()); |
152 | 480 | auto conf = TypeChecker::conformsToProtocol(fieldType, diffableProto, |
153 | 480 | DC->getParentModule()); |
154 | 480 | if (!conf) |
155 | 0 | return false; |
156 | 480 | auto tangentType = conf.getTypeWitnessByName(fieldType, C.Id_TangentVector); |
157 | 480 | if (!fieldType->isEqual(tangentType)) |
158 | 48 | return false; |
159 | 480 | } |
160 | 296 | return true; |
161 | 344 | } |
162 | | |
163 | | bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, |
164 | | DeclContext *DC, |
165 | 2.00k | ValueDecl *requirement) { |
166 | | // Experimental differentiable programming must be enabled. |
167 | 2.00k | if (auto *SF = DC->getParentSourceFile()) |
168 | 2.00k | if (!isDifferentiableProgrammingEnabled(*SF)) |
169 | 0 | return false; |
170 | | |
171 | 2.00k | auto &C = nominal->getASTContext(); |
172 | | // If there are any `TangentVector` type witness candidates, check whether |
173 | | // there exists only a single valid candidate. |
174 | 2.00k | bool canUseTangentVectorAsSelf = canDeriveTangentVectorAsSelf(nominal, DC); |
175 | 2.00k | auto isValidTangentVectorCandidate = [&](ValueDecl *v) -> bool { |
176 | | // Valid candidate must be a struct or a typealias to a struct. |
177 | 914 | auto *structDecl = convertToStructDecl(v); |
178 | 914 | if (!structDecl) |
179 | 4 | return false; |
180 | | // Valid candidate must either: |
181 | | // 1. Be implicit (previously synthesized). |
182 | 910 | if (structDecl->isImplicit()) |
183 | 898 | return true; |
184 | | // 2. Equal nominal, when the nominal can derive `TangentVector` as `Self`. |
185 | | // Nominal type must not customize `TangentVector` to anything other than |
186 | | // `Self`. Otherwise, synthesis is semantically unsupported. |
187 | 12 | if (structDecl == nominal && canUseTangentVectorAsSelf) |
188 | 0 | return true; |
189 | | // Otherwise, candidate is invalid. |
190 | 12 | return false; |
191 | 12 | }; |
192 | 2.00k | auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector); |
193 | | // There can be at most one valid `TangentVector` type. |
194 | 2.00k | if (tangentDecls.size() > 1) |
195 | 0 | return false; |
196 | | // There cannot be any invalid `TangentVector` types. |
197 | 2.00k | if (tangentDecls.size() == 1) { |
198 | 914 | auto *tangentDecl = tangentDecls.front(); |
199 | 914 | if (!isValidTangentVectorCandidate(tangentDecl)) |
200 | 16 | return false; |
201 | 914 | } |
202 | | |
203 | | // Check `TangentVector` struct derivation conditions. |
204 | | // Nominal type must be a struct or class. (No stored properties is okay.) |
205 | 1.98k | if (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal)) |
206 | 0 | return false; |
207 | | // If there are no `TangentVector` candidates, derivation is possible if all |
208 | | // differentiation stored properties conform to `Differentiable`. |
209 | 1.98k | SmallVector<VarDecl *, 16> diffProperties; |
210 | 1.98k | getStoredPropertiesForDifferentiation(nominal, DC, diffProperties); |
211 | 1.98k | auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
212 | 2.38k | return llvm::all_of(diffProperties, [&](VarDecl *v) { |
213 | 2.38k | if (v->getInterfaceType()->hasError()) |
214 | 0 | return false; |
215 | 2.38k | auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); |
216 | 2.38k | return (bool)TypeChecker::conformsToProtocol(varType, diffableProto, |
217 | 2.38k | DC->getParentModule()); |
218 | 2.38k | }); |
219 | 1.98k | } |
220 | | |
221 | | /// Synthesize body for `move(by:)`. |
222 | | static std::pair<BraceStmt *, bool> |
223 | 674 | deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) { |
224 | 674 | auto &C = funcDecl->getASTContext(); |
225 | 674 | auto *parentDC = funcDecl->getParent(); |
226 | 674 | auto *nominal = parentDC->getSelfNominalTypeDecl(); |
227 | | |
228 | | // Get `Differentiable.move(by:)` protocol requirement. |
229 | 674 | auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); |
230 | 674 | auto *requirement = getProtocolRequirement(diffProto, C.Id_move); |
231 | | |
232 | | // Get references to `self` and parameter declarations. |
233 | 674 | auto *selfDecl = funcDecl->getImplicitSelfDecl(); |
234 | 674 | auto *selfDRE = |
235 | 674 | new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); |
236 | 674 | auto *paramDecl = funcDecl->getParameters()->get(0); |
237 | 674 | auto *paramDRE = |
238 | 674 | new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true); |
239 | | |
240 | 674 | SmallVector<VarDecl *, 8> diffProperties; |
241 | 674 | getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); |
242 | | |
243 | | // Create call expression applying a member `move(by:)` method to a |
244 | | // parameter member: `self.<member>.move(by: offset.<member>)`. |
245 | 826 | auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * { |
246 | 826 | auto *module = nominal->getModuleContext(); |
247 | 826 | auto memberType = |
248 | 826 | parentDC->mapTypeIntoContext(member->getValueInterfaceType()); |
249 | 826 | auto confRef = module->lookupConformance(memberType, diffProto); |
250 | 826 | assert(confRef && "Member does not conform to `Differentiable`"); |
251 | | |
252 | | // Get member type's requirement witness: `<Member>.move(by:)`. |
253 | 0 | ValueDecl *memberWitnessDecl = requirement; |
254 | 826 | if (confRef.isConcrete()) |
255 | 634 | if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement)) |
256 | 634 | memberWitnessDecl = witness; |
257 | 826 | assert(memberWitnessDecl && "Member witness declaration must exist"); |
258 | | |
259 | | // Create reference to member method: `self.<member>.move(by:)`. |
260 | 0 | Expr *memberExpr = |
261 | 826 | new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), |
262 | 826 | /*Implicit*/ true); |
263 | 826 | auto *memberMethodExpr = |
264 | 826 | new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl, |
265 | 826 | DeclNameLoc(), /*Implicit*/ true); |
266 | | |
267 | | // Create reference to parameter member: `offset.<member>`. |
268 | 826 | VarDecl *paramMember = nullptr; |
269 | 826 | auto *paramNominal = paramDecl->getTypeInContext()->getAnyNominal(); |
270 | 826 | assert(paramNominal && "Parameter should have a nominal type"); |
271 | | // Find parameter member corresponding to returned nominal member. |
272 | 1.23k | for (auto *candidate : paramNominal->getStoredProperties()) { |
273 | 1.23k | if (candidate->getName() == member->getName()) { |
274 | 826 | paramMember = candidate; |
275 | 826 | break; |
276 | 826 | } |
277 | 1.23k | } |
278 | 826 | assert(paramMember && "Could not find corresponding parameter member"); |
279 | 0 | auto *paramMemberExpr = |
280 | 826 | new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(), |
281 | 826 | /*Implicit*/ true); |
282 | | // Create expression: `self.<member>.move(by: offset.<member>)`. |
283 | 826 | auto *args = ArgumentList::forImplicitSingle(C, C.Id_by, paramMemberExpr); |
284 | 826 | return CallExpr::createImplicit(C, memberMethodExpr, args); |
285 | 826 | }; |
286 | | |
287 | | // Collect member `move(by:)` method call expressions. |
288 | 674 | SmallVector<ASTNode, 2> memberMethodCallExprs; |
289 | 674 | SmallVector<Identifier, 2> memberNames; |
290 | 826 | for (auto *member : diffProperties) { |
291 | 826 | memberMethodCallExprs.push_back(createMemberMethodCallExpr(member)); |
292 | 826 | memberNames.push_back(member->getName()); |
293 | 826 | } |
294 | 674 | auto *braceStmt = BraceStmt::create(C, SourceLoc(), memberMethodCallExprs, |
295 | 674 | SourceLoc(), true); |
296 | 674 | return std::pair<BraceStmt *, bool>(braceStmt, false); |
297 | 674 | } |
298 | | |
299 | | /// Synthesize function declaration for a `Differentiable` method requirement. |
300 | | static ValueDecl *deriveDifferentiable_method( |
301 | | DerivedConformance &derived, Identifier methodName, Identifier argumentName, |
302 | | Identifier parameterName, Type parameterType, Type returnType, |
303 | 898 | AbstractFunctionDecl::BodySynthesizer bodySynthesizer) { |
304 | 898 | auto *nominal = derived.Nominal; |
305 | 898 | auto &C = derived.Context; |
306 | 898 | auto *parentDC = derived.getConformanceContext(); |
307 | | |
308 | 898 | auto *param = new (C) ParamDecl(SourceLoc(), SourceLoc(), argumentName, |
309 | 898 | SourceLoc(), parameterName, parentDC); |
310 | 898 | param->setSpecifier(ParamDecl::Specifier::Default); |
311 | 898 | param->setInterfaceType(parameterType); |
312 | 898 | param->setImplicit(); |
313 | 898 | ParameterList *params = ParameterList::create(C, {param}); |
314 | | |
315 | 898 | DeclName declName(C, methodName, params); |
316 | 898 | auto *const funcDecl = FuncDecl::createImplicit( |
317 | 898 | C, StaticSpellingKind::None, declName, /*NameLoc=*/SourceLoc(), |
318 | 898 | /*Async=*/false, |
319 | 898 | /*Throws=*/false, |
320 | 898 | /*ThrownType=*/Type(), |
321 | 898 | /*GenericParams=*/nullptr, params, returnType, parentDC); |
322 | 898 | funcDecl->setSynthesized(); |
323 | 898 | if (!nominal->getSelfClassDecl()) |
324 | 614 | funcDecl->setSelfAccessKind(SelfAccessKind::Mutating); |
325 | 898 | funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context); |
326 | | |
327 | 898 | funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext()); |
328 | 898 | funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); |
329 | | |
330 | 898 | derived.addMembersToConformanceContext({funcDecl}); |
331 | 898 | return funcDecl; |
332 | 898 | } |
333 | | |
334 | | /// Synthesize the `move(by:)` function declaration. |
335 | 898 | static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) { |
336 | 898 | auto &C = derived.Context; |
337 | 898 | auto *parentDC = derived.getConformanceContext(); |
338 | 898 | auto tangentType = |
339 | 898 | getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC); |
340 | 898 | return deriveDifferentiable_method( |
341 | 898 | derived, C.Id_move, C.Id_by, C.Id_offset, tangentType, |
342 | 898 | C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr}); |
343 | 898 | } |
344 | | |
345 | | /// Return associated `TangentVector` struct for a nominal type, if it exists. |
346 | | /// If not, synthesize the struct. |
347 | | static StructDecl * |
348 | 942 | getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { |
349 | 942 | auto *parentDC = derived.getConformanceContext(); |
350 | 942 | auto *nominal = derived.Nominal; |
351 | 942 | auto &C = nominal->getASTContext(); |
352 | | |
353 | | // If the associated struct already exists, return it. |
354 | 942 | auto lookup = nominal->lookupDirect(C.Id_TangentVector); |
355 | 942 | assert(lookup.size() < 2 && |
356 | 942 | "Expected at most one associated type named `TangentVector`"); |
357 | 942 | if (lookup.size() == 1) { |
358 | 0 | auto *structDecl = convertToStructDecl(lookup.front()); |
359 | 0 | assert(structDecl && "Expected lookup result to be a struct"); |
360 | 0 | return structDecl; |
361 | 0 | } |
362 | | |
363 | | // Otherwise, synthesize a new struct. |
364 | | |
365 | | // Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must |
366 | | // inherit, by collecting all the `TangentVector` conformance requirements imposed by the |
367 | | // protocols that `derived.ConformanceDecl` inherits. |
368 | | // |
369 | | // Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because |
370 | | // the `Differentiable` protocol itself requires that its `TangentVector` conforms to |
371 | | // `AdditiveArithmetic` and `Differentiable`. |
372 | 942 | llvm::SmallSetVector<ProtocolDecl *, 4> tvDesiredProtos; |
373 | | |
374 | 942 | auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
375 | 942 | auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector); |
376 | | |
377 | 942 | auto localProtos = cast<IterableDeclContext>(derived.ConformanceDecl) |
378 | 942 | ->getLocalProtocols(); |
379 | 1.57k | for (auto proto : localProtos) { |
380 | 2.99k | for (auto req : proto->getRequirementSignature().getRequirements()) { |
381 | 2.99k | if (req.getKind() != RequirementKind::Conformance) |
382 | 942 | continue; |
383 | 2.05k | auto *firstType = req.getFirstType()->getAs<DependentMemberType>(); |
384 | 2.05k | if (!firstType || firstType->getAssocType() != tvAssocType) |
385 | 140 | continue; |
386 | 1.91k | tvDesiredProtos.insert(req.getProtocolDecl()); |
387 | 1.91k | } |
388 | 1.57k | } |
389 | 942 | SmallVector<InheritedEntry, 4> tvDesiredProtoInherited; |
390 | 942 | for (auto *p : tvDesiredProtos) |
391 | 1.91k | tvDesiredProtoInherited.push_back( |
392 | 1.91k | InheritedEntry(TypeLoc::withoutLoc(p->getDeclaredInterfaceType()))); |
393 | | |
394 | | // Cache original members and their associated types for later use. |
395 | 942 | SmallVector<VarDecl *, 8> diffProperties; |
396 | 942 | getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); |
397 | | |
398 | 942 | auto synthesizedLoc = derived.ConformanceDecl->getEndLoc(); |
399 | 942 | auto *structDecl = |
400 | 942 | new (C) StructDecl(synthesizedLoc, C.Id_TangentVector, synthesizedLoc, |
401 | 942 | /*Inherited*/ C.AllocateCopy(tvDesiredProtoInherited), |
402 | 942 | /*GenericParams*/ {}, parentDC); |
403 | 942 | structDecl->setBraces({synthesizedLoc, synthesizedLoc}); |
404 | 942 | structDecl->setImplicit(); |
405 | 942 | structDecl->setSynthesized(); |
406 | 942 | structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); |
407 | | |
408 | | // Add stored properties to the `TangentVector` struct. |
409 | 1.11k | for (auto *member : diffProperties) { |
410 | | // Add a tangent stored property to the `TangentVector` struct, with the |
411 | | // name and `TangentVector` type of the original property. |
412 | 1.11k | auto *tangentProperty = new (C) VarDecl( |
413 | 1.11k | member->isStatic(), member->getIntroducer(), |
414 | 1.11k | /*NameLoc*/ SourceLoc(), member->getName(), structDecl); |
415 | | // Note: `tangentProperty` is not marked as implicit or synthesized here, |
416 | | // because that incorrectly affects memberwise initializer synthesis and |
417 | | // causes the type checker to not guarantee the order of these members. |
418 | 1.11k | auto memberContextualType = |
419 | 1.11k | parentDC->mapTypeIntoContext(member->getValueInterfaceType()); |
420 | 1.11k | auto memberTanType = |
421 | 1.11k | getTangentVectorInterfaceType(memberContextualType, parentDC); |
422 | 1.11k | tangentProperty->setInterfaceType(memberTanType); |
423 | 1.11k | Pattern *memberPattern = |
424 | 1.11k | NamedPattern::createImplicit(C, tangentProperty, memberTanType); |
425 | 1.11k | memberPattern = |
426 | 1.11k | TypedPattern::createImplicit(C, memberPattern, memberTanType); |
427 | 1.11k | memberPattern->setType(memberTanType); |
428 | 1.11k | auto *memberBinding = PatternBindingDecl::createImplicit( |
429 | 1.11k | C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr, |
430 | 1.11k | structDecl); |
431 | 1.11k | structDecl->addMember(tangentProperty); |
432 | 1.11k | structDecl->addMember(memberBinding); |
433 | 1.11k | tangentProperty->copyFormalAccessFrom(member, |
434 | 1.11k | /*sourceIsParentContext*/ true); |
435 | 1.11k | tangentProperty->setSetterAccess(member->getFormalAccess()); |
436 | | |
437 | | // Cache the tangent property. |
438 | 1.11k | C.evaluator.cacheOutput(TangentStoredPropertyRequest{member, CanType()}, |
439 | 1.11k | TangentPropertyInfo(tangentProperty)); |
440 | | |
441 | | // Now that the original property has a corresponding tangent property, it |
442 | | // should be marked `@differentiable` so that the differentiation transform |
443 | | // will synthesize derivative functions for its accessors. We only add this |
444 | | // to public stored properties, because their access outside the module will |
445 | | // go through accessor declarations. |
446 | 1.11k | if (member->getEffectiveAccess() > AccessLevel::Internal && |
447 | 1.11k | !member->getAttrs().hasAttribute<DifferentiableAttr>()) { |
448 | 68 | auto *getter = member->getSynthesizedAccessor(AccessorKind::Get); |
449 | 68 | (void)getter->getInterfaceType(); |
450 | | // If member or its getter already has a `@differentiable` attribute, |
451 | | // continue. |
452 | 68 | if (member->getAttrs().hasAttribute<DifferentiableAttr>() || |
453 | 68 | getter->getAttrs().hasAttribute<DifferentiableAttr>()) |
454 | 0 | continue; |
455 | 68 | GenericSignature derivativeGenericSignature = |
456 | 68 | getter->getGenericSignature(); |
457 | | // If the parent declaration context is an extension, the nominal type may |
458 | | // conditionally conform to `Differentiable`. Use the extension generic |
459 | | // requirements in getter `@differentiable` attributes. |
460 | 68 | if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl())) |
461 | 12 | if (auto extGenSig = extDecl->getGenericSignature()) |
462 | 12 | derivativeGenericSignature = extGenSig; |
463 | 68 | auto *diffableAttr = DifferentiableAttr::create( |
464 | 68 | getter, /*implicit*/ true, SourceLoc(), SourceLoc(), |
465 | 68 | DifferentiabilityKind::Reverse, |
466 | 68 | /*parameterIndices*/ IndexSubset::get(C, 1, {0}), |
467 | 68 | derivativeGenericSignature); |
468 | 68 | member->getAttrs().add(diffableAttr); |
469 | 68 | } |
470 | 1.11k | } |
471 | | |
472 | | // If nominal type is `@frozen`, also mark `TangentVector` struct. |
473 | 942 | if (nominal->getAttrs().hasAttribute<FrozenAttr>()) |
474 | 8 | structDecl->getAttrs().add(new (C) FrozenAttr(/*implicit*/ true)); |
475 | | |
476 | | // Add `typealias TangentVector = Self` so that the `TangentVector` itself |
477 | | // won't need its own conformance derivation. |
478 | 942 | auto *tangentEqualsSelfAlias = new (C) TypeAliasDecl( |
479 | 942 | SourceLoc(), SourceLoc(), C.Id_TangentVector, SourceLoc(), |
480 | 942 | /*GenericParams*/ nullptr, structDecl); |
481 | 942 | tangentEqualsSelfAlias->setUnderlyingType(structDecl->getDeclaredInterfaceType()); |
482 | 942 | tangentEqualsSelfAlias->copyFormalAccessFrom(structDecl, |
483 | 942 | /*sourceIsParentContext*/ true); |
484 | 942 | tangentEqualsSelfAlias->setImplicit(); |
485 | 942 | tangentEqualsSelfAlias->setSynthesized(); |
486 | 942 | structDecl->addMember(tangentEqualsSelfAlias); |
487 | | |
488 | | // The implicit memberwise constructor must be explicitly created so that it |
489 | | // can called in `AdditiveArithmetic` and `Differentiable` methods. Normally, |
490 | | // the memberwise constructor is synthesized during SILGen, which is too late. |
491 | 942 | TypeChecker::addImplicitConstructors(structDecl); |
492 | | |
493 | | // After memberwise initializer is synthesized, mark members as implicit. |
494 | 942 | for (auto *member : structDecl->getStoredProperties()) |
495 | 1.11k | member->setImplicit(); |
496 | | |
497 | 942 | derived.addMembersToConformanceContext({structDecl}); |
498 | | |
499 | 942 | TypeChecker::checkConformancesInContext(structDecl); |
500 | | |
501 | 942 | return structDecl; |
502 | 942 | } |
503 | | |
504 | | /// Diagnose stored properties in the nominal that do not have an explicit |
505 | | /// `@noDerivative` attribute, but either: |
506 | | /// - Do not conform to `Differentiable`. |
507 | | /// - Are a `let` stored property. |
508 | | /// Emit a warning and a fixit so that users will make the attribute explicit. |
509 | | static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context, |
510 | | NominalTypeDecl *nominal, |
511 | 942 | DeclContext *DC) { |
512 | | // If nominal type can conform to `AdditiveArithmetic`, suggest adding a |
513 | | // conformance to `AdditiveArithmetic` in fix-its. |
514 | | // `Differentiable` protocol requirements all have default implementations |
515 | | // when `Self` conforms to `AdditiveArithmetic`, so `Differentiable` |
516 | | // derived conformances will no longer be necessary. |
517 | 942 | bool nominalCanDeriveAdditiveArithmetic = |
518 | 942 | DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC); |
519 | 942 | auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable); |
520 | | // Check all stored properties. |
521 | 1.34k | for (auto *vd : nominal->getStoredProperties()) { |
522 | | // Peer through property wrappers: use original wrapped properties. |
523 | 1.34k | if (auto *originalProperty = vd->getOriginalWrappedProperty()) { |
524 | | // Skip wrapped properties with `@noDerivative` attribute. |
525 | 176 | if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>()) |
526 | 24 | continue; |
527 | | // Diagnose wrapped properties whose property wrappers do not define |
528 | | // `wrappedValue.set`. `mutating func move(by:)` cannot be synthesized |
529 | | // to update these properties. |
530 | 152 | if (!originalProperty->isSettable(DC)) { |
531 | 8 | auto *wrapperDecl = |
532 | 8 | vd->getInterfaceType()->getNominalOrBoundGenericNominal(); |
533 | 8 | auto loc = |
534 | 8 | originalProperty->getAttributeInsertionLoc(/*forModifier*/ false); |
535 | 8 | Context.Diags |
536 | 8 | .diagnose( |
537 | 8 | loc, |
538 | 8 | diag:: |
539 | 8 | differentiable_immutable_wrapper_implicit_noderivative_fixit, |
540 | 8 | wrapperDecl->getName(), nominal->getName(), |
541 | 8 | nominalCanDeriveAdditiveArithmetic) |
542 | 8 | .fixItInsert(loc, "@noDerivative "); |
543 | | // Add an implicit `@noDerivative` attribute. |
544 | 8 | originalProperty->getAttrs().add( |
545 | 8 | new (Context) NoDerivativeAttr(/*Implicit*/ true)); |
546 | 8 | continue; |
547 | 8 | } |
548 | | // Use the original wrapped property. |
549 | 144 | vd = originalProperty; |
550 | 144 | } |
551 | 1.31k | if (vd->getInterfaceType()->hasError()) |
552 | 0 | continue; |
553 | | // Skip stored properties with `@noDerivative` attribute. |
554 | 1.31k | if (vd->getAttrs().hasAttribute<NoDerivativeAttr>()) |
555 | 116 | continue; |
556 | | // Check whether to diagnose stored property. |
557 | 1.19k | auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); |
558 | 1.19k | auto diffableConformance = |
559 | 1.19k | TypeChecker::conformsToProtocol(varType, diffableProto, |
560 | 1.19k | DC->getParentModule()); |
561 | | // If stored property should not be diagnosed, continue. |
562 | 1.19k | if (diffableConformance && |
563 | 1.19k | canInvokeMoveByOnProperty(vd, diffableConformance)) |
564 | 1.11k | continue; |
565 | | // Otherwise, add an implicit `@noDerivative` attribute. |
566 | 84 | vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true)); |
567 | 84 | auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false); |
568 | 84 | assert(loc.isValid() && "Expected valid source location"); |
569 | | // Diagnose properties that do not conform to `Differentiable`. |
570 | 84 | if (!diffableConformance) { |
571 | 64 | Context.Diags |
572 | 64 | .diagnose( |
573 | 64 | loc, |
574 | 64 | diag::differentiable_nondiff_type_implicit_noderivative_fixit, |
575 | 64 | vd->getName(), vd->getTypeInContext(), nominal->getName(), |
576 | 64 | nominalCanDeriveAdditiveArithmetic) |
577 | 64 | .fixItInsert(loc, "@noDerivative "); |
578 | 64 | continue; |
579 | 64 | } |
580 | | // Otherwise, diagnose `let` property. |
581 | 20 | Context.Diags |
582 | 20 | .diagnose(loc, |
583 | 20 | diag::differentiable_let_property_implicit_noderivative_fixit, |
584 | 20 | nominal->getName(), nominalCanDeriveAdditiveArithmetic) |
585 | 20 | .fixItInsert(loc, "@noDerivative "); |
586 | 20 | } |
587 | 942 | } |
588 | | |
589 | | /// Get or synthesize `TangentVector` struct type. |
590 | | static std::pair<Type, TypeDecl *> |
591 | 942 | getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) { |
592 | 942 | auto *parentDC = derived.getConformanceContext(); |
593 | 942 | auto *nominal = derived.Nominal; |
594 | 942 | auto &C = nominal->getASTContext(); |
595 | | |
596 | | // Get or synthesize `TangentVector` struct. |
597 | 942 | auto *tangentStruct = |
598 | 942 | getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector); |
599 | 942 | if (!tangentStruct) |
600 | 0 | return std::make_pair(nullptr, nullptr); |
601 | | |
602 | | // Check and emit warnings for implicit `@noDerivative` members. |
603 | 942 | checkAndDiagnoseImplicitNoDerivative(C, nominal, parentDC); |
604 | | |
605 | | // Return the `TangentVector` struct type. |
606 | 942 | return std::make_pair( |
607 | 942 | parentDC->mapTypeIntoContext( |
608 | 942 | tangentStruct->getDeclaredInterfaceType()), |
609 | 942 | tangentStruct); |
610 | 942 | } |
611 | | |
612 | | /// Synthesize the `TangentVector` struct type. |
613 | | static std::pair<Type, TypeDecl *> |
614 | 1.08k | deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) { |
615 | 1.08k | auto *parentDC = derived.getConformanceContext(); |
616 | 1.08k | auto *nominal = derived.Nominal; |
617 | | |
618 | | // If nominal type can derive `TangentVector` as the contextual `Self` type, |
619 | | // return it. |
620 | 1.08k | if (canDeriveTangentVectorAsSelf(nominal, parentDC)) |
621 | 144 | return std::make_pair(parentDC->getSelfTypeInContext(), nullptr); |
622 | | |
623 | | // Otherwise, get or synthesize `TangentVector` struct type. |
624 | 942 | return getOrSynthesizeTangentVectorStructType(derived); |
625 | 1.08k | } |
626 | | |
627 | 930 | ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { |
628 | | // Diagnose unknown requirements. |
629 | 930 | if (requirement->getBaseName() != Context.Id_move) { |
630 | 0 | Context.Diags.diagnose(requirement->getLoc(), |
631 | 0 | diag::broken_differentiable_requirement); |
632 | 0 | return nullptr; |
633 | 0 | } |
634 | | // Diagnose conformances in disallowed contexts. |
635 | 930 | if (checkAndDiagnoseDisallowedContext(requirement)) |
636 | 16 | return nullptr; |
637 | | |
638 | | // Start an error diagnostic before attempting derivation. |
639 | | // If derivation succeeds, cancel the diagnostic. |
640 | 914 | DiagnosticTransaction diagnosticTransaction(Context.Diags); |
641 | 914 | ConformanceDecl->diagnose(diag::type_does_not_conform, |
642 | 914 | Nominal->getDeclaredType(), getProtocolType()); |
643 | 914 | requirement->diagnose(diag::no_witnesses, |
644 | 914 | getProtocolRequirementKind(requirement), |
645 | 914 | requirement, getProtocolType(), /*AddFixIt=*/false); |
646 | | |
647 | | // If derivation is possible, cancel the diagnostic and perform derivation. |
648 | 914 | if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) { |
649 | 898 | diagnosticTransaction.abort(); |
650 | 898 | if (requirement->getBaseName() == Context.Id_move) |
651 | 898 | return deriveDifferentiable_move(*this); |
652 | 898 | } |
653 | | |
654 | | // Otherwise, return nullptr. |
655 | 16 | return nullptr; |
656 | 914 | } |
657 | | |
658 | | std::pair<Type, TypeDecl *> |
659 | 1.08k | DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { |
660 | | // Diagnose unknown requirements. |
661 | 1.08k | if (requirement->getBaseName() != Context.Id_TangentVector) { |
662 | 0 | Context.Diags.diagnose(requirement->getLoc(), |
663 | 0 | diag::broken_differentiable_requirement); |
664 | 0 | return std::make_pair(nullptr, nullptr); |
665 | 0 | } |
666 | | |
667 | | // Start an error diagnostic before attempting derivation. |
668 | | // If derivation succeeds, cancel the diagnostic. |
669 | 1.08k | DiagnosticTransaction diagnosticTransaction(Context.Diags); |
670 | 1.08k | ConformanceDecl->diagnose(diag::type_does_not_conform, |
671 | 1.08k | Nominal->getDeclaredType(), getProtocolType()); |
672 | 1.08k | requirement->diagnose(diag::no_witnesses_type, requirement); |
673 | | |
674 | | // If derivation is possible, cancel the diagnostic and perform derivation. |
675 | 1.08k | if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) { |
676 | 1.08k | diagnosticTransaction.abort(); |
677 | 1.08k | return deriveDifferentiable_TangentVectorStruct(*this); |
678 | 1.08k | } |
679 | | |
680 | | // Otherwise, return nullptr. |
681 | 0 | return std::make_pair(nullptr, nullptr); |
682 | 1.08k | } |