Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// This file defines a helper class for generating pullback functions for
14
// automatic differentiation.
15
//
16
//===----------------------------------------------------------------------===//
17
18
#include "swift/Basic/STLExtras.h"
19
#define DEBUG_TYPE "differentiation"
20
21
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
22
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
23
#include "swift/SILOptimizer/Differentiation/ADContext.h"
24
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
25
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
26
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
27
#include "swift/SILOptimizer/Differentiation/Thunk.h"
28
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
29
30
#include "swift/AST/Expr.h"
31
#include "swift/AST/PropertyWrappers.h"
32
#include "swift/AST/TypeCheckRequests.h"
33
#include "swift/SIL/InstructionUtils.h"
34
#include "swift/SIL/Projection.h"
35
#include "swift/SIL/TypeSubstCloner.h"
36
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
37
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
38
#include "llvm/ADT/DenseMap.h"
39
#include "llvm/ADT/SmallSet.h"
40
41
namespace swift {
42
43
class SILDifferentiabilityWitness;
44
class SILBasicBlock;
45
class SILFunction;
46
class SILInstruction;
47
48
namespace autodiff {
49
50
class ADContext;
51
class VJPCloner;
52
53
/// The implementation class for `PullbackCloner`.
54
///
55
/// The implementation class is a `SILInstructionVisitor`. Effectively, it acts
56
/// as a `SILCloner` that visits basic blocks in post-order and that visits
57
/// instructions per basic block in reverse order. This visitation order is
58
/// necessary for generating pullback functions, whose control flow graph is
59
/// ~a transposed version of the original function's control flow graph.
60
class PullbackCloner::Implementation final
61
    : public SILInstructionVisitor<PullbackCloner::Implementation> {
62
63
public:
64
  explicit Implementation(VJPCloner &vjpCloner);
65
66
private:
67
  /// The parent VJP cloner.
68
  VJPCloner &vjpCloner;
69
70
  /// Dominance info for the original function.
71
  DominanceInfo *domInfo = nullptr;
72
73
  /// Post-dominance info for the original function.
74
  PostDominanceInfo *postDomInfo = nullptr;
75
76
  /// Post-order info for the original function.
77
  PostOrderFunctionInfo *postOrderInfo = nullptr;
78
79
  /// Mapping from original basic blocks to corresponding pullback basic blocks.
80
  /// Pullback basic blocks always have the predecessor as the single argument.
81
  llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap;
82
83
  /// Mapping from original basic blocks and original values to corresponding
84
  /// adjoint values.
85
  llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap;
86
87
  /// Mapping from original basic blocks and original values to corresponding
88
  /// adjoint buffers.
89
  llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap;
90
91
  /// Mapping from pullback struct field declarations to pullback struct
92
  /// elements destructured from the linear map basic block argument. In the
93
  /// beginning of each pullback basic block, the block's pullback struct is
94
  /// destructured into individual elements stored here.
95
  llvm::DenseMap<SILBasicBlock*, SmallVector<SILValue, 4>> pullbackTupleElements;
96
97
  /// Mapping from original basic blocks and successor basic blocks to
98
  /// corresponding pullback trampoline basic blocks. Trampoline basic blocks
99
  /// take additional arguments in addition to the predecessor enum argument.
100
  llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *>
101
      pullbackTrampolineBBMap;
102
103
  /// Mapping from original basic blocks to dominated active values.
104
  llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues;
105
106
  /// Mapping from original basic blocks and original active values to
107
  /// corresponding pullback block arguments.
108
  llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
109
      activeValuePullbackBBArgumentMap;
110
111
  /// Mapping from original basic blocks to local temporary values to be cleaned
112
  /// up. This is populated when pullback emission is run on one basic block and
113
  /// cleaned before processing another basic block.
114
  llvm::DenseMap<SILBasicBlock *, llvm::SmallSetVector<SILValue, 32>>
115
      blockTemporaries;
116
117
  /// The scope cloner.
118
  ScopeCloner scopeCloner;
119
120
  /// The main builder.
121
  TangentBuilder builder;
122
123
  /// An auxiliary local allocation builder.
124
  TangentBuilder localAllocBuilder;
125
126
  /// The original function's exit block.
127
  SILBasicBlock *originalExitBlock = nullptr;
128
129
  /// Stack buffers allocated for storing local adjoint values.
130
  SmallVector<AllocStackInst *, 64> functionLocalAllocations;
131
132
  /// A set used to remember local allocations that were destroyed.
133
  llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
134
135
  /// The seed arguments of the pullback function.
136
  SmallVector<SILArgument *, 4> seeds;
137
138
  /// The `AutoDiffLinearMapContext` object, if any.
139
  SILValue contextValue = nullptr;
140
141
  llvm::BumpPtrAllocator allocator;
142
143
  bool errorOccurred = false;
144
145
185k
  ADContext &getContext() const { return vjpCloner.getContext(); }
146
153k
  SILModule &getModule() const { return getContext().getModule(); }
147
7.58k
  ASTContext &getASTContext() const { return getPullback().getASTContext(); }
148
175k
  SILFunction &getOriginal() const { return vjpCloner.getOriginal(); }
149
144k
  SILDifferentiabilityWitness *getWitness() const {
150
144k
    return vjpCloner.getWitness();
151
144k
  }
152
3.81k
  DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); }
153
134k
  LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); }
154
173k
  const AutoDiffConfig &getConfig() const { return vjpCloner.getConfig(); }
155
152k
  const DifferentiableActivityInfo &getActivityInfo() const {
156
152k
    return vjpCloner.getActivityInfo();
157
152k
  }
158
159
  //--------------------------------------------------------------------------//
160
  // Pullback struct mapping
161
  //--------------------------------------------------------------------------//
162
163
  void initializePullbackTupleElements(SILBasicBlock *origBB,
164
1.85k
                                       SILInstructionResultArray values) {
165
1.85k
    auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
166
1.85k
    assert(pbTupleTyple->getNumElements() == values.size() &&
167
1.85k
           "The number of pullback tuple fields must equal the number of "
168
1.85k
           "pullback tuple element values");
169
0
    auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
170
1.85k
    (void)res;
171
1.85k
    assert(res.second && "A pullback tuple element already exists!");
172
1.85k
  }
173
174
  void initializePullbackTupleElements(SILBasicBlock *origBB,
175
4.94k
                                       const llvm::ArrayRef<SILArgument *> &values) {
176
4.94k
    auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
177
4.94k
    assert(pbTupleTyple->getNumElements() == values.size() &&
178
4.94k
           "The number of pullback tuple fields must equal the number of "
179
4.94k
           "pullback tuple element values");
180
0
    auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
181
4.94k
    (void)res;
182
4.94k
    assert(res.second && "A pullback struct element already exists!");
183
4.94k
  }
184
185
  /// Returns the pullback tuple element value corresponding to the given
186
  /// original block and apply inst.
187
5.65k
  SILValue getPullbackTupleElement(ApplyInst *ai) {
188
5.65k
    unsigned idx = getPullbackInfo().lookUpLinearMapIndex(ai);
189
5.65k
    assert((idx > 0 || (idx == 0 && ai->getParentBlock()->isEntry())) &&
190
5.65k
           "impossible linear map index");
191
0
    auto values = pullbackTupleElements.lookup(ai->getParentBlock());
192
5.65k
    assert(idx < values.size() &&
193
5.65k
           "pullback tuple element for this apply does not exist!");
194
0
    return values[idx];
195
5.65k
  }
196
197
  /// Returns the pullback tuple element value corresponding to the predecessor
198
  /// for the given original block.
199
1.75k
  SILValue getPullbackPredTupleElement(SILBasicBlock *origBB) {
200
1.75k
    assert(!origBB->isEntry() && "no predecessors for entry block");
201
0
    auto values = pullbackTupleElements.lookup(origBB);
202
1.75k
    assert(values.size() && "pullback tuple cannot be empty");
203
0
    return values[0];
204
1.75k
  }
205
206
  //--------------------------------------------------------------------------//
207
  // Type transformer
208
  //--------------------------------------------------------------------------//
209
210
  /// Get the type lowering for the given AST type.
211
61.2k
  const Lowering::TypeLowering &getTypeLowering(Type type) {
212
61.2k
    auto pbGenSig =
213
61.2k
        getPullback().getLoweredFunctionType()->getSubstGenericSignature();
214
61.2k
    Lowering::AbstractionPattern pattern(pbGenSig,
215
61.2k
                                         type->getReducedType(pbGenSig));
216
61.2k
    return getPullback().getTypeLowering(pattern, type);
217
61.2k
  }
218
219
  /// Remap any archetypes into the current function's context.
220
191k
  SILType remapType(SILType ty) {
221
191k
    if (ty.hasArchetype())
222
15.1k
      ty = ty.mapTypeOutOfContext();
223
191k
    auto remappedType = ty.getASTType()->getReducedType(
224
191k
        getPullback().getLoweredFunctionType()->getSubstGenericSignature());
225
191k
    auto remappedSILType =
226
191k
        SILType::getPrimitiveType(remappedType, ty.getCategory());
227
191k
    return getPullback().mapTypeIntoContext(remappedSILType);
228
191k
  }
229
230
144k
  llvm::Optional<TangentSpace> getTangentSpace(CanType type) {
231
    // Use witness generic signature to remap types.
232
144k
    type =
233
144k
        getWitness()->getDerivativeGenericSignature().getReducedType(
234
144k
            type);
235
144k
    return type->getAutoDiffTangentSpace(
236
144k
        LookUpConformanceInModule(getModule().getSwiftModule()));
237
144k
  }
238
239
  /// Returns the tangent value category of the given value.
240
124k
  SILValueCategory getTangentValueCategory(SILValue v) {
241
    // Tangent value category table:
242
    //
243
    // Let $L be a loadable type and $*A be an address-only type.
244
    //
245
    // Original type | Tangent type loadable? | Tangent value category and type
246
    // --------------|------------------------|--------------------------------
247
    // $L            | loadable               | object, $L' (no mismatch)
248
    // $*A           | loadable               | address, $*L' (create a buffer)
249
    // $L            | address-only           | address, $*A' (no alternative)
250
    // $*A           | address-only           | address, $*A' (no alternative)
251
252
    // TODO(https://github.com/apple/swift/issues/55523): Make "tangent value category" depend solely on whether the tangent type is loadable or address-only.
253
    //
254
    // For loadable tangent types, using symbolic adjoint values instead of
255
    // concrete adjoint buffers is more efficient.
256
257
    // Quick check: if the value has an address type, the tangent value category
258
    // is currently always "address".
259
124k
    if (v->getType().isAddress())
260
63.3k
      return SILValueCategory::Address;
261
    // If the value has an object type and the tangent type is not address-only,
262
    // then the tangent value category is "object".
263
61.1k
    auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType());
264
61.1k
    auto tanASTType = tanSpace->getCanonicalType();
265
61.1k
    if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable())
266
58.7k
      return SILValueCategory::Object;
267
    // Otherwise, the tangent value category is "address".
268
2.44k
    return SILValueCategory::Address;
269
61.1k
  }
270
271
  /// Assuming the given type conforms to `Differentiable` after remapping,
272
  /// returns the associated tangent space type.
273
58.4k
  SILType getRemappedTangentType(SILType type) {
274
58.4k
    return SILType::getPrimitiveType(
275
58.4k
        getTangentSpace(remapType(type).getASTType())->getCanonicalType(),
276
58.4k
        type.getCategory());
277
58.4k
  }
278
279
  /// Substitutes all replacement types of the given substitution map using the
280
  /// pullback function's substitution map.
281
1.46k
  SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) {
282
1.46k
    return substMap.subst(getPullback().getForwardingSubstitutionMap());
283
1.46k
  }
284
285
  //--------------------------------------------------------------------------//
286
  // Temporary value management
287
  //--------------------------------------------------------------------------//
288
289
  /// Record a temporary value for cleanup before its block's terminator.
290
18.0k
  SILValue recordTemporary(SILValue value) {
291
18.0k
    assert(value->getType().isObject());
292
0
    assert(value->getFunction() == &getPullback());
293
0
    auto inserted = blockTemporaries[value->getParentBlock()].insert(value);
294
18.0k
    (void)inserted;
295
18.0k
    LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
296
18.0k
    assert(inserted && "Temporary already recorded?");
297
0
    return value;
298
18.0k
  }
299
300
  /// Clean up all temporary values for the given pullback block.
301
6.74k
  void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) {
302
6.74k
    assert(bb->getParent() == &getPullback());
303
6.74k
    LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb"
304
6.74k
                                  << bb->getDebugID() << '\n');
305
6.74k
    for (auto temp : blockTemporaries[bb])
306
18.4k
      builder.emitDestroyValueOperation(loc, temp);
307
6.74k
    blockTemporaries[bb].clear();
308
6.74k
  }
309
310
  //--------------------------------------------------------------------------//
311
  // Adjoint value factory methods
312
  //--------------------------------------------------------------------------//
313
314
19.1k
  AdjointValue makeZeroAdjointValue(SILType type) {
315
19.1k
    return AdjointValue::createZero(allocator, remapType(type));
316
19.1k
  }
317
318
23.1k
  AdjointValue makeConcreteAdjointValue(SILValue value) {
319
23.1k
    return AdjointValue::createConcrete(allocator, value);
320
23.1k
  }
321
322
  AdjointValue makeAggregateAdjointValue(SILType type,
323
336
                                         ArrayRef<AdjointValue> elements) {
324
336
    return AdjointValue::createAggregate(allocator, remapType(type), elements);
325
336
  }
326
327
  AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint,
328
                                          AdjointValue eltToAdd,
329
860
                                          FieldLocator fieldLocator) {
330
860
    auto *addElementValue =
331
860
        new AddElementValue(baseAdjoint, eltToAdd, fieldLocator);
332
860
    return AdjointValue::createAddElement(allocator, baseAdjoint.getType(),
333
860
                                          addElementValue);
334
860
  }
335
336
  //--------------------------------------------------------------------------//
337
  // Adjoint value materialization
338
  //--------------------------------------------------------------------------//
339
340
  /// Materializes an adjoint value. The type of the given adjoint value must be
341
  /// loadable.
342
17.6k
  SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) {
343
17.6k
    assert(val.getType().isObject());
344
17.6k
    LLVM_DEBUG(getADDebugStream()
345
17.6k
               << "Materializing adjoint for " << val << '\n');
346
17.6k
    SILValue result;
347
17.6k
    switch (val.getKind()) {
348
2.36k
    case AdjointValueKind::Zero:
349
2.36k
      result = recordTemporary(builder.emitZero(loc, val.getSwiftType()));
350
2.36k
      break;
351
112
    case AdjointValueKind::Aggregate: {
352
112
      SmallVector<SILValue, 8> elements;
353
176
      for (auto i : range(val.getNumAggregateElements())) {
354
176
        auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
355
176
        elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
356
176
      }
357
112
      if (val.getType().is<TupleType>())
358
0
        result = recordTemporary(
359
0
            builder.createTuple(loc, val.getType(), elements));
360
112
      else
361
112
        result = recordTemporary(
362
112
            builder.createStruct(loc, val.getType(), elements));
363
112
      break;
364
0
    }
365
14.7k
    case AdjointValueKind::Concrete:
366
14.7k
      result = val.getConcreteValue();
367
14.7k
      break;
368
384
    case AdjointValueKind::AddElement: {
369
384
      auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType();
370
384
      auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType);
371
384
      materializeAdjointIndirect(val, baseAdjAlloc, loc);
372
373
384
      auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation(
374
384
          loc, baseAdjAlloc, LoadOwnershipQualifier::Take));
375
376
384
      builder.createDeallocStack(loc, baseAdjAlloc);
377
378
384
      result = baseAdjConcrete;
379
384
      break;
380
0
    }
381
17.6k
    }
382
17.6k
    if (auto debugInfo = val.getDebugInfo())
383
6.69k
      builder.createDebugValue(
384
6.69k
          debugInfo->first.getLocation(), result, debugInfo->second);
385
17.6k
    return result;
386
17.6k
  }
387
388
  /// Materializes an adjoint value indirectly to a SIL buffer.
389
  void materializeAdjointIndirect(AdjointValue val, SILValue destAddress,
390
768
                                  SILLocation loc) {
391
768
    assert(destAddress->getType().isAddress());
392
0
    switch (val.getKind()) {
393
    /// If adjoint value is a symbolic zero, emit a call to
394
    /// `AdditiveArithmetic.zero`.
395
328
    case AdjointValueKind::Zero:
396
328
      builder.emitZeroIntoBuffer(loc, destAddress, IsInitialization);
397
328
      break;
398
    /// If adjoint value is a symbolic aggregate (tuple or struct), recursively
399
    /// materialize the symbolic tuple or struct, filling the
400
    /// buffer.
401
0
    case AdjointValueKind::Aggregate: {
402
0
      if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
403
0
        for (auto idx : range(val.getNumAggregateElements())) {
404
0
          auto eltTy = SILType::getPrimitiveAddressType(
405
0
              tupTy->getElementType(idx)->getCanonicalType());
406
0
          auto *eltBuf =
407
0
              builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
408
0
          materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
409
0
        }
410
0
      } else if (auto *structDecl =
411
0
                     val.getSwiftType()->getStructOrBoundGenericStruct()) {
412
0
        auto fieldIt = structDecl->getStoredProperties().begin();
413
0
        for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
414
0
             ++fieldIt, ++i) {
415
0
          auto eltBuf =
416
0
              builder.createStructElementAddr(loc, destAddress, *fieldIt);
417
0
          materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
418
0
        }
419
0
      } else {
420
0
        llvm_unreachable("Not an aggregate type");
421
0
      }
422
0
      break;
423
0
    }
424
    /// If adjoint value is concrete, it is already materialized. Store it in
425
    /// the destination address.
426
56
    case AdjointValueKind::Concrete: {
427
56
      auto concreteVal = val.getConcreteValue();
428
56
      auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal);
429
56
      builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress,
430
56
                                      StoreOwnershipQualifier::Init);
431
56
      break;
432
0
    }
433
384
    case AdjointValueKind::AddElement: {
434
384
      auto baseAdjoint = val;
435
384
      auto baseAdjointType = baseAdjoint.getType();
436
437
      // Current adjoint may be made up of layers of `AddElement` adjoints.
438
      // We can iteratively gather the list of elements to add instead of making
439
      // recursive calls to `materializeAdjointIndirect`.
440
384
      SmallVector<AddElementValue *, 4> addEltAdjValues;
441
442
524
      do {
443
524
        auto addElementValue = baseAdjoint.getAddElementValue();
444
524
        addEltAdjValues.push_back(addElementValue);
445
524
        baseAdjoint = addElementValue->baseAdjoint;
446
524
        assert(baseAdjointType == baseAdjoint.getType());
447
524
      } while (baseAdjoint.getKind() == AdjointValueKind::AddElement);
448
449
0
      materializeAdjointIndirect(baseAdjoint, destAddress, loc);
450
451
524
      for (auto *addElementValue : addEltAdjValues) {
452
524
        auto eltToAdd = addElementValue->eltToAdd;
453
454
524
        SILValue baseAdjEltAddr;
455
524
        if (baseAdjoint.getType().is<TupleType>()) {
456
16
          baseAdjEltAddr = builder.createTupleElementAddr(
457
16
              loc, destAddress, addElementValue->getFieldIndex());
458
508
        } else {
459
508
          baseAdjEltAddr = builder.createStructElementAddr(
460
508
              loc, destAddress, addElementValue->getFieldDecl());
461
508
        }
462
463
524
        auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc);
464
        // Copy `eltToAddMaterialized` so we have a value with owned ownership
465
        // semantics, required for using `eltToAddMaterialized` in a `store`
466
        // instruction.
467
524
        auto eltToAddMaterializedCopy =
468
524
            builder.emitCopyValueOperation(loc, eltToAddMaterialized);
469
524
        auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType());
470
524
        builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy,
471
524
                                        eltToAddAlloc,
472
524
                                        StoreOwnershipQualifier::Init);
473
474
524
        builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc);
475
524
        builder.createDestroyAddr(loc, eltToAddAlloc);
476
524
        builder.createDeallocStack(loc, eltToAddAlloc);
477
524
      }
478
479
384
      break;
480
0
    }
481
768
    }
482
768
  }
483
484
  //--------------------------------------------------------------------------//
485
  // Adjoint value mapping
486
  //--------------------------------------------------------------------------//
487
488
  /// Returns true if the given value in the original function has a
489
  /// corresponding adjoint value.
490
5.48k
  bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const {
491
5.48k
    assert(origBB->getParent() == &getOriginal());
492
0
    assert(originalValue->getType().isObject());
493
0
    return valueMap.count({origBB, originalValue});
494
5.48k
  }
495
496
  /// Initializes the adjoint value for the original value. Asserts that the
497
  /// original value does not already have an adjoint value.
498
  void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
499
11.4k
                       AdjointValue adjointValue) {
500
11.4k
    LLVM_DEBUG(getADDebugStream()
501
11.4k
               << "Setting adjoint value for " << originalValue);
502
11.4k
    assert(origBB->getParent() == &getOriginal());
503
0
    assert(originalValue->getType().isObject());
504
0
    assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
505
0
    assert(adjointValue.getType().isObject());
506
0
    assert(originalValue->getFunction() == &getOriginal());
507
    // The adjoint value must be in the tangent space.
508
0
    assert(adjointValue.getType() ==
509
11.4k
           getRemappedTangentType(originalValue->getType()));
510
    // Try to assign a debug variable.
511
11.4k
    if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
512
5.53k
      LLVM_DEBUG({
513
5.53k
        auto &s = getADDebugStream();
514
5.53k
        s << "Found debug variable: \"" << debugInfo->second.Name
515
5.53k
          << "\"\nLocation: ";
516
5.53k
        debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
517
5.53k
        s << '\n';
518
5.53k
      });
519
5.53k
      adjointValue.setDebugInfo(*debugInfo);
520
5.91k
    } else {
521
5.91k
      LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
522
5.91k
    }
523
    // Insert into dictionary.
524
11.4k
    auto insertion =
525
11.4k
        valueMap.try_emplace({origBB, originalValue}, adjointValue);
526
11.4k
    LLVM_DEBUG(getADDebugStream()
527
11.4k
               << "The new adjoint value, replacing the existing one, is: "
528
11.4k
               << insertion.first->getSecond() << '\n');
529
11.4k
    if (!insertion.second)
530
4.49k
      insertion.first->getSecond() = adjointValue;
531
11.4k
  }
532
533
  /// Returns the adjoint value for a value in the original function.
534
  ///
535
  /// This method first tries to find an existing entry in the adjoint value
536
  /// mapping. If no entry exists, creates a zero adjoint value.
537
18.5k
  AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) {
538
18.5k
    assert(origBB->getParent() == &getOriginal());
539
0
    assert(originalValue->getType().isObject());
540
0
    assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
541
0
    assert(originalValue->getFunction() == &getOriginal());
542
0
    auto insertion = valueMap.try_emplace(
543
18.5k
        {origBB, originalValue},
544
18.5k
        makeZeroAdjointValue(getRemappedTangentType(originalValue->getType())));
545
18.5k
    auto it = insertion.first;
546
18.5k
    return it->getSecond();
547
18.5k
  }
548
549
  /// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets
550
  /// the sum as the new adjoint value.
551
  void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
552
13.4k
                       AdjointValue newAdjointValue, SILLocation loc) {
553
13.4k
    assert(origBB->getParent() == &getOriginal());
554
0
    assert(originalValue->getType().isObject());
555
0
    assert(newAdjointValue.getType().isObject());
556
0
    assert(originalValue->getFunction() == &getOriginal());
557
13.4k
    LLVM_DEBUG(getADDebugStream()
558
13.4k
               << "Adding adjoint value for " << originalValue);
559
    // The adjoint value must be in the tangent space.
560
13.4k
    assert(newAdjointValue.getType() ==
561
13.4k
           getRemappedTangentType(originalValue->getType()));
562
    // Try to assign a debug variable.
563
13.4k
    if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
564
6.15k
      LLVM_DEBUG({
565
6.15k
        auto &s = getADDebugStream();
566
6.15k
        s << "Found debug variable: \"" << debugInfo->second.Name
567
6.15k
          << "\"\nLocation: ";
568
6.15k
        debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
569
6.15k
        s << '\n';
570
6.15k
      });
571
6.15k
      newAdjointValue.setDebugInfo(*debugInfo);
572
7.26k
    } else {
573
7.26k
      LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
574
7.26k
    }
575
13.4k
    auto insertion =
576
13.4k
        valueMap.try_emplace({origBB, originalValue}, newAdjointValue);
577
13.4k
    auto inserted = insertion.second;
578
13.4k
    if (inserted)
579
10.7k
      return;
580
    // If adjoint already exists, accumulate the adjoint onto the existing
581
    // adjoint.
582
2.63k
    auto it = insertion.first;
583
2.63k
    auto existingValue = it->getSecond();
584
2.63k
    valueMap.erase(it);
585
2.63k
    auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc);
586
    // If the original value is the `Array` result of an
587
    // `array.uninitialized_intrinsic` application, accumulate adjoint buffers
588
    // for the array element addresses.
589
2.63k
    accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal,
590
2.63k
                                                 loc);
591
2.63k
    setAdjointValue(origBB, originalValue, adjVal);
592
2.63k
  }
593
594
  /// Get the pullback block argument corresponding to the given original block
595
  /// and active value.
596
  SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB,
597
3.81k
                                                   SILValue activeValue) {
598
3.81k
    assert(getTangentValueCategory(activeValue) == SILValueCategory::Object);
599
0
    assert(origBB->getParent() == &getOriginal());
600
0
    auto pullbackBBArg =
601
3.81k
        activeValuePullbackBBArgumentMap[{origBB, activeValue}];
602
3.81k
    assert(pullbackBBArg);
603
0
    assert(pullbackBBArg->getParent() == getPullbackBlock(origBB));
604
0
    return pullbackBBArg;
605
3.81k
  }
606
607
  //--------------------------------------------------------------------------//
608
  // Adjoint value accumulation
609
  //--------------------------------------------------------------------------//
610
611
  /// Given two adjoint values, accumulates them and returns their sum.
612
  AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs,
613
                                        SILLocation loc);
614
615
  //--------------------------------------------------------------------------//
616
  // Adjoint buffer mapping
617
  //--------------------------------------------------------------------------//
618
619
  /// If the given original value is an address projection, returns a
620
  /// corresponding adjoint projection to be used as its adjoint buffer.
621
  ///
622
  /// Helper function for `getAdjointBuffer`.
623
  SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue);
624
625
  /// Returns the adjoint buffer for the original value.
626
  ///
627
  /// This method first tries to find an existing entry in the adjoint buffer
628
  /// mapping. If no entry exists, creates a zero adjoint buffer.
629
40.5k
  SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) {
630
40.5k
    assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
631
0
    assert(originalValue->getFunction() == &getOriginal());
632
0
    auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue());
633
40.5k
    if (!insertion.second) // not inserted
634
23.7k
      return insertion.first->getSecond();
635
636
    // If the original buffer is a projection, return a corresponding projection
637
    // into the adjoint buffer.
638
16.7k
    if (auto adjProj = getAdjointProjection(origBB, originalValue))
639
6.65k
      return (bufferMap[{origBB, originalValue}] = adjProj);
640
641
10.0k
    LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for "
642
10.0k
               << originalValue
643
10.0k
               << "in bb" << origBB->getDebugID() << '\n');
644
645
10.0k
    auto bufType = getRemappedTangentType(originalValue->getType());
646
    // Set insertion point for local allocation builder: before the last local
647
    // allocation, or at the start of the pullback function's entry if no local
648
    // allocations exist yet.
649
10.0k
    auto debugInfo = findDebugLocationAndVariable(originalValue);
650
10.0k
    SILLocation loc = debugInfo ? debugInfo->first.getLocation()
651
10.0k
                                : RegularLocation::getAutoGeneratedLocation();
652
10.0k
    llvm::SmallString<32> adjName;
653
10.0k
    auto *newBuf = createFunctionLocalAllocation(
654
10.0k
        bufType, loc, /*zeroInitialize*/ true,
655
10.0k
        swift::transform(debugInfo,
656
10.0k
          [&](AdjointValue::DebugInfo di) {
657
5.32k
            llvm::raw_svector_ostream adjNameStream(adjName);
658
5.32k
            SILDebugVariable &dv = di.second;
659
5.32k
            dv.ArgNo = 0;
660
5.32k
            adjNameStream << "derivative of '" << dv.Name << "'";
661
5.32k
            if (SILDebugLocation origBBLoc = origBB->front().getDebugLocation()) {
662
5.32k
              adjNameStream << " in scope at ";
663
5.32k
              origBBLoc.getLocation().print(adjNameStream, getASTContext().SourceMgr);
664
5.32k
            }
665
5.32k
            adjNameStream << " (scope #" << origBB->getDebugID() << ")";
666
5.32k
            dv.Name = adjName;
667
5.32k
            return dv;
668
5.32k
          }));
669
10.0k
    return (insertion.first->getSecond() = newBuf);
670
16.7k
  }
671
672
  /// Initializes the adjoint buffer for the original value. Asserts that the
673
  /// original value does not already have an adjoint buffer.
674
  void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
675
1.74k
                        SILValue adjointBuffer) {
676
1.74k
    assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
677
0
    auto insertion =
678
1.74k
        bufferMap.try_emplace({origBB, originalValue}, adjointBuffer);
679
1.74k
    assert(insertion.second && "Adjoint buffer already exists");
680
0
    (void)insertion;
681
1.74k
  }
682
683
  /// Accumulates `rhsAddress` into the adjoint buffer corresponding to the
684
  /// original value.
685
  void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
686
7.48k
                          SILValue rhsAddress, SILLocation loc) {
687
7.48k
    assert(getTangentValueCategory(originalValue) ==
688
7.48k
               SILValueCategory::Address &&
689
7.48k
           rhsAddress->getType().isAddress());
690
0
    assert(originalValue->getFunction() == &getOriginal());
691
0
    assert(rhsAddress->getFunction() == &getPullback());
692
0
    auto adjointBuffer = getAdjointBuffer(origBB, originalValue);
693
694
7.48k
    LLVM_DEBUG(getADDebugStream() << "Adding"
695
7.48k
               << rhsAddress << "to adjoint ("
696
7.48k
               << adjointBuffer << ") of "
697
7.48k
               << originalValue
698
7.48k
               << "in bb" << origBB->getDebugID() << '\n');
699
700
7.48k
    builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress);
701
7.48k
  }
702
703
  /// Returns a next insertion point for creating a local allocation: either
704
  /// before the previous local allocation, or at the start of the pullback
705
  /// entry if no local allocations exist.
706
  ///
707
  /// Helper for `createFunctionLocalAllocation`.
708
17.2k
  SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() {
709
    // If there are no local allocations, insert at the pullback entry start.
710
17.2k
    if (functionLocalAllocations.empty())
711
7.60k
      return getPullback().getEntryBlock()->begin();
712
    // Otherwise, insert before the last local allocation. Inserting before
713
    // rather than after ensures that allocation and zero initialization
714
    // instructions are grouped together.
715
9.61k
    auto lastLocalAlloc = functionLocalAllocations.back();
716
9.61k
    return lastLocalAlloc->getDefiningInstruction()->getIterator();
717
17.2k
  }
718
719
  /// Creates and returns a local allocation with the given type.
720
  ///
721
  /// Local allocations are created uninitialized in the pullback entry and
722
  /// deallocated in the pullback exit. All local allocations not in
723
  /// `destroyedLocalAllocations` are also destroyed in the pullback exit.
724
  ///
725
  /// Helper for `getAdjointBuffer`.
726
  AllocStackInst *createFunctionLocalAllocation(
727
      SILType type, SILLocation loc, bool zeroInitialize = false,
728
12.1k
      llvm::Optional<SILDebugVariable> varInfo = llvm::None) {
729
    // Set insertion point for local allocation builder: before the last local
730
    // allocation, or at the start of the pullback function's entry if no local
731
    // allocations exist yet.
732
12.1k
    localAllocBuilder.setInsertionPoint(
733
12.1k
        getPullback().getEntryBlock(),
734
12.1k
        getNextFunctionLocalAllocationInsertionPoint());
735
    // Create and return local allocation.
736
12.1k
    auto *alloc = localAllocBuilder.createAllocStack(loc, type, varInfo);
737
12.1k
    functionLocalAllocations.push_back(alloc);
738
    // Zero-initialize if requested.
739
12.1k
    if (zeroInitialize)
740
10.6k
      localAllocBuilder.emitZeroIntoBuffer(loc, alloc, IsInitialization);
741
12.1k
    return alloc;
742
12.1k
  }
743
744
  //--------------------------------------------------------------------------//
745
  // Optional differentiation
746
  //--------------------------------------------------------------------------//
747
748
  /// Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>`
749
  /// type, creates an `Optional<T>.TangentVector` buffer from it.
750
  ///
751
  /// `wrappedAdjoint` may be an object or address value, both cases are
752
  /// handled.
753
  AllocStackInst *createOptionalAdjoint(SILBasicBlock *bb,
754
                                        SILValue wrappedAdjoint,
755
                                        SILType optionalTy);
756
757
  /// Accumulate optional buffer from `wrappedAdjoint`.
758
  void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb,
759
                                          SILValue optionalBuffer,
760
                                          SILValue wrappedAdjoint);
761
762
  /// Set optional value from `wrappedAdjoint`.
763
  void setAdjointValueForOptional(SILBasicBlock *bb, SILValue optionalValue,
764
                                  SILValue wrappedAdjoint);
765
766
  //--------------------------------------------------------------------------//
767
  // Array literal initialization differentiation
768
  //--------------------------------------------------------------------------//
769
770
  /// Given the adjoint value of an array initialized from an
771
  /// `array.uninitialized_intrinsic` application and an array element index,
772
  /// returns an `alloc_stack` containing the adjoint value of the array element
773
  /// at the given index by applying `Array.TangentVector.subscript`.
774
  AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
775
                                               int eltIndex, SILLocation loc);
776
777
  /// Given the adjoint value of an array initialized from an
778
  /// `array.uninitialized_intrinsic` application, accumulates the adjoint
779
  /// value's elements into the adjoint buffers of its element addresses.
780
  void accumulateArrayLiteralElementAddressAdjoints(
781
      SILBasicBlock *origBB, SILValue originalValue,
782
      AdjointValue arrayAdjointValue, SILLocation loc);
783
784
  //--------------------------------------------------------------------------//
785
  // CFG mapping
786
  //--------------------------------------------------------------------------//
787
788
18.1k
  SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) {
789
18.1k
    return pullbackBBMap.lookup(originalBlock);
790
18.1k
  }
791
792
  SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock,
793
2.35k
                                            SILBasicBlock *successorBlock) {
794
2.35k
    return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock});
795
2.35k
  }
796
797
  //--------------------------------------------------------------------------//
798
  // Debug info
799
  //--------------------------------------------------------------------------//
800
801
60.5k
  const SILDebugScope *remapScope(const SILDebugScope *DS) {
802
60.5k
    return scopeCloner.getOrCreateClonedScope(DS);
803
60.5k
  }
804
805
  //--------------------------------------------------------------------------//
806
  // Debugging utilities
807
  //--------------------------------------------------------------------------//
808
809
0
  void printAdjointValueMapping() {
810
0
    // Group original/adjoint values by basic block.
811
0
    llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp;
812
0
    for (auto pair : valueMap) {
813
0
      auto origPair = pair.first;
814
0
      auto *origBB = origPair.first;
815
0
      auto origValue = origPair.second;
816
0
      auto adjValue = pair.second;
817
0
      tmp[origBB].insert({origValue, adjValue});
818
0
    }
819
0
    // Print original/adjoint values per basic block.
820
0
    auto &s = getADDebugStream() << "Adjoint value mapping:\n";
821
0
    for (auto &origBB : getOriginal()) {
822
0
      if (!pullbackBBMap.count(&origBB))
823
0
        continue;
824
0
      auto bbValueMap = tmp[&origBB];
825
0
      s << "bb" << origBB.getDebugID();
826
0
      s << " (size " << bbValueMap.size() << "):\n";
827
0
      for (auto valuePair : bbValueMap) {
828
0
        auto origValue = valuePair.first;
829
0
        auto adjValue = valuePair.second;
830
0
        s << "ORIG: " << origValue;
831
0
        s << "ADJ: " << adjValue << '\n';
832
0
      }
833
0
      s << '\n';
834
0
    }
835
0
  }
836
837
0
  void printAdjointBufferMapping() {
838
0
    // Group original/adjoint buffers by basic block.
839
0
    llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp;
840
0
    for (auto pair : bufferMap) {
841
0
      auto origPair = pair.first;
842
0
      auto *origBB = origPair.first;
843
0
      auto origBuf = origPair.second;
844
0
      auto adjBuf = pair.second;
845
0
      tmp[origBB][origBuf] = adjBuf;
846
0
    }
847
0
    // Print original/adjoint buffers per basic block.
848
0
    auto &s = getADDebugStream() << "Adjoint buffer mapping:\n";
849
0
    for (auto &origBB : getOriginal()) {
850
0
      if (!pullbackBBMap.count(&origBB))
851
0
        continue;
852
0
      auto bbBufferMap = tmp[&origBB];
853
0
      s << "bb" << origBB.getDebugID();
854
0
      s << " (size " << bbBufferMap.size() << "):\n";
855
0
      for (auto valuePair : bbBufferMap) {
856
0
        auto origBuf = valuePair.first;
857
0
        auto adjBuf = valuePair.second;
858
0
        s << "ORIG: " << origBuf;
859
0
        s << "ADJ: " << adjBuf << '\n';
860
0
      }
861
0
      s << '\n';
862
0
    }
863
0
  }
864
865
public:
866
  //--------------------------------------------------------------------------//
867
  // Entry point
868
  //--------------------------------------------------------------------------//
869
870
  /// Performs pullback generation on the empty pullback function. Returns true
871
  /// if any error occurs.
872
  bool run();
873
874
  /// Performs pullback generation on the empty pullback function, given that
875
  /// the original function is a "semantic member accessor".
876
  ///
877
  /// "Semantic member accessors" are attached to member properties that have a
878
  /// corresponding tangent stored property in the parent `TangentVector` type.
879
  /// These accessors have special-case pullback generation based on their
880
  /// semantic behavior.
881
  ///
882
  /// Returns true if any error occurs.
883
  bool runForSemanticMemberAccessor();
884
  bool runForSemanticMemberGetter();
885
  bool runForSemanticMemberSetter();
886
887
  /// If original result is non-varied, it will always have a zero derivative.
888
  /// Skip full pullback generation and simply emit zero derivatives for wrt
889
  /// parameters.
890
  void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult);
891
892
  /// Public helper so that our users can get the underlying newly created
893
  /// function.
894
615k
  SILFunction &getPullback() const { return vjpCloner.getPullback(); }
895
896
  using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>;
897
898
  /// Determines the pullback successor block for a given original block and one
899
  /// of its predecessors. When a trampoline block is necessary, emits code into
900
  /// the trampoline block to trampoline the original block's active value's
901
  /// adjoint values.
902
  ///
903
  /// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint
904
  /// values to the pullback successor blocks in which they are used. This
905
  /// allows us to release those values in pullback successor blocks that do not
906
  /// use them.
907
  SILBasicBlock *
908
  buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB,
909
                         llvm::SmallDenseMap<SILValue, TrampolineBlockSet>
910
                             &pullbackTrampolineBlockMap);
911
912
  /// Emits pullback code in the corresponding pullback block.
913
  void visitSILBasicBlock(SILBasicBlock *bb);
914
915
34.6k
  void visit(SILInstruction *inst) {
916
34.6k
    if (errorOccurred)
917
0
      return;
918
919
34.6k
    LLVM_DEBUG(getADDebugStream()
920
34.6k
               << "PullbackCloner visited:\n[ORIG]" << *inst);
921
34.6k
#ifndef NDEBUG
922
34.6k
    auto beforeInsertion = std::prev(builder.getInsertionPoint());
923
34.6k
#endif
924
34.6k
    SILInstructionVisitor::visit(inst);
925
34.6k
    LLVM_DEBUG({
926
34.6k
      auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback (pb bb" <<
927
34.6k
        builder.getInsertionBB()->getDebugID() << "):\n";
928
34.6k
      auto afterInsertion = builder.getInsertionPoint();
929
34.6k
      for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
930
34.6k
        s << *it;
931
34.6k
    });
932
34.6k
  }
933
934
  /// Fallback instruction visitor for unhandled instructions.
935
  /// Emit a general non-differentiability diagnostic.
936
20
  void visitSILInstruction(SILInstruction *inst) {
937
20
    LLVM_DEBUG(getADDebugStream()
938
20
               << "Unhandled instruction in PullbackCloner: " << *inst);
939
20
    getContext().emitNondifferentiabilityError(
940
20
        inst, getInvoker(), diag::autodiff_expression_not_differentiable_note);
941
20
    errorOccurred = true;
942
20
  }
943
944
  /// Handle `apply` instruction.
945
  ///   Original: (y0, y1, ...) = apply @fn (x0, x1, ...)
946
  ///    Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...)
947
6.07k
  void visitApplyInst(ApplyInst *ai) {
948
6.07k
    assert(getPullbackInfo().shouldDifferentiateApplySite(ai));
949
950
    // Skip `array.uninitialized_intrinsic` applications, which have special
951
    // `store` and `copy_addr` support.
952
6.07k
    if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC))
953
208
      return;
954
5.86k
    auto loc = ai->getLoc();
955
5.86k
    auto *bb = ai->getParent();
956
    // Handle `array.finalize_intrinsic` applications.
957
    // `array.finalize_intrinsic` semantically behaves like an identity
958
    // function.
959
5.86k
    if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) {
960
208
      assert(ai->getNumArguments() == 1 &&
961
208
             "Expected intrinsic to have one operand");
962
      // Accumulate result's adjoint into argument's adjoint.
963
0
      auto adjResult = getAdjointValue(bb, ai);
964
208
      auto origArg = ai->getArgumentsWithoutIndirectResults().front();
965
208
      addAdjointValue(bb, origArg, adjResult, loc);
966
208
      return;
967
208
    }
968
    // Replace a call to a function with a call to its pullback.
969
5.65k
    auto &nestedApplyInfo = getContext().getNestedApplyInfo();
970
5.65k
    auto applyInfoLookup = nestedApplyInfo.find(ai);
971
    // If no `NestedApplyInfo` was found, then this task doesn't need to be
972
    // differentiated.
973
5.65k
    if (applyInfoLookup == nestedApplyInfo.end()) {
974
      // Must not be active.
975
0
      assert(!getActivityInfo().isActive(ai, getConfig()));
976
0
      return;
977
0
    }
978
5.65k
    auto applyInfo = applyInfoLookup->getSecond();
979
980
    // Get the original result of the `apply` instruction.
981
5.65k
    SmallVector<SILValue, 8> origDirectResults;
982
5.65k
    forEachApplyDirectResult(ai, [&](SILValue directResult) {
983
3.36k
      origDirectResults.push_back(directResult);
984
3.36k
    });
985
5.65k
    SmallVector<SILValue, 8> origAllResults;
986
5.65k
    collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults);
987
    // Append semantic result arguments after original results.
988
8.75k
    for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) {
989
8.75k
      auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg(
990
8.75k
          ai->getNumIndirectResults() + paramIdx);
991
8.75k
      if (!paramInfo.isAutoDiffSemanticResult())
992
8.24k
        continue;
993
512
      origAllResults.push_back(
994
512
          ai->getArgumentsWithoutIndirectResults()[paramIdx]);
995
512
    }
996
997
    // Get callee pullback arguments.
998
5.65k
    SmallVector<SILValue, 8> args;
999
1000
    // Handle callee pullback indirect results.
1001
    // Create local allocations for these and destroy them after the call.
1002
5.65k
    auto pullback = getPullbackTupleElement(ai);
1003
5.65k
    auto pullbackType =
1004
5.65k
        remapType(pullback->getType()).castTo<SILFunctionType>();
1005
1006
5.65k
    auto actualPullbackType = applyInfo.originalPullbackType
1007
5.65k
                                  ? *applyInfo.originalPullbackType
1008
5.65k
                                  : pullbackType;
1009
5.65k
    actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule());
1010
5.65k
    SmallVector<AllocStackInst *, 4> pullbackIndirectResults;
1011
5.65k
    for (auto indRes : actualPullbackType->getIndirectFormalResults()) {
1012
2.72k
      auto *alloc = builder.createAllocStack(
1013
2.72k
          loc, remapType(indRes.getSILStorageInterfaceType()));
1014
2.72k
      pullbackIndirectResults.push_back(alloc);
1015
2.72k
      args.push_back(alloc);
1016
2.72k
    }
1017
1018
    // Collect callee pullback formal arguments.
1019
5.75k
    for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) {
1020
5.75k
      assert(resultIndex < origAllResults.size());
1021
0
      auto origResult = origAllResults[resultIndex];
1022
      // Get the seed (i.e. adjoint value of the original result).
1023
5.75k
      SILValue seed;
1024
5.75k
      switch (getTangentValueCategory(origResult)) {
1025
3.26k
      case SILValueCategory::Object:
1026
3.26k
        seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc);
1027
3.26k
        break;
1028
2.48k
      case SILValueCategory::Address:
1029
2.48k
        seed = getAdjointBuffer(bb, origResult);
1030
2.48k
        break;
1031
5.75k
      }
1032
5.75k
      args.push_back(seed);
1033
5.75k
    }
1034
1035
    // If callee pullback was reabstracted in VJP, reabstract callee pullback.
1036
5.65k
    if (applyInfo.originalPullbackType) {
1037
1.46k
      SILOptFunctionBuilder fb(getContext().getTransform());
1038
1.46k
      pullback = reabstractFunction(
1039
1.46k
          builder, fb, loc, pullback, *applyInfo.originalPullbackType,
1040
1.46k
          [this](SubstitutionMap subs) -> SubstitutionMap {
1041
1.46k
            return this->remapSubstitutionMap(subs);
1042
1.46k
          });
1043
1.46k
    }
1044
1045
    // Call the callee pullback.
1046
5.65k
    auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
1047
5.65k
                                             args);
1048
5.65k
    builder.emitDestroyValueOperation(loc, pullback);
1049
1050
    // Extract all results from `pullbackCall`.
1051
5.65k
    SmallVector<SILValue, 8> dirResults;
1052
5.65k
    extractAllElements(pullbackCall, builder, dirResults);
1053
    // Get all results in type-defined order.
1054
5.65k
    SmallVector<SILValue, 8> allResults;
1055
5.65k
    collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults);
1056
1057
5.65k
    LLVM_DEBUG({
1058
5.65k
      auto &s = getADDebugStream();
1059
5.65k
      s << "All results of the nested pullback call:\n";
1060
5.65k
      llvm::for_each(allResults, [&](SILValue v) { s << v; });
1061
5.65k
    });
1062
1063
    // Accumulate adjoints for original differentiation parameters.
1064
5.65k
    auto allResultsIt = allResults.begin();
1065
8.75k
    for (unsigned i : applyInfo.config.parameterIndices->getIndices()) {
1066
8.75k
      auto origArg = ai->getArgument(ai->getNumIndirectResults() + i);
1067
      // Skip adjoint accumulation for semantic results arguments.
1068
8.75k
      auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg(
1069
8.75k
          ai->getNumIndirectResults() + i);
1070
8.75k
      if (paramInfo.isAutoDiffSemanticResult())
1071
512
        continue;
1072
8.24k
      auto tan = *allResultsIt++;
1073
8.24k
      if (tan->getType().isAddress()) {
1074
2.71k
        addToAdjointBuffer(bb, origArg, tan, loc);
1075
5.52k
      } else {
1076
5.52k
        if (origArg->getType().isAddress()) {
1077
0
          auto *tmpBuf = builder.createAllocStack(loc, tan->getType());
1078
0
          builder.emitStoreValueOperation(loc, tan, tmpBuf,
1079
0
                                          StoreOwnershipQualifier::Init);
1080
0
          addToAdjointBuffer(bb, origArg, tmpBuf, loc);
1081
0
          builder.emitDestroyAddrAndFold(loc, tmpBuf);
1082
0
          builder.createDeallocStack(loc, tmpBuf);
1083
5.52k
        } else {
1084
5.52k
          recordTemporary(tan);
1085
5.52k
          addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc);
1086
5.52k
        }
1087
5.52k
      }
1088
8.24k
    }
1089
    // Destroy unused pullback direct results. Needed for pullback results from
1090
    // VJPs extracted from `@differentiable` function callees, where the
1091
    // `@differentiable` function's differentiation parameter indices are a
1092
    // superset of the active `apply` parameter indices.
1093
5.67k
    while (allResultsIt != allResults.end()) {
1094
16
      auto unusedPullbackDirectResult = *allResultsIt++;
1095
16
      if (unusedPullbackDirectResult->getType().isAddress())
1096
4
        continue;
1097
12
      builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult);
1098
12
    }
1099
    // Destroy and deallocate pullback indirect results.
1100
5.65k
    for (auto *alloc : llvm::reverse(pullbackIndirectResults)) {
1101
2.72k
      builder.emitDestroyAddrAndFold(loc, alloc);
1102
2.72k
      builder.createDeallocStack(loc, alloc);
1103
2.72k
    }
1104
5.65k
  }
1105
1106
32
  void visitBeginApplyInst(BeginApplyInst *bai) {
1107
    // Diagnose `begin_apply` instructions.
1108
    // Coroutine differentiation is not yet supported.
1109
32
    getContext().emitNondifferentiabilityError(
1110
32
        bai, getInvoker(), diag::autodiff_coroutines_not_supported);
1111
32
    errorOccurred = true;
1112
32
    return;
1113
32
  }
1114
1115
  /// Handle `struct` instruction.
1116
  ///   Original: y = struct (x0, x1, x2, ...)
1117
  ///    Adjoint: adj[x0] += struct_extract adj[y], #x0
1118
  ///             adj[x1] += struct_extract adj[y], #x1
1119
  ///             adj[x2] += struct_extract adj[y], #x2
1120
  ///             ...
1121
60
  void visitStructInst(StructInst *si) {
1122
60
    auto *bb = si->getParent();
1123
60
    auto loc = si->getLoc();
1124
60
    auto *structDecl = si->getStructDecl();
1125
60
    switch (getTangentValueCategory(si)) {
1126
60
    case SILValueCategory::Object: {
1127
60
      auto av = getAdjointValue(bb, si);
1128
60
      switch (av.getKind()) {
1129
0
      case AdjointValueKind::Zero: {
1130
0
        for (auto *field : structDecl->getStoredProperties()) {
1131
0
          auto fv = si->getFieldValue(field);
1132
0
          addAdjointValue(
1133
0
              bb, fv,
1134
0
              makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc);
1135
0
        }
1136
0
        break;
1137
0
      }
1138
60
      case AdjointValueKind::Concrete: {
1139
60
        auto adjStruct = materializeAdjointDirect(std::move(av), loc);
1140
60
        auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct);
1141
1142
        // Find the struct `TangentVector` type.
1143
60
        auto structTy = remapType(si->getType()).getASTType();
1144
60
#ifndef NDEBUG
1145
60
        auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
1146
60
        assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
1147
0
        assert(tangentVectorTy->getStructOrBoundGenericStruct());
1148
0
#endif
1149
1150
        // Accumulate adjoints for the fields of the `struct` operand.
1151
0
        unsigned fieldIndex = 0;
1152
60
        for (auto it = structDecl->getStoredProperties().begin();
1153
164
             it != structDecl->getStoredProperties().end();
1154
104
             ++it, ++fieldIndex) {
1155
104
          VarDecl *field = *it;
1156
104
          if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
1157
0
            continue;
1158
          // Find the corresponding field in the tangent space.
1159
104
          auto *tanField = getTangentStoredProperty(
1160
104
              getContext(), field, structTy, loc, getInvoker());
1161
104
          if (!tanField) {
1162
0
            errorOccurred = true;
1163
0
            return;
1164
0
          }
1165
104
          auto tanElt = dti->getResult(fieldIndex);
1166
104
          addAdjointValue(bb, si->getFieldValue(field),
1167
104
                          makeConcreteAdjointValue(tanElt), si->getLoc());
1168
104
        }
1169
60
        break;
1170
60
      }
1171
60
      case AdjointValueKind::Aggregate: {
1172
        // Note: All user-called initializations go through the calls to the
1173
        // initializer, and synthesized initializers only have one level of
1174
        // struct formation which will not result into any aggregate adjoint
1175
        // values.
1176
0
        llvm_unreachable(
1177
0
            "Aggregate adjoint values should not occur for `struct` "
1178
0
            "instructions");
1179
0
      }
1180
0
      case AdjointValueKind::AddElement: {
1181
0
        llvm_unreachable(
1182
0
            "Adjoint of `StructInst` cannot be of kind `AddElement`");
1183
0
      }
1184
60
      }
1185
60
      break;
1186
60
    }
1187
60
    case SILValueCategory::Address: {
1188
0
      auto adjBuf = getAdjointBuffer(bb, si);
1189
      // Find the struct `TangentVector` type.
1190
0
      auto structTy = remapType(si->getType()).getASTType();
1191
      // Accumulate adjoints for the fields of the `struct` operand.
1192
0
      unsigned fieldIndex = 0;
1193
0
      for (auto it = structDecl->getStoredProperties().begin();
1194
0
           it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) {
1195
0
        VarDecl *field = *it;
1196
0
        if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
1197
0
          continue;
1198
        // Find the corresponding field in the tangent space.
1199
0
        auto *tanField = getTangentStoredProperty(getContext(), field, structTy,
1200
0
                                                  loc, getInvoker());
1201
0
        if (!tanField) {
1202
0
          errorOccurred = true;
1203
0
          return;
1204
0
        }
1205
0
        auto *adjFieldBuf =
1206
0
            builder.createStructElementAddr(loc, adjBuf, tanField);
1207
0
        auto fieldValue = si->getFieldValue(field);
1208
0
        switch (getTangentValueCategory(fieldValue)) {
1209
0
        case SILValueCategory::Object: {
1210
0
          auto adjField = builder.emitLoadValueOperation(
1211
0
              loc, adjFieldBuf, LoadOwnershipQualifier::Copy);
1212
0
          recordTemporary(adjField);
1213
0
          addAdjointValue(bb, fieldValue, makeConcreteAdjointValue(adjField),
1214
0
                          loc);
1215
0
          break;
1216
0
        }
1217
0
        case SILValueCategory::Address: {
1218
0
          addToAdjointBuffer(bb, fieldValue, adjFieldBuf, loc);
1219
0
          break;
1220
0
        }
1221
0
        }
1222
0
      }
1223
0
    } break;
1224
60
    }
1225
60
  }
1226
1227
  /// Handle `struct_extract` instruction.
1228
  ///   Original: y = struct_extract x, #field
1229
  ///    Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
1230
  ///                                       ^~~~~~~
1231
  ///                     field in tangent space corresponding to #field
1232
508
  void visitStructExtractInst(StructExtractInst *sei) {
1233
508
    auto *bb = sei->getParent();
1234
508
    auto loc = getValidLocation(sei);
1235
    // Find the corresponding field in the tangent space.
1236
508
    auto structTy = remapType(sei->getOperand()->getType()).getASTType();
1237
508
    auto *tanField =
1238
508
        getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
1239
508
    assert(tanField && "Invalid projections should have been diagnosed");
1240
    // Check the `struct_extract` operand's value tangent category.
1241
0
    switch (getTangentValueCategory(sei->getOperand())) {
1242
508
    case SILValueCategory::Object: {
1243
508
      auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
1244
508
      auto tangentVectorSILTy =
1245
508
          SILType::getPrimitiveObjectType(tangentVectorTy);
1246
508
      auto eltAdj = getAdjointValue(bb, sei);
1247
1248
508
      switch (eltAdj.getKind()) {
1249
0
      case AdjointValueKind::Zero: {
1250
0
        addAdjointValue(bb, sei->getOperand(),
1251
0
                        makeZeroAdjointValue(tangentVectorSILTy), loc);
1252
0
        break;
1253
0
      }
1254
0
      case AdjointValueKind::Aggregate:
1255
504
      case AdjointValueKind::Concrete:
1256
508
      case AdjointValueKind::AddElement: {
1257
508
        auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy);
1258
508
        addAdjointValue(bb, sei->getOperand(),
1259
508
                        makeAddElementAdjointValue(baseAdj, eltAdj, tanField),
1260
508
                        loc);
1261
508
        break;
1262
504
      }
1263
508
      }
1264
508
      break;
1265
508
    }
1266
508
    case SILValueCategory::Address: {
1267
0
      auto adjBase = getAdjointBuffer(bb, sei->getOperand());
1268
0
      auto *adjBaseElt =
1269
0
          builder.createStructElementAddr(loc, adjBase, tanField);
1270
      // Check the `struct_extract`'s value tangent category.
1271
0
      switch (getTangentValueCategory(sei)) {
1272
0
      case SILValueCategory::Object: {
1273
0
        auto adjElt = getAdjointValue(bb, sei);
1274
0
        auto concreteAdjElt = materializeAdjointDirect(adjElt, loc);
1275
0
        auto concreteAdjEltCopy =
1276
0
            builder.emitCopyValueOperation(loc, concreteAdjElt);
1277
0
        auto *alloc = builder.createAllocStack(loc, adjElt.getType());
1278
0
        builder.emitStoreValueOperation(loc, concreteAdjEltCopy, alloc,
1279
0
                                        StoreOwnershipQualifier::Init);
1280
0
        builder.emitInPlaceAdd(loc, adjBaseElt, alloc);
1281
0
        builder.createDestroyAddr(loc, alloc);
1282
0
        builder.createDeallocStack(loc, alloc);
1283
0
        break;
1284
0
      }
1285
0
      case SILValueCategory::Address: {
1286
0
        auto adjElt = getAdjointBuffer(bb, sei);
1287
0
        builder.emitInPlaceAdd(loc, adjBaseElt, adjElt);
1288
0
        break;
1289
0
      }
1290
0
      }
1291
0
      break;
1292
0
    }
1293
508
    }
1294
508
  }
1295
1296
  /// Handle `ref_element_addr` instruction.
1297
  ///   Original: y = ref_element_addr x, <n>
1298
  ///    Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
1299
  ///                                       ^~~~~~~
1300
  ///                     field in tangent space corresponding to #field
1301
128
  void visitRefElementAddrInst(RefElementAddrInst *reai) {
1302
128
    auto *bb = reai->getParent();
1303
128
    auto loc = reai->getLoc();
1304
128
    auto adjBuf = getAdjointBuffer(bb, reai);
1305
128
    auto classOperand = reai->getOperand();
1306
128
    auto classType = remapType(reai->getOperand()->getType()).getASTType();
1307
128
    auto *tanField =
1308
128
        getTangentStoredProperty(getContext(), reai, classType, getInvoker());
1309
128
    assert(tanField && "Invalid projections should have been diagnosed");
1310
0
    switch (getTangentValueCategory(classOperand)) {
1311
36
    case SILValueCategory::Object: {
1312
36
      auto classTy = remapType(classOperand->getType()).getASTType();
1313
36
      auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType();
1314
36
      auto tangentVectorSILTy =
1315
36
          SILType::getPrimitiveObjectType(tangentVectorTy);
1316
36
      auto *tangentVectorDecl =
1317
36
          tangentVectorTy->getStructOrBoundGenericStruct();
1318
      // Accumulate adjoint for the `ref_element_addr` operand.
1319
36
      SmallVector<AdjointValue, 8> eltVals;
1320
36
      for (auto *field : tangentVectorDecl->getStoredProperties()) {
1321
36
        if (field == tanField) {
1322
36
          auto adjElt = builder.emitLoadValueOperation(
1323
36
              reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy);
1324
36
          eltVals.push_back(makeConcreteAdjointValue(adjElt));
1325
36
          recordTemporary(adjElt);
1326
36
        } else {
1327
0
          auto substMap = tangentVectorTy->getMemberSubstitutionMap(
1328
0
              field->getModuleContext(), field);
1329
0
          auto fieldTy = field->getInterfaceType().subst(substMap);
1330
0
          auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
1331
0
          assert(fieldSILTy.isObject());
1332
0
          eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
1333
0
        }
1334
36
      }
1335
36
      addAdjointValue(bb, classOperand,
1336
36
                      makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
1337
36
                      loc);
1338
36
      break;
1339
0
    }
1340
92
    case SILValueCategory::Address: {
1341
92
      auto adjBufClass = getAdjointBuffer(bb, classOperand);
1342
92
      auto adjBufElt =
1343
92
          builder.createStructElementAddr(loc, adjBufClass, tanField);
1344
92
      builder.emitInPlaceAdd(loc, adjBufElt, adjBuf);
1345
92
      break;
1346
0
    }
1347
128
    }
1348
128
  }
1349
1350
  /// Handle `tuple` instruction.
1351
  ///   Original: y = tuple (x0, x1, x2, ...)
1352
  ///    Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y]
1353
  ///                                         ^~~
1354
  ///                         excluding non-differentiable elements
1355
68
  void visitTupleInst(TupleInst *ti) {
1356
68
    auto *bb = ti->getParent();
1357
68
    auto loc = ti->getLoc();
1358
68
    switch (getTangentValueCategory(ti)) {
1359
68
    case SILValueCategory::Object: {
1360
68
      auto av = getAdjointValue(bb, ti);
1361
68
      switch (av.getKind()) {
1362
0
      case AdjointValueKind::Zero:
1363
0
        for (auto elt : ti->getElements()) {
1364
0
          if (!getTangentSpace(elt->getType().getASTType()))
1365
0
            continue;
1366
0
          addAdjointValue(
1367
0
              bb, elt,
1368
0
              makeZeroAdjointValue(getRemappedTangentType(elt->getType())),
1369
0
              loc);
1370
0
        }
1371
0
        break;
1372
0
      case AdjointValueKind::Concrete: {
1373
0
        auto adjVal = av.getConcreteValue();
1374
0
        auto adjValCopy = builder.emitCopyValueOperation(loc, adjVal);
1375
0
        SmallVector<SILValue, 4> adjElts;
1376
0
        if (!adjVal->getType().getAs<TupleType>()) {
1377
0
          recordTemporary(adjValCopy);
1378
0
          adjElts.push_back(adjValCopy);
1379
0
        } else {
1380
0
          auto *dti = builder.createDestructureTuple(loc, adjValCopy);
1381
0
          for (auto adjElt : dti->getResults())
1382
0
            recordTemporary(adjElt);
1383
0
          adjElts.append(dti->getResults().begin(), dti->getResults().end());
1384
0
        }
1385
        // Accumulate adjoints for `tuple` operands, skipping the
1386
        // non-`Differentiable` ones.
1387
0
        unsigned adjIndex = 0;
1388
0
        for (auto i : range(ti->getNumOperands())) {
1389
0
          if (!getTangentSpace(ti->getOperand(i)->getType().getASTType()))
1390
0
            continue;
1391
0
          auto adjElt = adjElts[adjIndex++];
1392
0
          addAdjointValue(bb, ti->getOperand(i),
1393
0
                          makeConcreteAdjointValue(adjElt), loc);
1394
0
        }
1395
0
        break;
1396
0
      }
1397
68
      case AdjointValueKind::Aggregate: {
1398
68
        unsigned adjIndex = 0;
1399
136
        for (auto i : range(ti->getElements().size())) {
1400
136
          if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
1401
0
            continue;
1402
136
          addAdjointValue(bb, ti->getElement(i),
1403
136
                          av.getAggregateElement(adjIndex++), loc);
1404
136
        }
1405
68
        break;
1406
0
      }
1407
0
      case AdjointValueKind::AddElement: {
1408
0
        llvm_unreachable(
1409
0
            "Adjoint of `TupleInst` cannot be of kind `AddElement`");
1410
0
      }
1411
68
      }
1412
68
      break;
1413
68
    }
1414
68
    case SILValueCategory::Address: {
1415
0
      auto adjBuf = getAdjointBuffer(bb, ti);
1416
      // Accumulate adjoints for `tuple` operands, skipping the
1417
      // non-`Differentiable` ones.
1418
0
      unsigned adjIndex = 0;
1419
0
      for (auto i : range(ti->getNumOperands())) {
1420
0
        if (!getTangentSpace(ti->getOperand(i)->getType().getASTType()))
1421
0
          continue;
1422
0
        auto adjBufElt =
1423
0
            builder.createTupleElementAddr(loc, adjBuf, adjIndex++);
1424
0
        auto adjElt = getAdjointBuffer(bb, ti->getOperand(i));
1425
0
        builder.emitInPlaceAdd(loc, adjElt, adjBufElt);
1426
0
      }
1427
0
      break;
1428
68
    }
1429
68
    }
1430
68
  }
1431
1432
  /// Handle `tuple_extract` instruction.
1433
  ///   Original: y = tuple_extract x, <n>
1434
  ///    Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
1435
  ///                                         ^~~~~~
1436
  ///                            n'-th element, where n' is tuple tangent space
1437
  ///                            index corresponding to n
1438
16
  void visitTupleExtractInst(TupleExtractInst *tei) {
1439
16
    auto *bb = tei->getParent();
1440
16
    auto loc = tei->getLoc();
1441
16
    auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
1442
16
    auto eltAdj = getAdjointValue(bb, tei);
1443
16
    switch (eltAdj.getKind()) {
1444
0
    case AdjointValueKind::Zero: {
1445
0
      addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
1446
0
                      loc);
1447
0
      break;
1448
0
    }
1449
0
    case AdjointValueKind::Aggregate:
1450
12
    case AdjointValueKind::Concrete:
1451
16
    case AdjointValueKind::AddElement: {
1452
16
      auto tupleTy = tei->getTupleType();
1453
16
      auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
1454
16
      if (!tupleTanTupleTy) {
1455
0
        addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
1456
0
        break;
1457
0
      }
1458
1459
16
      unsigned elements = 0;
1460
32
      for (unsigned i : range(tupleTy->getNumElements())) {
1461
32
        if (!getTangentSpace(
1462
32
                tupleTy->getElement(i).getType()->getCanonicalType()))
1463
0
          continue;
1464
32
        elements++;
1465
32
      }
1466
1467
16
      if (elements == 1) {
1468
0
        addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
1469
16
      } else {
1470
16
        auto baseAdj = makeZeroAdjointValue(tupleTanTy);
1471
16
        addAdjointValue(
1472
16
            bb, tei->getOperand(),
1473
16
            makeAddElementAdjointValue(baseAdj, eltAdj, tei->getFieldIndex()),
1474
16
            loc);
1475
16
      }
1476
16
      break;
1477
16
    }
1478
16
    }
1479
16
  }
1480
1481
  /// Handle `destructure_tuple` instruction.
1482
  ///   Original: (y0, ..., yn) = destructure_tuple x
1483
  ///    Adjoint: adj[x].0 += adj[y0]
1484
  ///             ...
1485
  ///             adj[x].n += adj[yn]
1486
308
  void visitDestructureTupleInst(DestructureTupleInst *dti) {
1487
308
    auto *bb = dti->getParent();
1488
308
    auto loc = dti->getLoc();
1489
308
    auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType());
1490
    // Check the `destructure_tuple` operand's value tangent category.
1491
308
    switch (getTangentValueCategory(dti->getOperand())) {
1492
308
    case SILValueCategory::Object: {
1493
308
      SmallVector<AdjointValue, 8> adjValues;
1494
616
      for (auto origElt : dti->getResults()) {
1495
        // Skip non-`Differentiable` tuple elements.
1496
616
        if (!getTangentSpace(remapType(origElt->getType()).getASTType()))
1497
208
          continue;
1498
408
        adjValues.push_back(getAdjointValue(bb, origElt));
1499
408
      }
1500
      // Handle tuple tangent type.
1501
      // Add adjoints for every tuple element that has a tangent space.
1502
308
      if (tupleTanTy.is<TupleType>()) {
1503
100
        assert(adjValues.size() > 1);
1504
0
        addAdjointValue(bb, dti->getOperand(),
1505
100
                        makeAggregateAdjointValue(tupleTanTy, adjValues), loc);
1506
100
      }
1507
      // Handle non-tuple tangent type.
1508
      // Add adjoint for the single tuple element that has a tangent space.
1509
208
      else {
1510
208
        assert(adjValues.size() == 1);
1511
0
        addAdjointValue(bb, dti->getOperand(), adjValues.front(), loc);
1512
208
      }
1513
0
      break;
1514
0
    }
1515
0
    case SILValueCategory::Address: {
1516
0
      auto adjBuf = getAdjointBuffer(bb, dti->getOperand());
1517
0
      unsigned adjIndex = 0;
1518
0
      for (auto origElt : dti->getResults()) {
1519
        // Skip non-`Differentiable` tuple elements.
1520
0
        if (!getTangentSpace(remapType(origElt->getType()).getASTType()))
1521
0
          continue;
1522
        // Handle tuple tangent type.
1523
        // Add adjoints for every tuple element that has a tangent space.
1524
0
        if (tupleTanTy.is<TupleType>()) {
1525
0
          auto adjEltBuf = getAdjointBuffer(bb, origElt);
1526
0
          auto adjBufElt =
1527
0
              builder.createTupleElementAddr(loc, adjBuf, adjIndex);
1528
0
          builder.emitInPlaceAdd(loc, adjBufElt, adjEltBuf);
1529
0
        }
1530
        // Handle non-tuple tangent type.
1531
        // Add adjoint for the single tuple element that has a tangent space.
1532
0
        else {
1533
0
          auto adjEltBuf = getAdjointBuffer(bb, origElt);
1534
0
          addToAdjointBuffer(bb, dti->getOperand(), adjEltBuf, loc);
1535
0
        }
1536
0
        ++adjIndex;
1537
0
      }
1538
0
      break;
1539
0
    }
1540
308
    }
1541
308
  }
1542
1543
  /// Handle `load` or `load_borrow` instruction
1544
  ///   Original: y = load/load_borrow x
1545
  ///    Adjoint: adj[x] += adj[y]
1546
2.35k
  void visitLoadOperation(SingleValueInstruction *inst) {
1547
2.35k
    assert(isa<LoadInst>(inst) || isa<LoadBorrowInst>(inst));
1548
0
    auto *bb = inst->getParent();
1549
2.35k
    auto loc = inst->getLoc();
1550
2.35k
    switch (getTangentValueCategory(inst)) {
1551
2.30k
    case SILValueCategory::Object: {
1552
2.30k
      auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc);
1553
      // Allocate a local buffer and store the adjoint value. This buffer will
1554
      // be used for accumulation into the adjoint buffer.
1555
2.30k
      auto adjBuf = builder.createAllocStack(
1556
2.30k
          loc, adjVal->getType(), SILDebugVariable());
1557
2.30k
      auto copy = builder.emitCopyValueOperation(loc, adjVal);
1558
2.30k
      builder.emitStoreValueOperation(loc, copy, adjBuf,
1559
2.30k
                                      StoreOwnershipQualifier::Init);
1560
      // Accumulate the adjoint value in the local buffer into the adjoint
1561
      // buffer.
1562
2.30k
      addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc);
1563
2.30k
      builder.emitDestroyAddr(loc, adjBuf);
1564
2.30k
      builder.createDeallocStack(loc, adjBuf);
1565
2.30k
      break;
1566
0
    }
1567
52
    case SILValueCategory::Address: {
1568
52
      auto adjBuf = getAdjointBuffer(bb, inst);
1569
52
      addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc);
1570
52
      break;
1571
0
    }
1572
2.35k
    }
1573
2.35k
  }
1574
2.29k
  void visitLoadInst(LoadInst *li) { visitLoadOperation(li); }
1575
64
  void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); }
1576
1577
  /// Handle `store` or `store_borrow` instruction.
1578
  ///   Original: store/store_borrow x to y
1579
  ///    Adjoint: adj[x] += load adj[y]; adj[y] = 0
1580
  void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc,
1581
2.69k
                           SILValue origDest) {
1582
2.69k
    auto adjBuf = getAdjointBuffer(bb, origDest);
1583
2.69k
    switch (getTangentValueCategory(origSrc)) {
1584
2.63k
    case SILValueCategory::Object: {
1585
2.63k
      auto adjVal = builder.emitLoadValueOperation(
1586
2.63k
          loc, adjBuf, LoadOwnershipQualifier::Take);
1587
2.63k
      recordTemporary(adjVal);
1588
2.63k
      addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc);
1589
2.63k
      builder.emitZeroIntoBuffer(loc, adjBuf, IsInitialization);
1590
2.63k
      break;
1591
0
    }
1592
60
    case SILValueCategory::Address: {
1593
60
      addToAdjointBuffer(bb, origSrc, adjBuf, loc);
1594
60
      builder.emitZeroIntoBuffer(loc, adjBuf, IsNotInitialization);
1595
60
      break;
1596
0
    }
1597
2.69k
    }
1598
2.69k
  }
1599
2.69k
  void visitStoreInst(StoreInst *si) {
1600
2.69k
    visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(),
1601
2.69k
                        si->getDest());
1602
2.69k
  }
1603
0
  void visitStoreBorrowInst(StoreBorrowInst *sbi) {
1604
0
    visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(),
1605
0
                        sbi);
1606
0
  }
1607
1608
  /// Handle `copy_addr` instruction.
1609
  ///   Original: copy_addr x to y
1610
  ///    Adjoint: adj[x] += adj[y]; adj[y] = 0
1611
2.01k
  void visitCopyAddrInst(CopyAddrInst *cai) {
1612
2.01k
    auto *bb = cai->getParent();
1613
2.01k
    auto adjDest = getAdjointBuffer(bb, cai->getDest());
1614
2.01k
    addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc());
1615
2.01k
    builder.emitZeroIntoBuffer(cai->getLoc(), adjDest, IsNotInitialization);
1616
2.01k
  }
1617
1618
  /// Handle any ownership instruction that deals with values: copy_value,
1619
  /// move_value, begin_borrow.
1620
  ///   Original: y = copy_value x
1621
  ///    Adjoint: adj[x] += adj[y]
1622
500
  void visitValueOwnershipInst(SingleValueInstruction *svi) {
1623
500
    assert(svi->getNumOperands() == 1);
1624
0
    auto *bb = svi->getParent();
1625
500
    switch (getTangentValueCategory(svi)) {
1626
348
    case SILValueCategory::Object: {
1627
348
      auto adj = getAdjointValue(bb, svi);
1628
348
      addAdjointValue(bb, svi->getOperand(0), adj, svi->getLoc());
1629
348
      break;
1630
0
    }
1631
152
    case SILValueCategory::Address: {
1632
152
      auto adjDest = getAdjointBuffer(bb, svi);
1633
152
      addToAdjointBuffer(bb, svi->getOperand(0), adjDest, svi->getLoc());
1634
152
      builder.emitZeroIntoBuffer(svi->getLoc(), adjDest, IsNotInitialization);
1635
152
      break;
1636
0
    }
1637
500
    }
1638
500
  }
1639
1640
  /// Handle `copy_value` instruction.
1641
  ///   Original: y = copy_value x
1642
  ///    Adjoint: adj[x] += adj[y]
1643
308
  void visitCopyValueInst(CopyValueInst *cvi) { visitValueOwnershipInst(cvi); }
1644
1645
  /// Handle `begin_borrow` instruction.
1646
  ///   Original: y = begin_borrow x
1647
  ///    Adjoint: adj[x] += adj[y]
1648
152
  void visitBeginBorrowInst(BeginBorrowInst *bbi) {
1649
152
    visitValueOwnershipInst(bbi);
1650
152
  }
1651
1652
  /// Handle `move_value` instruction.
1653
  ///   Original: y = move_value x
1654
  ///    Adjoint: adj[x] += adj[y]
1655
0
  void visitMoveValueInst(MoveValueInst *mvi) { visitValueOwnershipInst(mvi); }
1656
1657
40
  void visitEndInitLetRefInst(EndInitLetRefInst *eir) { visitValueOwnershipInst(eir); }
1658
1659
  /// Handle `begin_access` instruction.
1660
  ///   Original: y = begin_access x
1661
  ///    Adjoint: nothing
1662
3.13k
  void visitBeginAccessInst(BeginAccessInst *bai) {
1663
    // Check for non-differentiable writes.
1664
3.13k
    if (bai->getAccessKind() == SILAccessKind::Modify) {
1665
1.27k
      if (isa<GlobalAddrInst>(bai->getSource())) {
1666
4
        getContext().emitNondifferentiabilityError(
1667
4
            bai, getInvoker(),
1668
4
            diag::autodiff_cannot_differentiate_writes_to_global_variables);
1669
4
        errorOccurred = true;
1670
4
        return;
1671
4
      }
1672
1.27k
      if (isa<ProjectBoxInst>(bai->getSource())) {
1673
0
        getContext().emitNondifferentiabilityError(
1674
0
            bai, getInvoker(),
1675
0
            diag::autodiff_cannot_differentiate_writes_to_mutable_captures);
1676
0
        errorOccurred = true;
1677
0
        return;
1678
0
      }
1679
1.27k
    }
1680
3.13k
  }
1681
1682
  /// Handle `unconditional_checked_cast_addr` instruction.
1683
  ///   Original: y = unconditional_checked_cast_addr x
1684
  ///    Adjoint: adj[x] += unconditional_checked_cast_addr adj[y]
1685
  void visitUnconditionalCheckedCastAddrInst(
1686
16
      UnconditionalCheckedCastAddrInst *uccai) {
1687
16
    auto *bb = uccai->getParent();
1688
16
    auto adjDest = getAdjointBuffer(bb, uccai->getDest());
1689
16
    auto adjSrc = getAdjointBuffer(bb, uccai->getSrc());
1690
16
    auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType());
1691
16
    builder.createUnconditionalCheckedCastAddr(
1692
16
        uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf,
1693
16
        adjSrc->getType().getASTType());
1694
16
    addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc());
1695
16
    builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf);
1696
16
    builder.createDeallocStack(uccai->getLoc(), castBuf);
1697
16
    builder.emitZeroIntoBuffer(uccai->getLoc(), adjDest, IsInitialization);
1698
16
  }
1699
1700
  /// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
1701
  /// instructions.
1702
  ///
1703
  /// Original: y = init_enum_data_addr x
1704
  ///           inject_enum_addr y
1705
  ///
1706
  ///  Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
1707
8
  void visitInjectEnumAddrInst(InjectEnumAddrInst *inject) {
1708
8
    SILBasicBlock *bb = inject->getParent();
1709
8
    SILValue origEnum = inject->getOperand();
1710
1711
    // Only `Optional`-typed operands are supported for now. Diagnose all other
1712
    // enum operand types.
1713
8
    auto *optionalEnumDecl = getASTContext().getOptionalDecl();
1714
8
    if (origEnum->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
1715
0
      LLVM_DEBUG(getADDebugStream()
1716
0
                 << "Unsupported enum type in PullbackCloner: " << *inject);
1717
0
      getContext().emitNondifferentiabilityError(
1718
0
          inject, getInvoker(),
1719
0
          diag::autodiff_expression_not_differentiable_note);
1720
0
      errorOccurred = true;
1721
0
      return;
1722
0
    }
1723
1724
8
    InitEnumDataAddrInst *origData = nullptr;
1725
16
    for (auto use : origEnum->getUses()) {
1726
16
      if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser())) {
1727
        // We need a more complicated analysis when init_enum_data_addr and
1728
        // inject_enum_addr are in different blocks, or there is more than one
1729
        // such instruction. Bail out for now.
1730
8
        if (origData || init->getParent() != bb) {
1731
0
          LLVM_DEBUG(getADDebugStream()
1732
0
                     << "Could not find a matching init_enum_data_addr for: "
1733
0
                     << *inject);
1734
0
          getContext().emitNondifferentiabilityError(
1735
0
              inject, getInvoker(),
1736
0
              diag::autodiff_expression_not_differentiable_note);
1737
0
          errorOccurred = true;
1738
0
          return;
1739
0
        }
1740
1741
8
        origData = init;
1742
8
      }
1743
16
    }
1744
1745
8
    SILValue adjStruct = getAdjointBuffer(bb, origEnum);
1746
8
    StructDecl *adjStructDecl =
1747
8
        adjStruct->getType().getStructOrBoundGenericStruct();
1748
1749
8
    VarDecl *adjOptVar = nullptr;
1750
8
    if (adjStructDecl) {
1751
8
      ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
1752
8
      adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
1753
8
    }
1754
1755
8
    EnumDecl *adjOptDecl =
1756
8
        adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
1757
8
                  : nullptr;
1758
1759
    // Optional<T>.TangentVector should be a struct with a single
1760
    // Optional<T.TangentVector> property. This is an implementation detail of
1761
    // OptionalDifferentiation.swift
1762
8
    if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1763
0
      llvm_unreachable("Unexpected type of Optional.TangentVector");
1764
1765
8
    SILLocation loc = origData->getLoc();
1766
8
    StructElementAddrInst *adjOpt =
1767
8
        builder.createStructElementAddr(loc, adjStruct, adjOptVar);
1768
1769
    // unchecked_take_enum_data_addr is destructive, so copy
1770
    // Optional<T.TangentVector> to a new alloca.
1771
8
    AllocStackInst *adjOptCopy =
1772
8
        createFunctionLocalAllocation(adjOpt->getType(), loc);
1773
8
    builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake,
1774
8
                           IsInitialization);
1775
1776
8
    EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
1777
8
    UncheckedTakeEnumDataAddrInst *adjData =
1778
8
        builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl);
1779
1780
8
    setAdjointBuffer(bb, origData, adjData);
1781
1782
    // The Optional copy is invalidated, do not attempt to destroy it at the end
1783
    // of the pullback. The value returned from unchecked_take_enum_data_addr is
1784
    // destroyed in visitInitEnumDataAddrInst.
1785
8
    destroyedLocalAllocations.insert(adjOptCopy);
1786
8
  }
1787
1788
  /// Handle `init_enum_data_addr` instruction.
1789
  /// Destroy the value returned from `unchecked_take_enum_data_addr`.
1790
8
  void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) {
1791
8
    auto bufIt = bufferMap.find({init->getParent(), SILValue(init)});
1792
8
    if (bufIt == bufferMap.end())
1793
0
      return;
1794
8
    SILValue adjData = bufIt->second;
1795
8
    builder.emitDestroyAddr(init->getLoc(), adjData);
1796
8
  }
1797
1798
  /// Handle `unchecked_ref_cast` instruction.
1799
  ///   Original: y = unchecked_ref_cast x
1800
  ///    Adjoint: adj[x] += adj[y]
1801
  ///             (assuming adj[x] and adj[y] have the same type)
1802
8
  void visitUncheckedRefCastInst(UncheckedRefCastInst *urci) {
1803
8
    auto *bb = urci->getParent();
1804
8
    assert(urci->getOperand()->getType().isObject());
1805
0
    assert(getRemappedTangentType(urci->getOperand()->getType()) ==
1806
8
               getRemappedTangentType(urci->getType()) &&
1807
8
           "Operand/result must have the same `TangentVector` type");
1808
0
    switch (getTangentValueCategory(urci)) {
1809
0
    case SILValueCategory::Object: {
1810
0
      auto adj = getAdjointValue(bb, urci);
1811
0
      addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
1812
0
      break;
1813
0
    }
1814
8
    case SILValueCategory::Address: {
1815
8
      auto adjDest = getAdjointBuffer(bb, urci);
1816
8
      addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc());
1817
8
      builder.emitZeroIntoBuffer(urci->getLoc(), adjDest, IsNotInitialization);
1818
8
      break;
1819
0
    }
1820
8
    }
1821
8
  }
1822
1823
  /// Handle `upcast` instruction.
1824
  ///   Original: y = upcast x
1825
  ///    Adjoint: adj[x] += adj[y]
1826
  ///             (assuming adj[x] and adj[y] have the same type)
1827
24
  void visitUpcastInst(UpcastInst *ui) {
1828
24
    auto *bb = ui->getParent();
1829
24
    assert(ui->getOperand()->getType().isObject());
1830
0
    assert(getRemappedTangentType(ui->getOperand()->getType()) ==
1831
24
               getRemappedTangentType(ui->getType()) &&
1832
24
           "Operand/result must have the same `TangentVector` type");
1833
0
    switch (getTangentValueCategory(ui)) {
1834
8
    case SILValueCategory::Object: {
1835
8
      auto adj = getAdjointValue(bb, ui);
1836
8
      addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
1837
8
      break;
1838
0
    }
1839
16
    case SILValueCategory::Address: {
1840
16
      auto adjDest = getAdjointBuffer(bb, ui);
1841
16
      addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc());
1842
16
      builder.emitZeroIntoBuffer(ui->getLoc(), adjDest, IsNotInitialization);
1843
16
      break;
1844
0
    }
1845
24
    }
1846
24
  }
1847
1848
  /// Handle `unchecked_take_enum_data_addr` instruction.
1849
  /// Currently, only `Optional`-typed operands are supported.
1850
  ///   Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case
1851
  ///    Adjoint: adj[x] += $Enum.TangentVector(adj[y])
1852
  void
1853
112
  visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) {
1854
112
    auto *bb = utedai->getParent();
1855
112
    auto adjDest = getAdjointBuffer(bb, utedai);
1856
112
    auto enumTy = utedai->getOperand()->getType();
1857
112
    auto *optionalEnumDecl = getASTContext().getOptionalDecl();
1858
    // Only `Optional`-typed operands are supported for now. Diagnose all other
1859
    // enum operand types.
1860
112
    if (enumTy.getASTType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
1861
0
      LLVM_DEBUG(getADDebugStream()
1862
0
                 << "Unhandled instruction in PullbackCloner: " << *utedai);
1863
0
      getContext().emitNondifferentiabilityError(
1864
0
          utedai, getInvoker(),
1865
0
          diag::autodiff_expression_not_differentiable_note);
1866
0
      errorOccurred = true;
1867
0
      return;
1868
0
    }
1869
112
    accumulateAdjointForOptionalBuffer(bb, utedai->getOperand(), adjDest);
1870
112
    builder.emitZeroIntoBuffer(utedai->getLoc(), adjDest, IsNotInitialization);
1871
112
  }
1872
1873
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
1874
#undef NOT_DIFFERENTIABLE
1875
1876
#define NO_ADJOINT(INST)                                                       \
1877
16.5k
  void visit##INST##Inst(INST##Inst *inst) {}
_ZN5swift8autodiff14PullbackCloner14Implementation19visitAllocStackInstEPNS_14AllocStackInstE
Line
Count
Source
1877
4.24k
  void visit##INST##Inst(INST##Inst *inst) {}
_ZN5swift8autodiff14PullbackCloner14Implementation18visitIndexAddrInstEPNS_13IndexAddrInstE
Line
Count
Source
1877
124
  void visit##INST##Inst(INST##Inst *inst) {}
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation25visitPointerToAddressInstEPNS_20PointerToAddressInstE
_ZN5swift8autodiff14PullbackCloner14Implementation25visitTupleElementAddrInstEPNS_20TupleElementAddrInstE
Line
Count
Source
1877
1.22k
  void visit##INST##Inst(INST##Inst *inst) {}
_ZN5swift8autodiff14PullbackCloner14Implementation26visitStructElementAddrInstEPNS_21StructElementAddrInstE
Line
Count
Source
1877
964
  void visit##INST##Inst(INST##Inst *inst) {}
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation15visitReturnInstEPNS_10ReturnInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation15visitBranchInstEPNS_10BranchInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation19visitCondBranchInstEPNS_14CondBranchInstE
_ZN5swift8autodiff14PullbackCloner14Implementation21visitDeallocStackInstEPNS_16DeallocStackInstE
Line
Count
Source
1877
4.70k
  void visit##INST##Inst(INST##Inst *inst) {}
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation21visitStrongRetainInstEPNS_16StrongRetainInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation22visitStrongReleaseInstEPNS_17StrongReleaseInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation28visitStrongRetainUnownedInstEPNS_23StrongRetainUnownedInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation22visitUnownedRetainInstEPNS_17UnownedRetainInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation23visitUnownedReleaseInstEPNS_18UnownedReleaseInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation20visitRetainValueInstEPNS_15RetainValueInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation24visitRetainValueAddrInstEPNS_19RetainValueAddrInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation21visitReleaseValueInstEPNS_16ReleaseValueInstE
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation25visitReleaseValueAddrInstEPNS_20ReleaseValueAddrInstE
_ZN5swift8autodiff14PullbackCloner14Implementation21visitDestroyValueInstEPNS_16DestroyValueInstE
Line
Count
Source
1877
492
  void visit##INST##Inst(INST##Inst *inst) {}
_ZN5swift8autodiff14PullbackCloner14Implementation18visitEndBorrowInstEPNS_13EndBorrowInstE
Line
Count
Source
1877
228
  void visit##INST##Inst(INST##Inst *inst) {}
_ZN5swift8autodiff14PullbackCloner14Implementation18visitEndAccessInstEPNS_13EndAccessInstE
Line
Count
Source
1877
3.16k
  void visit##INST##Inst(INST##Inst *inst) {}
Unexecuted instantiation: _ZN5swift8autodiff14PullbackCloner14Implementation19visitDebugValueInstEPNS_14DebugValueInstE
_ZN5swift8autodiff14PullbackCloner14Implementation20visitDestroyAddrInstEPNS_15DestroyAddrInstE
Line
Count
Source
1877
1.44k
  void visit##INST##Inst(INST##Inst *inst) {}
1878
  // Terminators.
1879
  NO_ADJOINT(Return)
1880
  NO_ADJOINT(Branch)
1881
  NO_ADJOINT(CondBranch)
1882
1883
  // Address projections.
1884
  NO_ADJOINT(StructElementAddr)
1885
  NO_ADJOINT(TupleElementAddr)
1886
1887
  // Array literal initialization address projections.
1888
  NO_ADJOINT(PointerToAddress)
1889
  NO_ADJOINT(IndexAddr)
1890
1891
  // Memory allocation/access.
1892
  NO_ADJOINT(AllocStack)
1893
  NO_ADJOINT(DeallocStack)
1894
  NO_ADJOINT(EndAccess)
1895
1896
  // Debugging/reference counting instructions.
1897
  NO_ADJOINT(DebugValue)
1898
  NO_ADJOINT(RetainValue)
1899
  NO_ADJOINT(RetainValueAddr)
1900
  NO_ADJOINT(ReleaseValue)
1901
  NO_ADJOINT(ReleaseValueAddr)
1902
  NO_ADJOINT(StrongRetain)
1903
  NO_ADJOINT(StrongRelease)
1904
  NO_ADJOINT(UnownedRetain)
1905
  NO_ADJOINT(UnownedRelease)
1906
  NO_ADJOINT(StrongRetainUnowned)
1907
  NO_ADJOINT(DestroyValue)
1908
  NO_ADJOINT(DestroyAddr)
1909
1910
  // Value ownership.
1911
  NO_ADJOINT(EndBorrow)
1912
#undef NO_ADJOINT
1913
};
1914
1915
PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner)
1916
    : vjpCloner(vjpCloner), scopeCloner(getPullback()),
1917
      builder(getPullback(), getContext()),
1918
5.23k
      localAllocBuilder(getPullback(), getContext()) {
1919
  // Get dominance and post-order info for the original function.
1920
5.23k
  auto &passManager = getContext().getPassManager();
1921
5.23k
  auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>();
1922
5.23k
  auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>();
1923
5.23k
  auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>();
1924
5.23k
  auto *original = &vjpCloner.getOriginal();
1925
5.23k
  domInfo = domAnalysis->get(original);
1926
5.23k
  postDomInfo = postDomAnalysis->get(original);
1927
5.23k
  postOrderInfo = postOrderAnalysis->get(original);
1928
  // Initialize `originalExitBlock`.
1929
5.23k
  auto origExitIt = original->findReturnBB();
1930
5.23k
  assert(origExitIt != original->end() &&
1931
5.23k
         "Functions without returns must have been diagnosed");
1932
0
  originalExitBlock = &*origExitIt;
1933
5.23k
  localAllocBuilder.setCurrentDebugScope(
1934
5.23k
       remapScope(originalExitBlock->getTerminator()->getDebugScope()));
1935
5.23k
}
1936
1937
PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
1938
5.23k
    : impl(*new Implementation(vjpCloner)) {}
1939
1940
5.23k
PullbackCloner::~PullbackCloner() { delete &impl; }
1941
1942
//--------------------------------------------------------------------------//
1943
// Entry point
1944
//--------------------------------------------------------------------------//
1945
1946
5.23k
bool PullbackCloner::run() {
1947
5.23k
  bool foundError = impl.run();
1948
5.23k
#ifndef NDEBUG
1949
5.23k
  if (!foundError)
1950
5.10k
    impl.getPullback().verify();
1951
5.23k
#endif
1952
5.23k
  return foundError;
1953
5.23k
}
1954
1955
5.23k
bool PullbackCloner::Implementation::run() {
1956
5.23k
  PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal());
1957
5.23k
  auto &original = getOriginal();
1958
5.23k
  auto &pullback = getPullback();
1959
5.23k
  auto pbLoc = getPullback().getLocation();
1960
5.23k
  LLVM_DEBUG(getADDebugStream() << "Running PullbackCloner on\n" << original);
1961
1962
  // Collect original formal results.
1963
5.23k
  SmallVector<SILValue, 8> origFormalResults;
1964
5.23k
  collectAllFormalResultsInTypeOrder(original, origFormalResults);
1965
5.32k
  for (auto resultIndex : getConfig().resultIndices->getIndices()) {
1966
5.32k
    auto origResult = origFormalResults[resultIndex];
1967
    // If original result is non-varied, it will always have a zero derivative.
1968
    // Skip full pullback generation and simply emit zero derivatives for wrt
1969
    // parameters.
1970
    //
1971
    // NOTE(TF-876): This shortcut is currently necessary for functions
1972
    // returning non-varied result with >1 basic block where some basic blocks
1973
    // have no dominated active values; control flow differentiation does not
1974
    // handle this case. See TF-876 for context.
1975
5.32k
    if (!getActivityInfo().isVaried(origResult, getConfig().parameterIndices)) {
1976
112
      emitZeroDerivativesForNonvariedResult(origResult);
1977
112
      return false;
1978
112
    }
1979
5.32k
  }
1980
1981
  // Collect dominated active values in original basic blocks.
1982
  // Adjoint values of dominated active values are passed as pullback block
1983
  // arguments.
1984
5.12k
  DominanceOrder domOrder(original.getEntryBlock(), domInfo);
1985
  // Keep track of visited values.
1986
5.12k
  SmallPtrSet<SILValue, 8> visited;
1987
12.0k
  while (auto *bb = domOrder.getNext()) {
1988
6.96k
    auto &bbActiveValues = activeValues[bb];
1989
    // If the current block has an immediate dominator, append the immediate
1990
    // dominator block's active values to the current block's active values.
1991
6.96k
    if (auto *domNode = domInfo->getNode(bb)->getIDom()) {
1992
1.83k
      auto &domBBActiveValues = activeValues[domNode->getBlock()];
1993
1.83k
      bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end());
1994
1.83k
    }
1995
    // If `v` is active and has not been visited, records it as an active value
1996
    // in the original basic block.
1997
    // For active values unsupported by differentiation, emits a diagnostic and
1998
    // returns true. Otherwise, returns false.
1999
146k
    auto recordValueIfActive = [&](SILValue v) -> bool {
2000
      // If value is not active, skip.
2001
146k
      if (!getActivityInfo().isActive(v, getConfig()))
2002
64.3k
        return false;
2003
      // If active value has already been visited, skip.
2004
81.8k
      if (visited.count(v))
2005
54.4k
        return false;
2006
      // Mark active value as visited.
2007
27.4k
      visited.insert(v);
2008
2009
      // Diagnose unsupported active values.
2010
27.4k
      auto type = v->getType();
2011
      // Do not emit remaining activity-related diagnostics for semantic member
2012
      // accessors, which have special-case pullback generation.
2013
27.4k
      if (isSemanticMemberAccessor(&original))
2014
1.32k
        return false;
2015
      // Diagnose active enum values. Differentiation of enum values requires
2016
      // special adjoint value handling and is not yet supported. Diagnose
2017
      // only the first active enum value to prevent too many diagnostics.
2018
      //
2019
      // Do not diagnose `Optional`-typed values, which will have special-case
2020
      // differentiation support.
2021
26.1k
      if (auto *enumDecl = type.getEnumOrBoundGenericEnum()) {
2022
940
        if (!type.getASTType()->isOptional()) {
2023
40
          getContext().emitNondifferentiabilityError(
2024
40
              v, getInvoker(), diag::autodiff_enums_unsupported);
2025
40
          errorOccurred = true;
2026
40
          return true;
2027
40
        }
2028
940
      }
2029
      // Diagnose unsupported stored property projections.
2030
26.1k
      if (isa<StructExtractInst>(v) || isa<RefElementAddrInst>(v) ||
2031
26.1k
          isa<StructElementAddrInst>(v)) {
2032
1.63k
        auto *inst = cast<SingleValueInstruction>(v);
2033
1.63k
        assert(inst->getNumOperands() == 1);
2034
0
        auto baseType = remapType(inst->getOperand(0)->getType()).getASTType();
2035
1.63k
        if (!getTangentStoredProperty(getContext(), inst, baseType,
2036
1.63k
                                      getInvoker())) {
2037
32
          errorOccurred = true;
2038
32
          return true;
2039
32
        }
2040
1.63k
      }
2041
      // Skip address projections.
2042
      // Address projections do not need their own adjoint buffers; they
2043
      // become projections into their adjoint base buffer.
2044
26.0k
      if (Projection::isAddressProjection(v))
2045
2.55k
        return false;
2046
2047
      // Check that active values are differentiable. Otherwise we may crash
2048
      // later when tangent space is required, but not available.
2049
23.5k
      if (!getTangentSpace(remapType(type).getASTType())) {
2050
4
        getContext().emitNondifferentiabilityError(
2051
4
            v, getInvoker(), diag::autodiff_expression_not_differentiable_note);
2052
4
        errorOccurred = true;
2053
4
        return true;
2054
4
      }
2055
2056
      // Record active value.
2057
23.5k
      bbActiveValues.push_back(v);
2058
23.5k
      return false;
2059
23.5k
    };
2060
    // Record all active values in the basic block.
2061
6.96k
    for (auto *arg : bb->getArguments())
2062
10.1k
      if (recordValueIfActive(arg))
2063
32
        return true;
2064
83.4k
    for (auto &inst : *bb) {
2065
83.4k
      for (auto op : inst.getOperandValues())
2066
87.2k
        if (recordValueIfActive(op))
2067
0
          return true;
2068
83.4k
      for (auto result : inst.getResults())
2069
48.9k
        if (recordValueIfActive(result))
2070
44
          return true;
2071
83.4k
    }
2072
6.88k
    domOrder.pushChildren(bb);
2073
6.88k
  }
2074
2075
  // Create pullback blocks and arguments, visiting original blocks using BFS
2076
  // starting from the original exit block. Unvisited original basic blocks
2077
  // (e.g unreachable blocks) are not relevant for pullback generation and thus
2078
  // ignored.
2079
  // The original blocks in traversal order for pullback generation.
2080
5.04k
  SmallVector<SILBasicBlock *, 8> originalBlocks;
2081
  // The workqueue used for bookkeeping during the breadth-first traversal.
2082
5.04k
  BasicBlockWorkqueue workqueue = {originalExitBlock};
2083
2084
  // Perform BFS from the original exit block.
2085
5.04k
  {
2086
11.8k
    while (auto *BB = workqueue.pop()) {
2087
6.80k
      originalBlocks.push_back(BB);
2088
2089
6.80k
      for (auto *nextBB : BB->getPredecessorBlocks()) {
2090
2.36k
        workqueue.pushIfNotVisited(nextBB);
2091
2.36k
      }
2092
6.80k
    }
2093
5.04k
  }
2094
2095
6.80k
  for (auto *origBB : originalBlocks) {
2096
6.80k
    auto *pullbackBB = pullback.createBasicBlock();
2097
6.80k
    pullbackBBMap.insert({origBB, pullbackBB});
2098
6.80k
    auto pbTupleLoweredType =
2099
6.80k
        remapType(getPullbackInfo().getLinearMapTupleLoweredType(origBB));
2100
    // If the BB is the original exit, then the pullback block that we just
2101
    // created must be the pullback function's entry. For the pullback entry,
2102
    // create entry arguments and continue to the next block.
2103
6.80k
    if (origBB == originalExitBlock) {
2104
5.04k
      assert(pullbackBB->isEntry());
2105
0
      createEntryArguments(&pullback);
2106
5.04k
      auto *origTerm = originalExitBlock->getTerminator();
2107
5.04k
      builder.setCurrentDebugScope(remapScope(origTerm->getDebugScope()));
2108
5.04k
      builder.setInsertionPoint(pullbackBB);
2109
      // Obtain the context object, if any, and the top-level subcontext, i.e.
2110
      // the main pullback struct.
2111
5.04k
      if (getPullbackInfo().hasHeapAllocatedContext()) {
2112
        // The last argument is the context object (`Builtin.NativeObject`).
2113
100
        contextValue = pullbackBB->getArguments().back();
2114
100
        assert(contextValue->getType() ==
2115
100
               SILType::getNativeObjectType(getASTContext()));
2116
        // Load the pullback context.
2117
0
        auto subcontextAddr = emitProjectTopLevelSubcontext(
2118
100
            builder, pbLoc, contextValue, pbTupleLoweredType);
2119
100
        SILValue mainPullbackTuple = builder.createLoad(
2120
100
            pbLoc, subcontextAddr,
2121
100
            pbTupleLoweredType.isTrivial(getPullback()) ?
2122
88
                LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
2123
100
        auto *dsi = builder.createDestructureTuple(pbLoc, mainPullbackTuple);
2124
100
        initializePullbackTupleElements(origBB, dsi->getAllResults());
2125
4.94k
      } else {
2126
        // Obtain and destructure pullback struct elements.
2127
4.94k
        unsigned numVals = pbTupleLoweredType.getAs<TupleType>()->getNumElements();
2128
4.94k
        initializePullbackTupleElements(origBB,
2129
4.94k
                                        pullbackBB->getArguments().take_back(numVals));
2130
4.94k
      }
2131
2132
0
      continue;
2133
5.04k
    }
2134
2135
    // Get all active values in the original block.
2136
    // If the original block has no active values, continue.
2137
1.75k
    auto &bbActiveValues = activeValues[origBB];
2138
1.75k
    if (bbActiveValues.empty())
2139
0
      continue;
2140
2141
    // Otherwise, if the original block has active values:
2142
    // - For each active buffer in the original block, allocate a new local
2143
    //   buffer in the pullback entry. (All adjoint buffers are allocated in
2144
    //   the pullback entry and deallocated in the pullback exit.)
2145
    // - For each active value in the original block, add adjoint value
2146
    //   arguments to the pullback block.
2147
9.11k
    for (auto activeValue : bbActiveValues) {
2148
      // Handle the active value based on its value category.
2149
9.11k
      switch (getTangentValueCategory(activeValue)) {
2150
4.76k
      case SILValueCategory::Address: {
2151
        // Allocate and zero initialize a new local buffer using
2152
        // `getAdjointBuffer`.
2153
4.76k
        builder.setCurrentDebugScope(
2154
4.76k
            remapScope(originalExitBlock->getTerminator()->getDebugScope()));
2155
4.76k
        builder.setInsertionPoint(pullback.getEntryBlock());
2156
4.76k
        getAdjointBuffer(origBB, activeValue);
2157
4.76k
        break;
2158
0
      }
2159
4.35k
      case SILValueCategory::Object: {
2160
        // Create and register pullback block argument for the active value.
2161
4.35k
        auto *pullbackArg = pullbackBB->createPhiArgument(
2162
4.35k
            getRemappedTangentType(activeValue->getType()),
2163
4.35k
            OwnershipKind::Owned);
2164
4.35k
        activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg;
2165
4.35k
        recordTemporary(pullbackArg);
2166
4.35k
        break;
2167
0
      }
2168
9.11k
      }
2169
9.11k
    }
2170
    // Add a pullback tuple argument.
2171
1.75k
    auto *pbTupleArg = pullbackBB->createPhiArgument(pbTupleLoweredType,
2172
1.75k
                                                     OwnershipKind::Owned);
2173
    // Destructure the pullback struct to get the elements.
2174
1.75k
    builder.setCurrentDebugScope(
2175
1.75k
        remapScope(origBB->getTerminator()->getDebugScope()));
2176
1.75k
    builder.setInsertionPoint(pullbackBB);
2177
1.75k
    auto *dsi = builder.createDestructureTuple(pbLoc, pbTupleArg);
2178
1.75k
    initializePullbackTupleElements(origBB, dsi->getResults());
2179
2180
    // - Create pullback trampoline blocks for each successor block of the
2181
    //   original block. Pullback trampoline blocks only have a pullback
2182
    //   struct argument. They branch from a pullback successor block to the
2183
    //   pullback original block, passing adjoint values of active values.
2184
2.40k
    for (auto *succBB : origBB->getSuccessorBlocks()) {
2185
      // Skip generating pullback block for original unreachable blocks.
2186
2.40k
      if (!workqueue.isVisited(succBB))
2187
44
        continue;
2188
2.36k
      auto *pullbackTrampolineBB = pullback.createBasicBlockBefore(pullbackBB);
2189
2.36k
      pullbackTrampolineBBMap.insert({{origBB, succBB}, pullbackTrampolineBB});
2190
      // Get the enum element type (i.e. the pullback struct type). The enum
2191
      // element type may be boxed if the enum is indirect.
2192
2.36k
      auto enumLoweredTy =
2193
2.36k
          getPullbackInfo().getBranchingTraceEnumLoweredType(succBB);
2194
2.36k
      auto *enumEltDecl =
2195
2.36k
          getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB);
2196
2.36k
      auto enumEltType = remapType(enumLoweredTy.getEnumElementType(
2197
2.36k
          enumEltDecl, getModule(), TypeExpansionContext::minimal()));
2198
2.36k
      pullbackTrampolineBB->createPhiArgument(enumEltType,
2199
2.36k
                                              OwnershipKind::Owned);
2200
2.36k
    }
2201
1.75k
  }
2202
2203
5.04k
  auto *pullbackEntry = pullback.getEntryBlock();
2204
5.04k
  auto pbTupleLoweredType =
2205
5.04k
    remapType(getPullbackInfo().getLinearMapTupleLoweredType(originalExitBlock));
2206
5.04k
  unsigned numVals = (getPullbackInfo().hasHeapAllocatedContext() ?
2207
4.94k
                      1 : pbTupleLoweredType.getAs<TupleType>()->getNumElements());
2208
5.04k
  (void)numVals;
2209
2210
  // The pullback function has type:
2211
  // `(seed0, seed1, ..., (exit_pb_tuple_el0, ..., )|context_obj) -> (d_arg0, ..., d_argn)`.
2212
5.04k
  auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults();
2213
5.04k
  assert(getConfig().resultIndices->getNumIndices() == pbParamArgs.size() - numVals &&
2214
5.04k
         pbParamArgs.size() >= 1);
2215
  // Assign adjoints for original result.
2216
0
  builder.setCurrentDebugScope(
2217
5.04k
      remapScope(originalExitBlock->getTerminator()->getDebugScope()));
2218
5.04k
  builder.setInsertionPoint(pullbackEntry,
2219
5.04k
                            getNextFunctionLocalAllocationInsertionPoint());
2220
5.04k
  unsigned seedIndex = 0;
2221
5.14k
  for (auto resultIndex : getConfig().resultIndices->getIndices()) {
2222
5.14k
    auto origResult = origFormalResults[resultIndex];
2223
5.14k
    auto *seed = pbParamArgs[seedIndex];
2224
5.14k
    if (seed->getType().isAddress()) {
2225
      // If the seed argument is an `inout` parameter, assign it directly as
2226
      // the adjoint buffer of the original result.
2227
1.63k
      auto seedParamInfo =
2228
1.63k
          pullback.getLoweredFunctionType()->getParameters()[seedIndex];
2229
2230
1.63k
      if (seedParamInfo.isIndirectInOut()) {
2231
376
        setAdjointBuffer(originalExitBlock, origResult, seed);
2232
376
      }
2233
      // Otherwise, assign a copy of the seed argument as the adjoint buffer of
2234
      // the original result.
2235
1.26k
      else {
2236
1.26k
        auto *seedBufCopy =
2237
1.26k
            createFunctionLocalAllocation(seed->getType(), pbLoc);
2238
1.26k
        builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
2239
1.26k
                               IsInitialization);
2240
1.26k
        setAdjointBuffer(originalExitBlock, origResult, seedBufCopy);
2241
1.26k
        LLVM_DEBUG(getADDebugStream()
2242
1.26k
                   << "Assigned seed buffer " << *seedBufCopy
2243
1.26k
                   << " as the adjoint of original indirect result "
2244
1.26k
                   << origResult);
2245
1.26k
      }
2246
3.50k
    } else {
2247
3.50k
      addAdjointValue(originalExitBlock, origResult, makeConcreteAdjointValue(seed),
2248
3.50k
                      pbLoc);
2249
3.50k
      LLVM_DEBUG(getADDebugStream()
2250
3.50k
                 << "Assigned seed " << *seed
2251
3.50k
                 << " as the adjoint of original result " << origResult);
2252
3.50k
    }
2253
5.14k
    ++seedIndex;
2254
5.14k
  }
2255
2256
  // If the original function is an accessor with special-case pullback
2257
  // generation logic, do special-case generation.
2258
5.04k
  if (isSemanticMemberAccessor(&original)) {
2259
256
    if (runForSemanticMemberAccessor())
2260
0
      return true;
2261
256
  }
2262
  // Otherwise, perform standard pullback generation.
2263
  // Visit original blocks in post-order and perform differentiation
2264
  // in corresponding pullback blocks. If errors occurred, back out.
2265
4.79k
  else {
2266
6.54k
    for (auto *bb : originalBlocks) {
2267
6.54k
      visitSILBasicBlock(bb);
2268
6.54k
      if (errorOccurred)
2269
56
        return true;
2270
6.54k
    }
2271
4.79k
  }
2272
2273
  // Prepare and emit a `return` in the pullback exit block.
2274
4.99k
  auto *origEntry = getOriginal().getEntryBlock();
2275
4.99k
  auto *pbExit = getPullbackBlock(origEntry);
2276
4.99k
  builder.setCurrentDebugScope(pbExit->back().getDebugScope());
2277
4.99k
  builder.setInsertionPoint(pbExit);
2278
2279
  // This vector will contain all the materialized return elements.
2280
4.99k
  SmallVector<SILValue, 8> retElts;
2281
  // This vector will contain all indirect parameter adjoint buffers.
2282
4.99k
  SmallVector<SILValue, 4> indParamAdjoints;
2283
  // This vector will identify the locations where initialization is needed.
2284
4.99k
  SmallBitVector outputsToInitialize;
2285
2286
4.99k
  auto conv = getOriginal().getConventions();
2287
4.99k
  auto origParams = getOriginal().getArgumentsWithoutIndirectResults();
2288
2289
  // Materializes the return element corresponding to the parameter
2290
  // `parameterIndex` into the `retElts` vector.
2291
6.64k
  auto addRetElt = [&](unsigned parameterIndex) -> void {
2292
6.64k
    auto origParam = origParams[parameterIndex];
2293
6.64k
    switch (getTangentValueCategory(origParam)) {
2294
4.84k
    case SILValueCategory::Object: {
2295
4.84k
      auto pbVal = getAdjointValue(origEntry, origParam);
2296
4.84k
      auto val = materializeAdjointDirect(pbVal, pbLoc);
2297
4.84k
      auto newVal = builder.emitCopyValueOperation(pbLoc, val);
2298
4.84k
      retElts.push_back(newVal);
2299
4.84k
      break;
2300
0
    }
2301
1.80k
    case SILValueCategory::Address: {
2302
1.80k
      auto adjBuf = getAdjointBuffer(origEntry, origParam);
2303
1.80k
      indParamAdjoints.push_back(adjBuf);
2304
1.80k
      outputsToInitialize.push_back(
2305
1.80k
        !conv.getParameters()[parameterIndex].isIndirectMutating());
2306
1.80k
      break;
2307
0
    }
2308
6.64k
    }
2309
6.64k
  };
2310
4.99k
  SmallVector<SILArgument *, 4> pullbackIndirectResults(
2311
4.99k
        getPullback().getIndirectResults().begin(),
2312
4.99k
        getPullback().getIndirectResults().end());
2313
2314
  // Collect differentiation parameter adjoints.
2315
  // Do a first pass to collect non-inout values.
2316
7.00k
  for (auto i : getConfig().parameterIndices->getIndices()) {
2317
7.00k
    if (!conv.getParameters()[i].isAutoDiffSemanticResult()) {
2318
6.62k
       addRetElt(i);
2319
6.62k
     }
2320
7.00k
  }
2321
2322
  // Do a second pass for all inout parameters, however this is only necessary
2323
  // for functions with multiple basic blocks.  For functions with a single
2324
  // basic block adjoint accumulation for those parameters is already done by
2325
  // per-instruction visitors.
2326
4.99k
  if (getOriginal().size() > 1) {
2327
448
    const auto &pullbackConv = pullback.getConventions();
2328
448
    SmallVector<SILArgument *, 1> pullbackInOutArgs;
2329
912
    for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) {
2330
912
      if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult())
2331
20
        pullbackInOutArgs.push_back(pullbackArg.value());
2332
912
    }
2333
2334
448
    unsigned pullbackInoutArgumentIdx = 0;
2335
532
    for (auto i : getConfig().parameterIndices->getIndices()) {
2336
      // Skip non-inout parameters.
2337
532
      if (!conv.getParameters()[i].isAutoDiffSemanticResult())
2338
512
        continue;
2339
2340
      // For functions with multiple basic blocks, accumulation is needed
2341
      // for `inout` parameters because pullback basic blocks have different
2342
      // adjoint buffers.
2343
20
      pullbackIndirectResults.push_back(pullbackInOutArgs[pullbackInoutArgumentIdx++]);
2344
20
      addRetElt(i);
2345
20
    }
2346
448
  }
2347
2348
  // Copy them to adjoint indirect results.
2349
4.99k
  assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
2350
4.99k
         "Indirect parameter adjoint count mismatch");
2351
0
  unsigned currentIndex = 0;
2352
4.99k
  for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) {
2353
1.80k
    auto source = std::get<0>(pair);
2354
1.80k
    auto *dest = std::get<1>(pair);
2355
1.80k
    if (outputsToInitialize[currentIndex]) {
2356
1.78k
      builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
2357
1.78k
    } else {
2358
20
      builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization);
2359
20
    }
2360
1.80k
    currentIndex++;
2361
    // Prevent source buffer from being deallocated, since the underlying
2362
    // value is moved.
2363
1.80k
    destroyedLocalAllocations.insert(source);
2364
1.80k
  }
2365
2366
  // Emit cleanups for all local values.
2367
4.99k
  cleanUpTemporariesForBlock(pbExit, pbLoc);
2368
  // Deallocate local allocations.
2369
11.9k
  for (auto alloc : functionLocalAllocations) {
2370
    // Assert that local allocations have at least one use.
2371
    // Buffers should not be allocated needlessly.
2372
11.9k
    assert(!alloc->use_empty());
2373
11.9k
    if (!destroyedLocalAllocations.count(alloc)) {
2374
10.0k
      builder.emitDestroyAddrAndFold(pbLoc, alloc);
2375
10.0k
      destroyedLocalAllocations.insert(alloc);
2376
10.0k
    }
2377
11.9k
    builder.createDeallocStack(pbLoc, alloc);
2378
11.9k
  }
2379
4.99k
  builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc));
2380
2381
4.99k
#ifndef NDEBUG
2382
4.99k
  bool leakFound = false;
2383
  // Ensure all temporaries have been cleaned up.
2384
9.02k
  for (auto &bb : pullback) {
2385
9.02k
    for (auto temp : blockTemporaries[&bb]) {
2386
0
      if (blockTemporaries[&bb].count(temp)) {
2387
0
        leakFound = true;
2388
0
        getADDebugStream() << "Found leaked temporary:\n" << temp;
2389
0
      }
2390
0
    }
2391
9.02k
  }
2392
  // Ensure all local allocations have been cleaned up.
2393
11.9k
  for (auto localAlloc : functionLocalAllocations) {
2394
11.9k
    if (!destroyedLocalAllocations.count(localAlloc)) {
2395
0
      leakFound = true;
2396
0
      getADDebugStream() << "Found leaked local buffer:\n" << localAlloc;
2397
0
    }
2398
11.9k
  }
2399
4.99k
  assert(!leakFound && "Leaks found!");
2400
0
#endif
2401
2402
4.99k
  LLVM_DEBUG(getADDebugStream()
2403
4.99k
             << "Generated pullback for " << original.getName() << ":\n"
2404
4.99k
             << pullback);
2405
4.99k
  return errorOccurred;
2406
5.04k
}
2407
2408
void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
2409
112
    SILValue origNonvariedResult) {
2410
112
  auto &pullback = getPullback();
2411
112
  auto pbLoc = getPullback().getLocation();
2412
  /*
2413
  // TODO(TF-788): Re-enable non-varied result warning.
2414
  // Emit fixit if original non-varied result has a valid source location.
2415
  auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc();
2416
  auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc();
2417
  if (startLoc.isValid() && endLoc.isValid()) {
2418
    getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
2419
        .fixItInsert(startLoc, "withoutDerivative(at:")
2420
        .fixItInsertAfter(endLoc, ")");
2421
  }
2422
  */
2423
112
  LLVM_DEBUG(getADDebugStream() << getOriginal().getName()
2424
112
                                << " has non-varied result, returning zero"
2425
112
                                   " for all pullback results\n");
2426
112
  auto *pullbackEntry = pullback.createBasicBlock();
2427
112
  createEntryArguments(&pullback);
2428
112
  builder.setCurrentDebugScope(
2429
112
      remapScope(originalExitBlock->getTerminator()->getDebugScope()));
2430
112
  builder.setInsertionPoint(pullbackEntry);
2431
  // Destroy all owned arguments.
2432
112
  for (auto *arg : pullbackEntry->getArguments())
2433
172
    if (arg->getOwnershipKind() == OwnershipKind::Owned)
2434
0
      builder.emitDestroyOperation(pbLoc, arg);
2435
  // Return zero for each result.
2436
112
  SmallVector<SILValue, 4> directResults;
2437
112
  auto indirectResultIt = pullback.getIndirectResults().begin();
2438
132
  for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) {
2439
132
    auto resultType =
2440
132
        pullback.mapTypeIntoContext(resultInfo.getInterfaceType())
2441
132
            ->getCanonicalType();
2442
132
    if (resultInfo.isFormalDirect())
2443
88
      directResults.push_back(builder.emitZero(pbLoc, resultType));
2444
44
    else
2445
44
      builder.emitZeroIntoBuffer(pbLoc, *indirectResultIt++, IsInitialization);
2446
132
  }
2447
112
  builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc));
2448
112
  LLVM_DEBUG(getADDebugStream()
2449
112
             << "Generated pullback for " << getOriginal().getName() << ":\n"
2450
112
             << pullback);
2451
112
}
2452
2453
AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
2454
264
    SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy) {
2455
264
  auto pbLoc = getPullback().getLocation();
2456
  // `Optional<T>`
2457
264
  optionalTy = remapType(optionalTy);
2458
264
  assert(optionalTy.getASTType()->isOptional());
2459
  // `T`
2460
0
  auto wrappedType = optionalTy.getOptionalObjectType();
2461
  // `T.TangentVector`
2462
264
  auto wrappedTanType = remapType(wrappedAdjoint->getType());
2463
  // `Optional<T.TangentVector>`
2464
264
  auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType);
2465
  // `Optional<T>.TangentVector`
2466
264
  auto optionalTanTy = getRemappedTangentType(optionalTy);
2467
264
  auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal();
2468
  // Look up the `Optional<T>.TangentVector.init` declaration.
2469
264
  auto initLookup =
2470
264
      optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
2471
264
  ConstructorDecl *constructorDecl = nullptr;
2472
264
  for (auto *candidate : initLookup) {
2473
264
    auto candidateModule = candidate->getModuleContext();
2474
264
    if (candidateModule->getName() ==
2475
264
            builder.getASTContext().Id_Differentiation ||
2476
264
        candidateModule->isStdlibModule()) {
2477
264
      assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
2478
0
      constructorDecl = cast<ConstructorDecl>(candidate);
2479
#ifdef NDEBUG
2480
      break;
2481
#endif
2482
264
    }
2483
264
  }
2484
264
  assert(constructorDecl && "No `Optional.TangentVector.init`");
2485
2486
  // Allocate a local buffer for the `Optional` adjoint value.
2487
0
  auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy);
2488
  // Find `Optional<T.TangentVector>.some` EnumElementDecl.
2489
264
  auto someEltDecl = builder.getASTContext().getOptionalSomeDecl();
2490
2491
  // Initialize an `Optional<T.TangentVector>` buffer from `wrappedAdjoint` as
2492
  // the input for `Optional<T>.TangentVector.init`.
2493
264
  auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType);
2494
264
  if (optionalOfWrappedTanType.isLoadableOrOpaque(builder.getFunction())) {
2495
    // %enum = enum $Optional<T.TangentVector>, #Optional.some!enumelt,
2496
    //         %wrappedAdjoint : $T
2497
152
    auto *enumInst = builder.createEnum(pbLoc, wrappedAdjoint, someEltDecl,
2498
152
                                        optionalOfWrappedTanType);
2499
    // store %enum to %optArgBuf
2500
152
    builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf,
2501
152
                                    StoreOwnershipQualifier::Init);
2502
152
  } else {
2503
    // %enumAddr = init_enum_data_addr %optArgBuf $Optional<T.TangentVector>,
2504
    //                                 #Optional.some!enumelt
2505
112
    auto *enumAddr = builder.createInitEnumDataAddr(
2506
112
        pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType());
2507
    // copy_addr %wrappedAdjoint to [init] %enumAddr
2508
112
    builder.createCopyAddr(pbLoc, wrappedAdjoint, enumAddr, IsNotTake,
2509
112
                           IsInitialization);
2510
    // inject_enum_addr %optArgBuf : $*Optional<T.TangentVector>,
2511
    //                  #Optional.some!enumelt
2512
112
    builder.createInjectEnumAddr(pbLoc, optArgBuf, someEltDecl);
2513
112
  }
2514
2515
  // Apply `Optional<T>.TangentVector.init`.
2516
264
  SILOptFunctionBuilder fb(getContext().getTransform());
2517
  // %init_fn = function_ref @Optional<T>.TangentVector.init
2518
264
  auto *initFn = fb.getOrCreateFunction(pbLoc, SILDeclRef(constructorDecl),
2519
264
                                        NotForDefinition);
2520
264
  auto *initFnRef = builder.createFunctionRef(pbLoc, initFn);
2521
264
  auto *diffProto =
2522
264
      builder.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
2523
264
  auto *swiftModule = getModule().getSwiftModule();
2524
264
  auto diffConf =
2525
264
      swiftModule->lookupConformance(wrappedType.getASTType(), diffProto);
2526
264
  assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
2527
0
  auto subMap = SubstitutionMap::get(
2528
264
      initFn->getLoweredFunctionType()->getSubstGenericSignature(),
2529
264
      ArrayRef<Type>(wrappedType.getASTType()), {diffConf});
2530
  // %metatype = metatype $Optional<T>.TangentVector.Type
2531
264
  auto metatypeType = CanMetatypeType::get(optionalTanTy.getASTType(),
2532
264
                                           MetatypeRepresentation::Thin);
2533
264
  auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
2534
264
  auto metatype = builder.createMetatype(pbLoc, metatypeSILType);
2535
  // apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype)
2536
264
  builder.createApply(pbLoc, initFnRef, subMap,
2537
264
                      {optTanAdjBuf, optArgBuf, metatype});
2538
264
  builder.createDeallocStack(pbLoc, optArgBuf);
2539
264
  return optTanAdjBuf;
2540
264
}
2541
2542
// Accumulate adjoint for the incoming `Optional` buffer.
2543
void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
2544
112
    SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) {
2545
112
  assert(getTangentValueCategory(optionalBuffer) == SILValueCategory::Address);
2546
0
  auto pbLoc = getPullback().getLocation();
2547
2548
  // Allocate and initialize Optional<Wrapped>.TangentVector from
2549
  // Wrapped.TangentVector
2550
112
  AllocStackInst *optTanAdjBuf =
2551
112
      createOptionalAdjoint(bb, wrappedAdjoint, optionalBuffer->getType());
2552
2553
  // Accumulate into optionalBuffer
2554
112
  addToAdjointBuffer(bb, optionalBuffer, optTanAdjBuf, pbLoc);
2555
112
  builder.emitDestroyAddr(pbLoc, optTanAdjBuf);
2556
112
  builder.createDeallocStack(pbLoc, optTanAdjBuf);
2557
112
}
2558
2559
// Set the adjoint value for the incoming `Optional` value.
2560
void PullbackCloner::Implementation::setAdjointValueForOptional(
2561
152
    SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2562
152
  assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object);
2563
0
  auto pbLoc = getPullback().getLocation();
2564
2565
  // Allocate and initialize Optional<Wrapped>.TangentVector from
2566
  // Wrapped.TangentVector
2567
152
  AllocStackInst *optTanAdjBuf =
2568
152
      createOptionalAdjoint(bb, wrappedAdjoint, optionalValue->getType());
2569
2570
152
  auto optTanAdjVal = builder.emitLoadValueOperation(
2571
152
      pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2572
152
  recordTemporary(optTanAdjVal);
2573
152
  builder.createDeallocStack(pbLoc, optTanAdjBuf);
2574
2575
152
  setAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal));
2576
152
}
2577
2578
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
2579
    SILBasicBlock *origBB, SILBasicBlock *origPredBB,
2580
2.35k
    SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
2581
  // Get the pullback block and optional pullback trampoline block of the
2582
  // predecessor block.
2583
2.35k
  auto *pullbackBB = getPullbackBlock(origPredBB);
2584
2.35k
  auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB);
2585
  // If the predecessor block does not have a corresponding pullback
2586
  // trampoline block, then the pullback successor is the pullback block.
2587
2.35k
  if (!pullbackTrampolineBB)
2588
0
    return pullbackBB;
2589
2590
  // Otherwise, the pullback successor is the pullback trampoline block,
2591
  // which branches to the pullback block and propagates adjoint values of
2592
  // active values.
2593
2.35k
  assert(pullbackTrampolineBB->getNumArguments() == 1);
2594
0
  auto loc = origBB->getParent()->getLocation();
2595
2.35k
  SmallVector<SILValue, 8> trampolineArguments;
2596
2597
  // Propagate adjoint values/buffers of active values/buffers to
2598
  // predecessor blocks.
2599
2.35k
  auto &predBBActiveValues = activeValues[origPredBB];
2600
2.35k
  llvm::SmallSet<std::pair<SILValue, SILValue>, 32> propagatedAdjoints;
2601
11.6k
  for (auto activeValue : predBBActiveValues) {
2602
11.6k
    LLVM_DEBUG(getADDebugStream()
2603
11.6k
               << "Propagating adjoint of active value " << activeValue
2604
11.6k
               << "from bb" << origBB->getDebugID()
2605
11.6k
               << " to predecessors' (bb" << origPredBB->getDebugID()
2606
11.6k
               << ") pullback blocks\n");
2607
11.6k
    switch (getTangentValueCategory(activeValue)) {
2608
5.48k
    case SILValueCategory::Object: {
2609
5.48k
      auto activeValueAdj = getAdjointValue(origBB, activeValue);
2610
5.48k
      auto concreteActiveValueAdj =
2611
5.48k
          materializeAdjointDirect(activeValueAdj, loc);
2612
2613
5.48k
      if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) {
2614
4.35k
        concreteActiveValueAdj =
2615
4.35k
            builder.emitCopyValueOperation(loc, concreteActiveValueAdj);
2616
4.35k
        setAdjointValue(origBB, activeValue,
2617
4.35k
                        makeConcreteAdjointValue(concreteActiveValueAdj));
2618
4.35k
      }
2619
5.48k
      auto insertion = pullbackTrampolineBlockMap.try_emplace(
2620
5.48k
          concreteActiveValueAdj, TrampolineBlockSet());
2621
5.48k
      auto &blockSet = insertion.first->getSecond();
2622
5.48k
      blockSet.insert(pullbackTrampolineBB);
2623
5.48k
      trampolineArguments.push_back(concreteActiveValueAdj);
2624
2625
      // If the pullback block does not yet have a registered adjoint
2626
      // value for the active value, set the adjoint value to the
2627
      // forwarded adjoint value argument.
2628
      // TODO: Hoist this logic out of loop over predecessor blocks to
2629
      // remove the `hasAdjointValue` check.
2630
5.48k
      if (!hasAdjointValue(origPredBB, activeValue)) {
2631
3.81k
        auto *pullbackBBArg =
2632
3.81k
            getActiveValuePullbackBlockArgument(origPredBB, activeValue);
2633
3.81k
        auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg);
2634
3.81k
        setAdjointValue(origPredBB, activeValue, forwardedArgAdj);
2635
3.81k
      }
2636
5.48k
      break;
2637
0
    }
2638
6.17k
    case SILValueCategory::Address: {
2639
      // Propagate adjoint buffers using `copy_addr`.
2640
6.17k
      auto adjBuf = getAdjointBuffer(origBB, activeValue);
2641
6.17k
      auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue);
2642
6.17k
      if (propagatedAdjoints.insert({adjBuf, predAdjBuf}).second)
2643
5.12k
        builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake,
2644
5.12k
                               IsNotInitialization);
2645
6.17k
      break;
2646
0
    }
2647
11.6k
    }
2648
11.6k
  }
2649
2650
  // Propagate pullback struct argument.
2651
2.35k
  TangentBuilder pullbackTrampolineBBBuilder(
2652
2.35k
      pullbackTrampolineBB, getContext());
2653
2.35k
  pullbackTrampolineBBBuilder.setCurrentDebugScope(
2654
2.35k
      remapScope(origPredBB->getTerminator()->getDebugScope()));
2655
2656
2.35k
  auto *pullbackTrampolineBBArg = pullbackTrampolineBB->getArguments().front();
2657
2.35k
  if (vjpCloner.getLoopInfo()->getLoopFor(origPredBB)) {
2658
376
    assert(pullbackTrampolineBBArg->getType() ==
2659
376
               SILType::getRawPointerType(getASTContext()));
2660
0
    auto pbTupleType =
2661
376
      remapType(getPullbackInfo().getLinearMapTupleLoweredType(origPredBB));
2662
376
    auto predPbTupleAddr = pullbackTrampolineBBBuilder.createPointerToAddress(
2663
376
        loc, pullbackTrampolineBBArg, pbTupleType.getAddressType(),
2664
376
        /*isStrict*/ true);
2665
376
    auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad(
2666
376
        loc, predPbTupleAddr,
2667
376
        pbTupleType.isTrivial(getPullback()) ?
2668
284
            LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
2669
376
    trampolineArguments.push_back(predPbStructVal);
2670
1.98k
  } else {
2671
1.98k
    trampolineArguments.push_back(pullbackTrampolineBBArg);
2672
1.98k
  }
2673
  // Branch from pullback trampoline block to pullback block.
2674
0
  pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB,
2675
2.35k
                                           trampolineArguments);
2676
2.35k
  return pullbackTrampolineBB;
2677
2.35k
}
2678
2679
6.54k
void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2680
6.54k
  auto pbLoc = getPullback().getLocation();
2681
  // Get the corresponding pullback basic block.
2682
6.54k
  auto *pbBB = getPullbackBlock(bb);
2683
6.54k
  builder.setInsertionPoint(pbBB);
2684
2685
6.54k
  LLVM_DEBUG({
2686
6.54k
    auto &s = getADDebugStream()
2687
6.54k
              << "Original bb" + std::to_string(bb->getDebugID())
2688
6.54k
              << ": To differentiate or not to differentiate?\n";
2689
6.54k
    for (auto &inst : llvm::reverse(*bb)) {
2690
6.54k
      s << (getPullbackInfo().shouldDifferentiateInstruction(&inst) ? "[x] "
2691
6.54k
                                                                    : "[ ] ")
2692
6.54k
        << inst;
2693
6.54k
    }
2694
6.54k
  });
2695
2696
  // Visit each instruction in reverse order.
2697
79.2k
  for (auto &inst : llvm::reverse(*bb)) {
2698
79.2k
    if (!getPullbackInfo().shouldDifferentiateInstruction(&inst))
2699
44.5k
      continue;
2700
    // Differentiate instruction.
2701
34.6k
    builder.setCurrentDebugScope(remapScope(inst.getDebugScope()));
2702
34.6k
    visit(&inst);
2703
34.6k
    if (errorOccurred)
2704
56
      return;
2705
34.6k
  }
2706
2707
  // Emit a branching terminator for the block.
2708
  // If the original block is the original entry, then the pullback block is
2709
  // the pullback exit. This is handled specially in
2710
  // `PullbackCloner::Implementation::run()`, so we leave the block
2711
  // non-terminated.
2712
6.48k
  if (bb->isEntry())
2713
4.73k
    return;
2714
2715
  // Otherwise, add a `switch_enum` terminator for non-exit
2716
  // pullback blocks.
2717
  // 1. Get the pullback struct pullback block argument.
2718
  // 2. Extract the predecessor enum value from the pullback struct value.
2719
1.75k
  auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb);
2720
1.75k
  (void)predEnum;
2721
1.75k
  auto predEnumVal = getPullbackPredTupleElement(bb);
2722
2723
  // Propagate adjoint values from active basic block arguments to
2724
  // incoming values (predecessor terminator operands).
2725
1.75k
  for (auto *bbArg : bb->getArguments()) {
2726
540
    if (!getActivityInfo().isActive(bbArg, getConfig()))
2727
180
      continue;
2728
360
    LLVM_DEBUG(getADDebugStream() << "Propagating adjoint value for active bb"
2729
360
               << bb->getDebugID() << " argument: "
2730
360
               << *bbArg);
2731
2732
    // Get predecessor terminator operands.
2733
360
    SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
2734
360
    if (bbArg->getSingleTerminatorOperands(incomingValues)) {
2735
      // Returns true if the given terminator instruction is a `switch_enum` on
2736
      // an `Optional`-typed value. `switch_enum` instructions require
2737
      // special-case adjoint value propagation for the operand.
2738
360
      auto isSwitchEnumInstOnOptional =
2739
620
        [&ctx = getASTContext()](TermInst *termInst) {
2740
620
          if (!termInst)
2741
468
            return false;
2742
152
          if (auto *sei = dyn_cast<SwitchEnumInst>(termInst)) {
2743
152
            auto operandTy = sei->getOperand()->getType();
2744
152
            return operandTy.getASTType()->isOptional();
2745
152
          }
2746
0
          return false;
2747
152
        };
2748
2749
      // Check the tangent value category of the active basic block argument.
2750
360
      switch (getTangentValueCategory(bbArg)) {
2751
        // If argument has a loadable tangent value category: materialize adjoint
2752
        // value of the argument, create a copy, and set the copy as the adjoint
2753
        // value of incoming values.
2754
360
      case SILValueCategory::Object: {
2755
360
        auto bbArgAdj = getAdjointValue(bb, bbArg);
2756
360
        auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
2757
360
        auto concreteBBArgAdjCopy =
2758
360
          builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj);
2759
620
        for (auto pair : incomingValues) {
2760
620
          auto *predBB = std::get<0>(pair);
2761
620
          auto incomingValue = std::get<1>(pair);
2762
          // Handle `switch_enum` on `Optional`.
2763
620
          auto termInst = bbArg->getSingleTerminator();
2764
620
          if (isSwitchEnumInstOnOptional(termInst)) {
2765
152
            setAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
2766
468
          } else {
2767
468
            blockTemporaries[getPullbackBlock(predBB)].insert(
2768
468
              concreteBBArgAdjCopy);
2769
468
            setAdjointValue(predBB, incomingValue,
2770
468
                            makeConcreteAdjointValue(concreteBBArgAdjCopy));
2771
468
          }
2772
620
        }
2773
360
        break;
2774
0
      }
2775
      // If argument has an address tangent value category: materialize adjoint
2776
      // value of the argument, create a copy, and set the copy as the adjoint
2777
      // value of incoming values.
2778
0
      case SILValueCategory::Address: {
2779
0
        auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
2780
0
        for (auto pair : incomingValues) {
2781
0
          auto incomingValue = std::get<1>(pair);
2782
          // Handle `switch_enum` on `Optional`.
2783
0
          auto termInst = bbArg->getSingleTerminator();
2784
0
          if (isSwitchEnumInstOnOptional(termInst))
2785
0
            accumulateAdjointForOptionalBuffer(bb, incomingValue, bbArgAdjBuf);
2786
0
          else
2787
0
            addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
2788
0
        }
2789
0
        break;
2790
0
      }
2791
360
      }
2792
360
    } else
2793
0
      llvm::report_fatal_error("do not know how to handle this incoming bb argument");
2794
360
  }
2795
2796
  // 3. Build the pullback successor cases for the `switch_enum`
2797
  //    instruction. The pullback successors correspond to the predecessors
2798
  //    of the current original block.
2799
1.75k
  SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4>
2800
1.75k
      pullbackSuccessorCases;
2801
  // A map from active values' adjoint values to the trampoline blocks that
2802
  // are using them.
2803
1.75k
  SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap;
2804
1.75k
  SmallDenseMap<SILBasicBlock *, SILBasicBlock *> origPredpullbackSuccBBMap;
2805
2.35k
  for (auto *predBB : bb->getPredecessorBlocks()) {
2806
2.35k
    auto *pullbackSuccBB =
2807
2.35k
        buildPullbackSuccessor(bb, predBB, pullbackTrampolineBlockMap);
2808
2.35k
    origPredpullbackSuccBBMap[predBB] = pullbackSuccBB;
2809
2.35k
    auto *enumEltDecl =
2810
2.35k
        getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb);
2811
2.35k
    pullbackSuccessorCases.push_back({enumEltDecl, pullbackSuccBB});
2812
2.35k
  }
2813
  // Values are trampolined by only a subset of pullback successor blocks.
2814
  // Other successors blocks should destroy the value.
2815
4.35k
  for (auto pair : pullbackTrampolineBlockMap) {
2816
4.35k
    auto value = pair.getFirst();
2817
    // The set of trampoline BBs that are users of `value`.
2818
4.35k
    auto &userTrampolineBBSet = pair.getSecond();
2819
    // For each pullback successor block that does not trampoline the value,
2820
    // release the value.
2821
6.75k
    for (auto origPredPbSuccPair : origPredpullbackSuccBBMap) {
2822
6.75k
      auto *origPred = origPredPbSuccPair.getFirst();
2823
6.75k
      auto *pbSucc = origPredPbSuccPair.getSecond();
2824
6.75k
      if (userTrampolineBBSet.count(pbSucc))
2825
5.48k
        continue;
2826
1.26k
      TangentBuilder pullbackSuccBuilder(pbSucc->begin(), getContext());
2827
1.26k
      pullbackSuccBuilder.setCurrentDebugScope(
2828
1.26k
          remapScope(origPred->getTerminator()->getDebugScope()));
2829
1.26k
      pullbackSuccBuilder.emitDestroyValueOperation(pbLoc, value);
2830
1.26k
    }
2831
4.35k
  }
2832
  // Emit cleanups for all block-local temporaries.
2833
1.75k
  cleanUpTemporariesForBlock(pbBB, pbLoc);
2834
  // Branch to pullback successor blocks.
2835
1.75k
  assert(pullbackSuccessorCases.size() == predEnum->getNumElements());
2836
0
  builder.createSwitchEnum(pbLoc, predEnumVal, /*DefaultBB*/ nullptr,
2837
1.75k
                           pullbackSuccessorCases, llvm::None, ProfileCounter(),
2838
1.75k
                           OwnershipKind::Owned);
2839
1.75k
}
2840
2841
//--------------------------------------------------------------------------//
2842
// Member accessor pullback generation
2843
//--------------------------------------------------------------------------//
2844
2845
256
bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
2846
256
  auto &original = getOriginal();
2847
256
  auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
2848
256
  switch (accessor->getAccessorKind()) {
2849
192
  case AccessorKind::Get:
2850
192
    return runForSemanticMemberGetter();
2851
64
  case AccessorKind::Set:
2852
64
    return runForSemanticMemberSetter();
2853
  // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
2854
0
  default:
2855
0
    llvm_unreachable("Unsupported accessor kind; inconsistent with "
2856
256
                     "`isSemanticMemberAccessor`?");
2857
256
  }
2858
256
}
2859
2860
192
bool PullbackCloner::Implementation::runForSemanticMemberGetter() {
2861
192
  auto &original = getOriginal();
2862
192
  auto &pullback = getPullback();
2863
192
  auto pbLoc = getPullback().getLocation();
2864
2865
192
  auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
2866
192
  assert(accessor->getAccessorKind() == AccessorKind::Get);
2867
2868
0
  auto *origEntry = original.getEntryBlock();
2869
192
  auto *pbEntry = pullback.getEntryBlock();
2870
192
  builder.setCurrentDebugScope(
2871
192
      remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
2872
192
  builder.setInsertionPoint(pbEntry);
2873
2874
  // Get getter argument and result values.
2875
  //   Getter type: $(Self) -> Result
2876
  // Pullback type: $(Result') -> Self'
2877
192
  assert(original.getLoweredFunctionType()->getNumParameters() == 1);
2878
0
  assert(pullback.getLoweredFunctionType()->getNumParameters() == 1);
2879
0
  assert(pullback.getLoweredFunctionType()->getNumResults() == 1);
2880
0
  SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
2881
2882
192
  SmallVector<SILValue, 8> origFormalResults;
2883
192
  collectAllFormalResultsInTypeOrder(original, origFormalResults);
2884
192
  assert(getConfig().resultIndices->getNumIndices() == 1 &&
2885
192
         "Getter should have one semantic result");
2886
0
  auto origResult = origFormalResults[*getConfig().resultIndices->begin()];
2887
2888
192
  auto tangentVectorSILTy = pullback.getConventions().getResults().front()
2889
192
      .getSILStorageType(getModule(),
2890
192
                         pullback.getLoweredFunctionType(),
2891
192
                         TypeExpansionContext::minimal());
2892
192
  auto tangentVectorTy = tangentVectorSILTy.getASTType();
2893
192
  auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
2894
2895
  // Look up the corresponding field in the tangent space.
2896
192
  auto *origField = cast<VarDecl>(accessor->getStorage());
2897
192
  auto baseType = remapType(origSelf->getType()).getASTType();
2898
192
  auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
2899
192
                                            pbLoc, getInvoker());
2900
192
  if (!tanField) {
2901
0
    errorOccurred = true;
2902
0
    return true;
2903
0
  }
2904
2905
  // Switch based on the base tangent struct's value category.
2906
192
  switch (getTangentValueCategory(origSelf)) {
2907
88
  case SILValueCategory::Object: {
2908
88
    auto adjResult = getAdjointValue(origEntry, origResult);
2909
88
    switch (adjResult.getKind()) {
2910
0
    case AdjointValueKind::Zero:
2911
0
      addAdjointValue(origEntry, origSelf,
2912
0
                      makeZeroAdjointValue(tangentVectorSILTy), pbLoc);
2913
0
      break;
2914
88
    case AdjointValueKind::Concrete:
2915
88
    case AdjointValueKind::Aggregate: {
2916
88
      SmallVector<AdjointValue, 8> eltVals;
2917
152
      for (auto *field : tangentVectorDecl->getStoredProperties()) {
2918
152
        if (field == tanField) {
2919
88
          eltVals.push_back(adjResult);
2920
88
        } else {
2921
64
          auto substMap = tangentVectorTy->getMemberSubstitutionMap(
2922
64
              field->getModuleContext(), field);
2923
64
          auto fieldTy = field->getInterfaceType().subst(substMap);
2924
64
          auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
2925
64
          assert(fieldSILTy.isObject());
2926
0
          eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
2927
64
        }
2928
152
      }
2929
88
      addAdjointValue(origEntry, origSelf,
2930
88
                      makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
2931
88
                      pbLoc);
2932
2933
88
      break;
2934
88
    }
2935
0
    case AdjointValueKind::AddElement:
2936
0
      llvm_unreachable("Adjoint of an aggregate type's field cannot be of kind "
2937
88
                       "`AddElement`");
2938
88
    }
2939
88
    break;
2940
88
  }
2941
104
  case SILValueCategory::Address: {
2942
104
    assert(pullback.getIndirectResults().size() == 1);
2943
0
    auto pbIndRes = pullback.getIndirectResults().front();
2944
104
    auto *adjSelf = createFunctionLocalAllocation(
2945
104
        pbIndRes->getType().getObjectType(), pbLoc);
2946
104
    setAdjointBuffer(origEntry, origSelf, adjSelf);
2947
296
    for (auto *field : tangentVectorDecl->getStoredProperties()) {
2948
296
      auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field);
2949
      // Non-tangent fields get a zero.
2950
296
      if (field != tanField) {
2951
192
        builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization);
2952
192
        continue;
2953
192
      }
2954
      // Switch based on the property's value category.
2955
104
      switch (getTangentValueCategory(origResult)) {
2956
20
      case SILValueCategory::Object: {
2957
20
        auto adjResult = getAdjointValue(origEntry, origResult);
2958
20
        auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc);
2959
20
        auto adjResultValueCopy =
2960
20
            builder.emitCopyValueOperation(pbLoc, adjResultValue);
2961
20
        builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt,
2962
20
                                        StoreOwnershipQualifier::Init);
2963
20
        break;
2964
0
      }
2965
84
      case SILValueCategory::Address: {
2966
84
        auto adjResult = getAdjointBuffer(origEntry, origResult);
2967
84
        builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake,
2968
84
                               IsInitialization);
2969
84
        destroyedLocalAllocations.insert(adjResult);
2970
84
        break;
2971
0
      }
2972
104
      }
2973
104
    }
2974
104
    break;
2975
104
  }
2976
192
  }
2977
192
  return false;
2978
192
}
2979
2980
64
bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
2981
64
  auto &original = getOriginal();
2982
64
  auto &pullback = getPullback();
2983
64
  auto pbLoc = getPullback().getLocation();
2984
2985
64
  auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
2986
64
  assert(accessor->getAccessorKind() == AccessorKind::Set);
2987
2988
0
  auto *origEntry = original.getEntryBlock();
2989
64
  auto *pbEntry = pullback.getEntryBlock();
2990
64
  builder.setCurrentDebugScope(
2991
64
      remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
2992
64
  builder.setInsertionPoint(pbEntry);
2993
2994
  // Get setter argument values.
2995
  //              Setter type: $(inout Self, Argument) -> ()
2996
  // Pullback type (wrt self): $(inout Self') -> ()
2997
  // Pullback type (wrt both): $(inout Self') -> Argument'
2998
64
  assert(original.getLoweredFunctionType()->getNumParameters() == 2);
2999
0
  assert(pullback.getLoweredFunctionType()->getNumParameters() == 1);
3000
0
  assert(pullback.getLoweredFunctionType()->getNumResults() == 0 ||
3001
64
         pullback.getLoweredFunctionType()->getNumResults() == 1);
3002
3003
0
  SILValue origArg = original.getArgumentsWithoutIndirectResults()[0];
3004
64
  SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1];
3005
3006
  // Look up the corresponding field in the tangent space.
3007
64
  auto *origField = cast<VarDecl>(accessor->getStorage());
3008
64
  auto baseType = remapType(origSelf->getType()).getASTType();
3009
64
  auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
3010
64
                                            pbLoc, getInvoker());
3011
64
  if (!tanField) {
3012
0
    errorOccurred = true;
3013
0
    return true;
3014
0
  }
3015
3016
64
  auto adjSelf = getAdjointBuffer(origEntry, origSelf);
3017
64
  auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
3018
  // Switch based on the property's value category.
3019
64
  switch (getTangentValueCategory(origArg)) {
3020
24
  case SILValueCategory::Object: {
3021
24
    auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt,
3022
24
                                                 LoadOwnershipQualifier::Take);
3023
24
    setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg));
3024
24
    blockTemporaries[pbEntry].insert(adjArg);
3025
24
    break;
3026
0
  }
3027
40
  case SILValueCategory::Address: {
3028
40
    addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc);
3029
40
    builder.emitDestroyOperation(pbLoc, adjSelfElt);
3030
40
    break;
3031
0
  }
3032
64
  }
3033
64
  builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization);
3034
3035
64
  return false;
3036
64
}
3037
3038
//--------------------------------------------------------------------------//
3039
// Adjoint buffer mapping
3040
//--------------------------------------------------------------------------//
3041
3042
SILValue PullbackCloner::Implementation::getAdjointProjection(
3043
16.7k
    SILBasicBlock *origBB, SILValue originalProjection) {
3044
  // Handle `struct_element_addr`.
3045
  // Adjoint projection: a `struct_element_addr` into the base adjoint buffer.
3046
16.7k
  if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) {
3047
920
    assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
3048
920
           "`@noDerivative` struct projections should never be active");
3049
0
    auto adjSource = getAdjointBuffer(origBB, seai->getOperand());
3050
920
    auto structType = remapType(seai->getOperand()->getType()).getASTType();
3051
920
    auto *tanField =
3052
920
        getTangentStoredProperty(getContext(), seai, structType, getInvoker());
3053
920
    assert(tanField && "Invalid projections should have been diagnosed");
3054
0
    return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField);
3055
920
  }
3056
  // Handle `tuple_element_addr`.
3057
  // Adjoint projection: a `tuple_element_addr` into the base adjoint buffer.
3058
15.8k
  if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
3059
1.16k
    auto source = teai->getOperand();
3060
1.16k
    auto adjSource = getAdjointBuffer(origBB, source);
3061
1.16k
    if (!adjSource->getType().is<TupleType>())
3062
200
      return adjSource;
3063
960
    auto origTupleTy = source->getType().castTo<TupleType>();
3064
960
    unsigned adjIndex = 0;
3065
960
    for (unsigned i : range(teai->getFieldIndex())) {
3066
384
      if (getTangentSpace(
3067
384
              origTupleTy->getElement(i).getType()->getCanonicalType()))
3068
328
        ++adjIndex;
3069
384
    }
3070
960
    return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex);
3071
1.16k
  }
3072
  // Handle `ref_element_addr`.
3073
  // Adjoint projection: a local allocation initialized with the corresponding
3074
  // field value from the class's base adjoint value.
3075
14.6k
  if (auto *reai = dyn_cast<RefElementAddrInst>(originalProjection)) {
3076
164
    assert(!reai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
3077
164
           "`@noDerivative` class projections should never be active");
3078
0
    auto loc = reai->getLoc();
3079
    // Get the class operand, stripping `begin_borrow`.
3080
164
    auto classOperand = stripBorrow(reai->getOperand());
3081
164
    auto classType = remapType(reai->getOperand()->getType()).getASTType();
3082
164
    auto *tanField =
3083
164
        getTangentStoredProperty(getContext(), reai->getField(), classType,
3084
164
                                 reai->getLoc(), getInvoker());
3085
164
    assert(tanField && "Invalid projections should have been diagnosed");
3086
    // Create a local allocation for the element adjoint buffer.
3087
0
    auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
3088
164
    auto eltTanSILType =
3089
164
        remapType(SILType::getPrimitiveAddressType(eltTanType));
3090
164
    auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
3091
    // Check the class operand's `TangentVector` value category.
3092
164
    switch (getTangentValueCategory(classOperand)) {
3093
56
    case SILValueCategory::Object: {
3094
      // Get the class operand's adjoint value. Currently, it must be a
3095
      // `TangentVector` struct.
3096
56
      auto adjClass =
3097
56
          materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc);
3098
56
      builder.emitScopedBorrowOperation(
3099
56
          loc, adjClass, [&](SILValue borrowedAdjClass) {
3100
            // Initialize the element adjoint buffer with the base adjoint
3101
            // value.
3102
56
            auto *adjElt =
3103
56
                builder.createStructExtract(loc, borrowedAdjClass, tanField);
3104
56
            auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt);
3105
56
            builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer,
3106
56
                                            StoreOwnershipQualifier::Init);
3107
56
          });
3108
56
      return eltAdjBuffer;
3109
0
    }
3110
108
    case SILValueCategory::Address: {
3111
      // Get the class operand's adjoint buffer. Currently, it must be a
3112
      // `TangentVector` struct.
3113
108
      auto adjClass = getAdjointBuffer(origBB, classOperand);
3114
      // Initialize the element adjoint buffer with the base adjoint buffer.
3115
108
      auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField);
3116
108
      builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake,
3117
108
                             IsInitialization);
3118
108
      return eltAdjBuffer;
3119
0
    }
3120
164
    }
3121
164
  }
3122
  // Handle `begin_access`.
3123
  // Adjoint projection: the base adjoint buffer itself.
3124
14.4k
  if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
3125
3.92k
    auto adjBase = getAdjointBuffer(origBB, bai->getOperand());
3126
3.92k
    if (errorOccurred)
3127
0
      return (bufferMap[{origBB, originalProjection}] = SILValue());
3128
    // Return the base buffer's adjoint buffer.
3129
3.92k
    return adjBase;
3130
3.92k
  }
3131
  // Handle `array.uninitialized_intrinsic` application element addresses.
3132
  // Adjoint projection: a local allocation initialized by applying
3133
  // `Array.TangentVector.subscript` to the base array's adjoint value.
3134
10.5k
  auto *ai =
3135
10.5k
      getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection);
3136
10.5k
  auto *definingInst = dyn_cast_or_null<SingleValueInstruction>(
3137
10.5k
      originalProjection->getDefiningInstruction());
3138
10.5k
  bool isAllocateUninitializedArrayIntrinsicElementAddress =
3139
10.5k
      ai && definingInst &&
3140
10.5k
      (isa<PointerToAddressInst>(definingInst) ||
3141
488
       isa<IndexAddrInst>(definingInst));
3142
10.5k
  if (isAllocateUninitializedArrayIntrinsicElementAddress) {
3143
    // Get the array element index of the result address.
3144
488
    int eltIndex = 0;
3145
488
    if (auto *iai = dyn_cast<IndexAddrInst>(definingInst)) {
3146
124
      auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
3147
124
      eltIndex = ili->getValue().getLimitedValue();
3148
124
    }
3149
    // Get the array adjoint value.
3150
488
    SILValue arrayAdjoint;
3151
488
    assert(ai && "Expected `array.uninitialized_intrinsic` application");
3152
488
    for (auto use : ai->getUses()) {
3153
488
      auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
3154
488
      if (!dti)
3155
0
        continue;
3156
488
      assert(!arrayAdjoint && "Array adjoint already found");
3157
      // The first `destructure_tuple` result is the `Array` value.
3158
0
      auto arrayValue = dti->getResult(0);
3159
488
      arrayAdjoint = materializeAdjointDirect(
3160
488
          getAdjointValue(origBB, arrayValue), definingInst->getLoc());
3161
488
    }
3162
488
    assert(arrayAdjoint && "Array does not have adjoint value");
3163
    // Apply `Array.TangentVector.subscript` to get array element adjoint value.
3164
0
    auto *eltAdjBuffer =
3165
488
        getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc());
3166
488
    return eltAdjBuffer;
3167
488
  }
3168
10.0k
  return SILValue();
3169
10.5k
}
3170
3171
//----------------------------------------------------------------------------//
3172
// Adjoint value accumulation
3173
//----------------------------------------------------------------------------//
3174
3175
AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
3176
3.06k
    AdjointValue lhs, AdjointValue rhs, SILLocation loc) {
3177
3.06k
  LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint directly.\nLHS: "
3178
3.06k
                                << lhs << "\nRHS: " << rhs << '\n');
3179
3.06k
  switch (lhs.getKind()) {
3180
  // x
3181
2.63k
  case AdjointValueKind::Concrete: {
3182
2.63k
    auto lhsVal = lhs.getConcreteValue();
3183
2.63k
    switch (rhs.getKind()) {
3184
    // x + y
3185
2.34k
    case AdjointValueKind::Concrete: {
3186
2.34k
      auto rhsVal = rhs.getConcreteValue();
3187
2.34k
      auto sum = recordTemporary(builder.emitAdd(loc, lhsVal, rhsVal));
3188
2.34k
      return makeConcreteAdjointValue(sum);
3189
0
    }
3190
    // x + 0 => x
3191
152
    case AdjointValueKind::Zero:
3192
152
      return lhs;
3193
    // x + (y, z) => (x.0 + y, x.1 + z)
3194
80
    case AdjointValueKind::Aggregate: {
3195
80
      SmallVector<AdjointValue, 8> newElements;
3196
80
      auto lhsTy = lhsVal->getType().getASTType();
3197
80
      auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
3198
80
      if (lhsTy->is<TupleType>()) {
3199
64
        auto elts = builder.createDestructureTuple(loc, lhsValCopy);
3200
64
        llvm::for_each(elts->getResults(),
3201
128
                       [this](SILValue result) { recordTemporary(result); });
3202
128
        for (auto i : indices(elts->getResults())) {
3203
128
          auto rhsElt = rhs.getAggregateElement(i);
3204
128
          newElements.push_back(accumulateAdjointsDirect(
3205
128
              makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
3206
128
        }
3207
64
      } else if (lhsTy->getStructOrBoundGenericStruct()) {
3208
16
        auto elts =
3209
16
            builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy);
3210
16
        llvm::for_each(elts->getResults(),
3211
16
                       [this](SILValue result) { recordTemporary(result); });
3212
16
        for (unsigned i : indices(elts->getResults())) {
3213
16
          auto rhsElt = rhs.getAggregateElement(i);
3214
16
          newElements.push_back(accumulateAdjointsDirect(
3215
16
              makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
3216
16
        }
3217
16
      } else {
3218
0
        llvm_unreachable("Not an aggregate type");
3219
0
      }
3220
80
      return makeAggregateAdjointValue(lhsVal->getType(), newElements);
3221
0
    }
3222
    // x + (baseAdjoint, index, eltToAdd) => (x+baseAdjoint, index, eltToAdd)
3223
56
    case AdjointValueKind::AddElement: {
3224
56
      auto *addElementValue = rhs.getAddElementValue();
3225
56
      auto baseAdjoint = addElementValue->baseAdjoint;
3226
56
      auto eltToAdd = addElementValue->eltToAdd;
3227
3228
56
      auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
3229
56
      return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
3230
56
                                        addElementValue->fieldLocator);
3231
0
    }
3232
2.63k
    }
3233
2.63k
  }
3234
  // 0
3235
192
  case AdjointValueKind::Zero:
3236
    // 0 + x => x
3237
192
    return rhs;
3238
  // (x, y)
3239
36
  case AdjointValueKind::Aggregate: {
3240
36
    switch (rhs.getKind()) {
3241
    // (x, y) + z => (z.0 + x, z.1 + y)
3242
0
    case AdjointValueKind::Concrete:
3243
0
      return accumulateAdjointsDirect(rhs, lhs, loc);
3244
    // x + 0 => x
3245
4
    case AdjointValueKind::Zero:
3246
4
      return lhs;
3247
    // (x, y) + (z, w) => (x + z, y + w)
3248
32
    case AdjointValueKind::Aggregate: {
3249
32
      SmallVector<AdjointValue, 8> newElements;
3250
32
      for (auto i : range(lhs.getNumAggregateElements()))
3251
64
        newElements.push_back(accumulateAdjointsDirect(
3252
64
            lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc));
3253
32
      return makeAggregateAdjointValue(lhs.getType(), newElements);
3254
0
    }
3255
    // (x.0, ..., x.n) + (baseAdjoint, index, eltToAdd) => (x + baseAdjoint,
3256
    // index, eltToAdd)
3257
0
    case AdjointValueKind::AddElement: {
3258
0
      auto *addElementValue = rhs.getAddElementValue();
3259
0
      auto baseAdjoint = addElementValue->baseAdjoint;
3260
0
      auto eltToAdd = addElementValue->eltToAdd;
3261
0
      auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
3262
3263
0
      return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
3264
0
                                        addElementValue->fieldLocator);
3265
0
    }
3266
36
    }
3267
36
  }
3268
  // (baseAdjoint, index, eltToAdd)
3269
196
  case AdjointValueKind::AddElement: {
3270
196
    switch (rhs.getKind()) {
3271
36
    case AdjointValueKind::Zero:
3272
36
      return lhs;
3273
    // (baseAdjoint, index, eltToAdd) + x => (x + baseAdjoint, index, eltToAdd)
3274
20
    case AdjointValueKind::Concrete:
3275
    // (baseAdjoint, index, eltToAdd) + (x.0, ..., x.n) => (x + baseAdjoint,
3276
    // index, eltToAdd)
3277
20
    case AdjointValueKind::Aggregate:
3278
20
      return accumulateAdjointsDirect(rhs, lhs, loc);
3279
    // (baseAdjoint1, index1, eltToAdd1) + (baseAdjoint2, index2, eltToAdd2)
3280
    // => ((baseAdjoint1 + baseAdjoint2, index1, eltToAdd1), index2, eltToAdd2)
3281
140
    case AdjointValueKind::AddElement: {
3282
140
      auto *addElementValueLhs = lhs.getAddElementValue();
3283
140
      auto baseAdjointLhs = addElementValueLhs->baseAdjoint;
3284
140
      auto eltToAddLhs = addElementValueLhs->eltToAdd;
3285
3286
140
      auto *addElementValueRhs = rhs.getAddElementValue();
3287
140
      auto baseAdjointRhs = addElementValueRhs->baseAdjoint;
3288
140
      auto eltToAddRhs = addElementValueRhs->eltToAdd;
3289
3290
140
      auto sumOfBaseAdjoints =
3291
140
          accumulateAdjointsDirect(baseAdjointLhs, baseAdjointRhs, loc);
3292
140
      auto newBaseAdjoint = makeAddElementAdjointValue(
3293
140
          sumOfBaseAdjoints, eltToAddLhs, addElementValueLhs->fieldLocator);
3294
3295
140
      return makeAddElementAdjointValue(newBaseAdjoint, eltToAddRhs,
3296
140
                                        addElementValueRhs->fieldLocator);
3297
20
    }
3298
196
    }
3299
196
  }
3300
3.06k
  }
3301
0
  llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
3302
0
}
3303
3304
//----------------------------------------------------------------------------//
3305
// Array literal initialization differentiation
3306
//----------------------------------------------------------------------------//
3307
3308
void PullbackCloner::Implementation::
3309
    accumulateArrayLiteralElementAddressAdjoints(SILBasicBlock *origBB,
3310
                                                 SILValue originalValue,
3311
                                                 AdjointValue arrayAdjointValue,
3312
2.63k
                                                 SILLocation loc) {
3313
  // Return if the original value is not the `Array` result of an
3314
  // `array.uninitialized_intrinsic` application.
3315
2.63k
  auto *dti = dyn_cast_or_null<DestructureTupleInst>(
3316
2.63k
      originalValue->getDefiningInstruction());
3317
2.63k
  if (!dti)
3318
2.54k
    return;
3319
92
  if (!ArraySemanticsCall(dti->getOperand(),
3320
92
                          semantics::ARRAY_UNINITIALIZED_INTRINSIC))
3321
32
    return;
3322
60
  if (originalValue != dti->getResult(0))
3323
0
    return;
3324
  // Accumulate the array's adjoint value into the adjoint buffers of its
3325
  // element addresses: `pointer_to_address` and `index_addr` instructions.
3326
60
  LLVM_DEBUG(getADDebugStream()
3327
60
             << "Accumulating adjoint value for array literal into element "
3328
60
                "address adjoint buffers"
3329
60
             << originalValue);
3330
60
  auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc);
3331
60
  builder.setCurrentDebugScope(remapScope(dti->getDebugScope()));
3332
60
  builder.setInsertionPoint(arrayAdjoint->getParentBlock());
3333
60
  for (auto use : dti->getResult(1)->getUses()) {
3334
60
    auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser());
3335
60
    auto adjBuf = getAdjointBuffer(origBB, ptai);
3336
60
    auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc);
3337
60
    builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
3338
72
    for (auto use : ptai->getUses()) {
3339
72
      if (auto *iai = dyn_cast<IndexAddrInst>(use->getUser())) {
3340
12
        auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
3341
12
        auto eltIndex = ili->getValue().getLimitedValue();
3342
12
        auto adjBuf = getAdjointBuffer(origBB, iai);
3343
12
        auto *eltAdjBuf =
3344
12
            getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc);
3345
12
        builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
3346
12
      }
3347
72
    }
3348
60
  }
3349
60
}
3350
3351
AllocStackInst *PullbackCloner::Implementation::getArrayAdjointElementBuffer(
3352
560
    SILValue arrayAdjoint, int eltIndex, SILLocation loc) {
3353
560
  auto &ctx = builder.getASTContext();
3354
560
  auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType());
3355
560
  auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>();
3356
560
  auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType();
3357
560
  auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType));
3358
  // Get `function_ref` and generic signature of
3359
  // `Array.TangentVector.subscript.getter`.
3360
560
  auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct();
3361
560
  auto subscriptLookup =
3362
560
      arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript());
3363
560
  SubscriptDecl *subscriptDecl = nullptr;
3364
1.05k
  for (auto *candidate : subscriptLookup) {
3365
1.05k
    auto candidateModule = candidate->getModuleContext();
3366
1.05k
    if (candidateModule->getName() == ctx.Id_Differentiation ||
3367
1.05k
        candidateModule->isStdlibModule()) {
3368
560
      assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s");
3369
0
      subscriptDecl = cast<SubscriptDecl>(candidate);
3370
#ifdef NDEBUG
3371
      break;
3372
#endif
3373
560
    }
3374
1.05k
  }
3375
560
  assert(subscriptDecl && "No `Array.TangentVector.subscript`");
3376
0
  auto *subscriptGetterDecl =
3377
560
      subscriptDecl->getOpaqueAccessor(AccessorKind::Get);
3378
560
  assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter");
3379
0
  SILOptFunctionBuilder fb(getContext().getTransform());
3380
560
  auto *subscriptGetterFn = fb.getOrCreateFunction(
3381
560
      loc, SILDeclRef(subscriptGetterDecl), NotForDefinition);
3382
  // %subscript_fn = function_ref @Array.TangentVector<T>.subscript.getter
3383
560
  auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn);
3384
560
  auto subscriptFnGenSig =
3385
560
      subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature();
3386
  // Apply `Array.TangentVector.subscript.getter` to get array element adjoint
3387
  // buffer.
3388
  // %index_literal = integer_literal $Builtin.IntXX, <index>
3389
560
  auto builtinIntType =
3390
560
      SILType::getPrimitiveObjectType(ctx.getIntDecl()
3391
560
                                          ->getStoredProperties()
3392
560
                                          .front()
3393
560
                                          ->getInterfaceType()
3394
560
                                          ->getCanonicalType());
3395
560
  auto *eltIndexLiteral =
3396
560
      builder.createIntegerLiteral(loc, builtinIntType, eltIndex);
3397
560
  auto intType = SILType::getPrimitiveObjectType(
3398
560
      ctx.getIntType()->getCanonicalType());
3399
  // %index_int = struct $Int (%index_literal)
3400
560
  auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral});
3401
560
  auto *swiftModule = getModule().getSwiftModule();
3402
560
  auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
3403
560
  auto diffConf = swiftModule->lookupConformance(eltTanType, diffProto);
3404
560
  assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
3405
0
  auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
3406
560
  auto addArithConf = swiftModule->lookupConformance(eltTanType, addArithProto);
3407
560
  assert(!addArithConf.isInvalid() &&
3408
560
         "Missing conformance to `AdditiveArithmetic`");
3409
0
  auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType},
3410
560
                                     {addArithConf, diffConf});
3411
  // %elt_adj = alloc_stack $T.TangentVector
3412
  // Create and register a local allocation.
3413
560
  auto *eltAdjBuffer = createFunctionLocalAllocation(
3414
560
      eltTanSILType, loc, /*zeroInitialize*/ true);
3415
  // Immediately destroy the emitted zero value.
3416
  // NOTE: It is not efficient to emit a zero value then immediately destroy
3417
  // it. However, it was the easiest way to to avoid "lifetime mismatch in
3418
  // predecessors" memory lifetime verification errors for control flow
3419
  // differentiation.
3420
  // Perhaps we can avoid emitting a zero value if local allocations are created
3421
  // per pullback bb instead of all in the pullback entry: TF-1075.
3422
560
  builder.emitDestroyOperation(loc, eltAdjBuffer);
3423
  // apply %subscript_fn<T.TangentVector>(%elt_adj, %index_int, %array_adj)
3424
560
  builder.createApply(loc, subscriptFnRef, subMap,
3425
560
                      {eltAdjBuffer, eltIndexInt, arrayAdjoint});
3426
560
  return eltAdjBuffer;
3427
560
}
3428
3429
} // end namespace autodiff
3430
} // end namespace swift