Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/TangentBuilder.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- TangentBuilder.cpp - Tangent SIL builder ------------*- C++ -*----===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// This file defines a helper class for emitting tangent code for automatic
14
// differentiation.
15
//
16
//===----------------------------------------------------------------------===//
17
18
#define DEBUG_TYPE "differentiation"
19
20
#include "swift/SILOptimizer/Differentiation/TangentBuilder.h"
21
#include "swift/SILOptimizer/Differentiation/ADContext.h"
22
23
namespace swift {
24
namespace autodiff {
25
26
void TangentBuilder::emitZeroIntoBuffer(SILLocation loc, SILValue buffer,
27
20.0k
                                        IsInitialization_t isInit) {
28
20.0k
  if (!isInit)
29
2.36k
    emitDestroyAddr(loc, buffer);
30
20.0k
  if (auto tupleType = buffer->getType().getAs<TupleType>()) {
31
880
    for (unsigned i : range(tupleType->getNumElements())) {
32
880
      auto *eltAddr = createTupleElementAddr(loc, buffer, i);
33
880
      emitZeroIntoBuffer(loc, eltAddr, IsInitialization);
34
880
    }
35
440
    return;
36
440
  }
37
19.5k
  auto *swiftMod = getModule().getSwiftModule();
38
  // Look up conformance to `AdditiveArithmetic`.
39
19.5k
  auto *additiveArithmeticProto = adContext.getAdditiveArithmeticProtocol();
40
19.5k
  auto astType = buffer->getType().getASTType();
41
19.5k
  auto confRef = swiftMod->lookupConformance(astType, additiveArithmeticProto);
42
19.5k
  assert(!confRef.isInvalid() && "Missing conformance to `AdditiveArithmetic`");
43
0
  SILDeclRef accessorDeclRef(adContext.getAdditiveArithmeticZeroGetter(),
44
19.5k
                             SILDeclRef::Kind::Func);
45
19.5k
  auto silFnType = getModule().Types.getConstantType(
46
19.5k
      getTypeExpansionContext(), accessorDeclRef);
47
  // %wm = witness_method ...
48
19.5k
  auto *getter = createWitnessMethod(
49
19.5k
      loc, astType, confRef, accessorDeclRef, silFnType);
50
  // %metatype = metatype $T
51
19.5k
  auto metatypeType = CanMetatypeType::get(astType,
52
19.5k
                                           MetatypeRepresentation::Thick);
53
19.5k
  auto metatype = createMetatype(
54
19.5k
      loc, SILType::getPrimitiveObjectType(metatypeType));
55
19.5k
  auto subMap = SubstitutionMap::getProtocolSubstitutions(
56
19.5k
      additiveArithmeticProto, astType, confRef);
57
19.5k
  createApply(loc, getter, subMap, {buffer, metatype});
58
19.5k
  emitDestroyValueOperation(loc, getter);
59
19.5k
}
60
61
2.44k
SILValue TangentBuilder::emitZero(SILLocation loc, CanType type) {
62
2.44k
  auto silType = getModule().Types.getLoweredLoadableType(
63
2.44k
      type, TypeExpansionContext::minimal(), getModule());
64
2.44k
  auto tempAllocLoc = RegularLocation::getAutoGeneratedLocation();
65
2.44k
  auto *alloc = createAllocStack(tempAllocLoc, silType);
66
2.44k
  emitZeroIntoBuffer(loc, alloc, IsInitialization);
67
2.44k
  auto zeroValue = emitLoadValueOperation(
68
2.44k
      loc, alloc, LoadOwnershipQualifier::Take);
69
2.44k
  createDeallocStack(loc, alloc);
70
2.44k
  return zeroValue;
71
2.44k
}
72
73
void TangentBuilder::emitInPlaceAdd(
74
8.21k
    SILLocation loc, SILValue destinationBuffer, SILValue operand) {
75
8.21k
  assert(destinationBuffer->getType().isAddress());
76
0
  auto type = destinationBuffer->getType();
77
8.21k
  if (auto tupleType = type.getAs<TupleType>()) {
78
40
    for (unsigned i : range(tupleType->getNumElements())) {
79
40
      auto *eltDestAddr = createTupleElementAddr(loc, destinationBuffer, i);
80
40
      switch (operand->getType().getCategory()) {
81
40
      case SILValueCategory::Address: {
82
40
        auto *eltOperand = createTupleElementAddr(loc, operand, i);
83
40
        emitInPlaceAdd(loc, eltDestAddr, eltOperand);
84
40
        break;
85
0
      }
86
0
      case SILValueCategory::Object: {
87
0
        auto borrowedOp = emitBeginBorrowOperation(loc, operand);
88
0
        auto eltOperand = emitTupleExtract(loc, borrowedOp, i);
89
0
        emitInPlaceAdd(loc, eltDestAddr, eltOperand);
90
0
        emitEndBorrowOperation(loc, borrowedOp);
91
0
        break;
92
0
      }
93
40
      }
94
40
    }
95
20
    return;
96
20
  }
97
  // Call the combiner function and return.
98
8.19k
  auto *swiftMod = getModule().getSwiftModule();
99
8.19k
  auto astType = type.getASTType();
100
8.19k
  auto confRef = swiftMod->lookupConformance(
101
8.19k
      astType, adContext.getAdditiveArithmeticProtocol());
102
8.19k
  assert(!confRef.isInvalid() &&
103
8.19k
         "Missing conformance to `AdditiveArithmetic`");
104
0
  SILDeclRef declRef(adContext.getPlusEqualDecl(), SILDeclRef::Kind::Func);
105
8.19k
  auto silFnTy = getModule().Types.getConstantType(
106
8.19k
      getTypeExpansionContext(), declRef);
107
  // %0 = witness_method @+=
108
8.19k
  auto witnessMethod =
109
8.19k
      createWitnessMethod(loc, astType, confRef, declRef, silFnTy);
110
8.19k
  auto subMap = SubstitutionMap::getProtocolSubstitutions(
111
8.19k
      adContext.getAdditiveArithmeticProtocol(), astType, confRef);
112
  // %1 = metatype $T.Type
113
8.19k
  auto metatypeType =
114
8.19k
      CanMetatypeType::get(astType, MetatypeRepresentation::Thick);
115
8.19k
  auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
116
8.19k
  auto metatype = createMetatype(loc, metatypeSILType);
117
  // %2 = apply $0(%lhs, %rhs, %1)
118
8.19k
  createApply(loc, witnessMethod, subMap,
119
8.19k
              {destinationBuffer, operand, metatype});
120
8.19k
  emitDestroyValueOperation(loc, witnessMethod);
121
8.19k
}
122
123
void TangentBuilder::emitAddIntoBuffer(SILLocation loc,
124
                                       SILValue destinationBuffer,
125
                                       SILValue lhsAddress,
126
2.34k
                                       SILValue rhsAddress) {
127
2.34k
  assert(lhsAddress->getType().getASTType() ==
128
2.34k
             rhsAddress->getType().getASTType() &&
129
2.34k
         "Adjoint values must have same type!");
130
0
  assert(lhsAddress->getType().isAddress() &&
131
2.34k
         rhsAddress->getType().isAddress() &&
132
2.34k
         "Adjoint values must both have address types!");
133
0
  auto type = lhsAddress->getType();
134
2.34k
  if (auto tupleType = type.getAs<TupleType>()) {
135
0
    for (unsigned i : range(tupleType->getNumElements())) {
136
0
      auto *destAddr = createTupleElementAddr(loc, destinationBuffer, i);
137
0
      auto *eltAddrLHS = createTupleElementAddr(loc, lhsAddress, i);
138
0
      auto *eltAddrRHS = createTupleElementAddr(loc, rhsAddress, i);
139
0
      emitAddIntoBuffer(loc, destAddr, eltAddrLHS, eltAddrRHS);
140
0
    }
141
0
    return;
142
0
  }
143
2.34k
  auto astType = type.getASTType();
144
2.34k
  auto *proto = adContext.getAdditiveArithmeticProtocol();
145
2.34k
  auto *combinerFuncDecl = adContext.getPlusDecl();
146
  // Call the combiner function and return.
147
2.34k
  auto *swiftMod = getModule().getSwiftModule();
148
2.34k
  auto confRef = swiftMod->lookupConformance(astType, proto);
149
2.34k
  assert(!confRef.isInvalid() &&
150
2.34k
         "Missing conformance to `AdditiveArithmetic`");
151
0
  SILDeclRef declRef(combinerFuncDecl, SILDeclRef::Kind::Func);
152
2.34k
  auto silFnTy = getModule().Types.getConstantType(
153
2.34k
      getTypeExpansionContext(), declRef);
154
  // %0 = witness_method @+
155
2.34k
  auto witnessMethod =
156
2.34k
      createWitnessMethod(loc, astType, confRef, declRef, silFnTy);
157
2.34k
  auto subMap =
158
2.34k
      SubstitutionMap::getProtocolSubstitutions(proto, astType, confRef);
159
  // %1 = metatype $T.Type
160
2.34k
  auto metatypeType =
161
2.34k
      CanMetatypeType::get(astType, MetatypeRepresentation::Thick);
162
2.34k
  auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
163
2.34k
  auto metatype = createMetatype(loc, metatypeSILType);
164
  // %2 = apply %0(%result, %new, %old, %1)
165
2.34k
  createApply(loc, witnessMethod, subMap,
166
2.34k
              {destinationBuffer, rhsAddress, lhsAddress, metatype});
167
2.34k
  emitDestroyValueOperation(loc, witnessMethod);
168
2.34k
}
169
170
2.34k
SILValue TangentBuilder::emitAdd(SILLocation loc, SILValue lhs, SILValue rhs) {
171
2.34k
  LLVM_DEBUG(getADDebugStream() << "Emitting adjoint accumulation for lhs: "
172
2.34k
                                << lhs << " and rhs: " << rhs);
173
2.34k
  assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!");
174
0
  assert(lhs->getType().isObject() && rhs->getType().isObject() &&
175
2.34k
         "Adjoint types must be both object types!");
176
0
  auto type = lhs->getType();
177
2.34k
  auto lhsCopy = emitCopyValueOperation(loc, lhs);
178
2.34k
  auto rhsCopy = emitCopyValueOperation(loc, rhs);
179
  // Allocate buffers for inputs and output.
180
2.34k
  auto tempAllocLoc = RegularLocation::getAutoGeneratedLocation();
181
2.34k
  auto *resultBuf = createAllocStack(tempAllocLoc, type);
182
2.34k
  auto *lhsBuf = createAllocStack(tempAllocLoc, type);
183
2.34k
  auto *rhsBuf = createAllocStack(tempAllocLoc, type);
184
  // Initialize input buffers.
185
2.34k
  emitStoreValueOperation(loc, lhsCopy, lhsBuf,
186
2.34k
                          StoreOwnershipQualifier::Init);
187
2.34k
  emitStoreValueOperation(loc, rhsCopy, rhsBuf,
188
2.34k
                          StoreOwnershipQualifier::Init);
189
2.34k
  emitAddIntoBuffer(loc, resultBuf, lhsBuf, rhsBuf);
190
2.34k
  emitDestroyAddr(loc, lhsBuf);
191
2.34k
  emitDestroyAddr(loc, rhsBuf);
192
  // Deallocate input buffers.
193
2.34k
  createDeallocStack(loc, rhsBuf);
194
2.34k
  createDeallocStack(loc, lhsBuf);
195
2.34k
  auto val = emitLoadValueOperation(loc, resultBuf,
196
2.34k
                                    LoadOwnershipQualifier::Take);
197
  // Deallocate result buffer.
198
2.34k
  createDeallocStack(loc, resultBuf);
199
2.34k
  return val;
200
2.34k
}
201
202
} // end namespace autodiff
203
} // end namespace swift