/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/TangentBuilder.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- TangentBuilder.cpp - Tangent SIL builder ------------*- C++ -*----===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // This file defines a helper class for emitting tangent code for automatic |
14 | | // differentiation. |
15 | | // |
16 | | //===----------------------------------------------------------------------===// |
17 | | |
18 | | #define DEBUG_TYPE "differentiation" |
19 | | |
20 | | #include "swift/SILOptimizer/Differentiation/TangentBuilder.h" |
21 | | #include "swift/SILOptimizer/Differentiation/ADContext.h" |
22 | | |
23 | | namespace swift { |
24 | | namespace autodiff { |
25 | | |
26 | | void TangentBuilder::emitZeroIntoBuffer(SILLocation loc, SILValue buffer, |
27 | 20.0k | IsInitialization_t isInit) { |
28 | 20.0k | if (!isInit) |
29 | 2.36k | emitDestroyAddr(loc, buffer); |
30 | 20.0k | if (auto tupleType = buffer->getType().getAs<TupleType>()) { |
31 | 880 | for (unsigned i : range(tupleType->getNumElements())) { |
32 | 880 | auto *eltAddr = createTupleElementAddr(loc, buffer, i); |
33 | 880 | emitZeroIntoBuffer(loc, eltAddr, IsInitialization); |
34 | 880 | } |
35 | 440 | return; |
36 | 440 | } |
37 | 19.5k | auto *swiftMod = getModule().getSwiftModule(); |
38 | | // Look up conformance to `AdditiveArithmetic`. |
39 | 19.5k | auto *additiveArithmeticProto = adContext.getAdditiveArithmeticProtocol(); |
40 | 19.5k | auto astType = buffer->getType().getASTType(); |
41 | 19.5k | auto confRef = swiftMod->lookupConformance(astType, additiveArithmeticProto); |
42 | 19.5k | assert(!confRef.isInvalid() && "Missing conformance to `AdditiveArithmetic`"); |
43 | 0 | SILDeclRef accessorDeclRef(adContext.getAdditiveArithmeticZeroGetter(), |
44 | 19.5k | SILDeclRef::Kind::Func); |
45 | 19.5k | auto silFnType = getModule().Types.getConstantType( |
46 | 19.5k | getTypeExpansionContext(), accessorDeclRef); |
47 | | // %wm = witness_method ... |
48 | 19.5k | auto *getter = createWitnessMethod( |
49 | 19.5k | loc, astType, confRef, accessorDeclRef, silFnType); |
50 | | // %metatype = metatype $T |
51 | 19.5k | auto metatypeType = CanMetatypeType::get(astType, |
52 | 19.5k | MetatypeRepresentation::Thick); |
53 | 19.5k | auto metatype = createMetatype( |
54 | 19.5k | loc, SILType::getPrimitiveObjectType(metatypeType)); |
55 | 19.5k | auto subMap = SubstitutionMap::getProtocolSubstitutions( |
56 | 19.5k | additiveArithmeticProto, astType, confRef); |
57 | 19.5k | createApply(loc, getter, subMap, {buffer, metatype}); |
58 | 19.5k | emitDestroyValueOperation(loc, getter); |
59 | 19.5k | } |
60 | | |
61 | 2.44k | SILValue TangentBuilder::emitZero(SILLocation loc, CanType type) { |
62 | 2.44k | auto silType = getModule().Types.getLoweredLoadableType( |
63 | 2.44k | type, TypeExpansionContext::minimal(), getModule()); |
64 | 2.44k | auto tempAllocLoc = RegularLocation::getAutoGeneratedLocation(); |
65 | 2.44k | auto *alloc = createAllocStack(tempAllocLoc, silType); |
66 | 2.44k | emitZeroIntoBuffer(loc, alloc, IsInitialization); |
67 | 2.44k | auto zeroValue = emitLoadValueOperation( |
68 | 2.44k | loc, alloc, LoadOwnershipQualifier::Take); |
69 | 2.44k | createDeallocStack(loc, alloc); |
70 | 2.44k | return zeroValue; |
71 | 2.44k | } |
72 | | |
73 | | void TangentBuilder::emitInPlaceAdd( |
74 | 8.21k | SILLocation loc, SILValue destinationBuffer, SILValue operand) { |
75 | 8.21k | assert(destinationBuffer->getType().isAddress()); |
76 | 0 | auto type = destinationBuffer->getType(); |
77 | 8.21k | if (auto tupleType = type.getAs<TupleType>()) { |
78 | 40 | for (unsigned i : range(tupleType->getNumElements())) { |
79 | 40 | auto *eltDestAddr = createTupleElementAddr(loc, destinationBuffer, i); |
80 | 40 | switch (operand->getType().getCategory()) { |
81 | 40 | case SILValueCategory::Address: { |
82 | 40 | auto *eltOperand = createTupleElementAddr(loc, operand, i); |
83 | 40 | emitInPlaceAdd(loc, eltDestAddr, eltOperand); |
84 | 40 | break; |
85 | 0 | } |
86 | 0 | case SILValueCategory::Object: { |
87 | 0 | auto borrowedOp = emitBeginBorrowOperation(loc, operand); |
88 | 0 | auto eltOperand = emitTupleExtract(loc, borrowedOp, i); |
89 | 0 | emitInPlaceAdd(loc, eltDestAddr, eltOperand); |
90 | 0 | emitEndBorrowOperation(loc, borrowedOp); |
91 | 0 | break; |
92 | 0 | } |
93 | 40 | } |
94 | 40 | } |
95 | 20 | return; |
96 | 20 | } |
97 | | // Call the combiner function and return. |
98 | 8.19k | auto *swiftMod = getModule().getSwiftModule(); |
99 | 8.19k | auto astType = type.getASTType(); |
100 | 8.19k | auto confRef = swiftMod->lookupConformance( |
101 | 8.19k | astType, adContext.getAdditiveArithmeticProtocol()); |
102 | 8.19k | assert(!confRef.isInvalid() && |
103 | 8.19k | "Missing conformance to `AdditiveArithmetic`"); |
104 | 0 | SILDeclRef declRef(adContext.getPlusEqualDecl(), SILDeclRef::Kind::Func); |
105 | 8.19k | auto silFnTy = getModule().Types.getConstantType( |
106 | 8.19k | getTypeExpansionContext(), declRef); |
107 | | // %0 = witness_method @+= |
108 | 8.19k | auto witnessMethod = |
109 | 8.19k | createWitnessMethod(loc, astType, confRef, declRef, silFnTy); |
110 | 8.19k | auto subMap = SubstitutionMap::getProtocolSubstitutions( |
111 | 8.19k | adContext.getAdditiveArithmeticProtocol(), astType, confRef); |
112 | | // %1 = metatype $T.Type |
113 | 8.19k | auto metatypeType = |
114 | 8.19k | CanMetatypeType::get(astType, MetatypeRepresentation::Thick); |
115 | 8.19k | auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
116 | 8.19k | auto metatype = createMetatype(loc, metatypeSILType); |
117 | | // %2 = apply $0(%lhs, %rhs, %1) |
118 | 8.19k | createApply(loc, witnessMethod, subMap, |
119 | 8.19k | {destinationBuffer, operand, metatype}); |
120 | 8.19k | emitDestroyValueOperation(loc, witnessMethod); |
121 | 8.19k | } |
122 | | |
123 | | void TangentBuilder::emitAddIntoBuffer(SILLocation loc, |
124 | | SILValue destinationBuffer, |
125 | | SILValue lhsAddress, |
126 | 2.34k | SILValue rhsAddress) { |
127 | 2.34k | assert(lhsAddress->getType().getASTType() == |
128 | 2.34k | rhsAddress->getType().getASTType() && |
129 | 2.34k | "Adjoint values must have same type!"); |
130 | 0 | assert(lhsAddress->getType().isAddress() && |
131 | 2.34k | rhsAddress->getType().isAddress() && |
132 | 2.34k | "Adjoint values must both have address types!"); |
133 | 0 | auto type = lhsAddress->getType(); |
134 | 2.34k | if (auto tupleType = type.getAs<TupleType>()) { |
135 | 0 | for (unsigned i : range(tupleType->getNumElements())) { |
136 | 0 | auto *destAddr = createTupleElementAddr(loc, destinationBuffer, i); |
137 | 0 | auto *eltAddrLHS = createTupleElementAddr(loc, lhsAddress, i); |
138 | 0 | auto *eltAddrRHS = createTupleElementAddr(loc, rhsAddress, i); |
139 | 0 | emitAddIntoBuffer(loc, destAddr, eltAddrLHS, eltAddrRHS); |
140 | 0 | } |
141 | 0 | return; |
142 | 0 | } |
143 | 2.34k | auto astType = type.getASTType(); |
144 | 2.34k | auto *proto = adContext.getAdditiveArithmeticProtocol(); |
145 | 2.34k | auto *combinerFuncDecl = adContext.getPlusDecl(); |
146 | | // Call the combiner function and return. |
147 | 2.34k | auto *swiftMod = getModule().getSwiftModule(); |
148 | 2.34k | auto confRef = swiftMod->lookupConformance(astType, proto); |
149 | 2.34k | assert(!confRef.isInvalid() && |
150 | 2.34k | "Missing conformance to `AdditiveArithmetic`"); |
151 | 0 | SILDeclRef declRef(combinerFuncDecl, SILDeclRef::Kind::Func); |
152 | 2.34k | auto silFnTy = getModule().Types.getConstantType( |
153 | 2.34k | getTypeExpansionContext(), declRef); |
154 | | // %0 = witness_method @+ |
155 | 2.34k | auto witnessMethod = |
156 | 2.34k | createWitnessMethod(loc, astType, confRef, declRef, silFnTy); |
157 | 2.34k | auto subMap = |
158 | 2.34k | SubstitutionMap::getProtocolSubstitutions(proto, astType, confRef); |
159 | | // %1 = metatype $T.Type |
160 | 2.34k | auto metatypeType = |
161 | 2.34k | CanMetatypeType::get(astType, MetatypeRepresentation::Thick); |
162 | 2.34k | auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
163 | 2.34k | auto metatype = createMetatype(loc, metatypeSILType); |
164 | | // %2 = apply %0(%result, %new, %old, %1) |
165 | 2.34k | createApply(loc, witnessMethod, subMap, |
166 | 2.34k | {destinationBuffer, rhsAddress, lhsAddress, metatype}); |
167 | 2.34k | emitDestroyValueOperation(loc, witnessMethod); |
168 | 2.34k | } |
169 | | |
170 | 2.34k | SILValue TangentBuilder::emitAdd(SILLocation loc, SILValue lhs, SILValue rhs) { |
171 | 2.34k | LLVM_DEBUG(getADDebugStream() << "Emitting adjoint accumulation for lhs: " |
172 | 2.34k | << lhs << " and rhs: " << rhs); |
173 | 2.34k | assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!"); |
174 | 0 | assert(lhs->getType().isObject() && rhs->getType().isObject() && |
175 | 2.34k | "Adjoint types must be both object types!"); |
176 | 0 | auto type = lhs->getType(); |
177 | 2.34k | auto lhsCopy = emitCopyValueOperation(loc, lhs); |
178 | 2.34k | auto rhsCopy = emitCopyValueOperation(loc, rhs); |
179 | | // Allocate buffers for inputs and output. |
180 | 2.34k | auto tempAllocLoc = RegularLocation::getAutoGeneratedLocation(); |
181 | 2.34k | auto *resultBuf = createAllocStack(tempAllocLoc, type); |
182 | 2.34k | auto *lhsBuf = createAllocStack(tempAllocLoc, type); |
183 | 2.34k | auto *rhsBuf = createAllocStack(tempAllocLoc, type); |
184 | | // Initialize input buffers. |
185 | 2.34k | emitStoreValueOperation(loc, lhsCopy, lhsBuf, |
186 | 2.34k | StoreOwnershipQualifier::Init); |
187 | 2.34k | emitStoreValueOperation(loc, rhsCopy, rhsBuf, |
188 | 2.34k | StoreOwnershipQualifier::Init); |
189 | 2.34k | emitAddIntoBuffer(loc, resultBuf, lhsBuf, rhsBuf); |
190 | 2.34k | emitDestroyAddr(loc, lhsBuf); |
191 | 2.34k | emitDestroyAddr(loc, rhsBuf); |
192 | | // Deallocate input buffers. |
193 | 2.34k | createDeallocStack(loc, rhsBuf); |
194 | 2.34k | createDeallocStack(loc, lhsBuf); |
195 | 2.34k | auto val = emitLoadValueOperation(loc, resultBuf, |
196 | 2.34k | LoadOwnershipQualifier::Take); |
197 | | // Deallocate result buffer. |
198 | 2.34k | createDeallocStack(loc, resultBuf); |
199 | 2.34k | return val; |
200 | 2.34k | } |
201 | | |
202 | | } // end namespace autodiff |
203 | | } // end namespace swift |