Autodiff Coverage for full test suite

Coverage Report

Created: 2023-11-30 18:54

/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/ADContext.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- ADContext.cpp - Differentiation Context --------------*- C++ -*---===//
2
//
3
// This source file is part of the Swift.org open source project
4
//
5
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
//
8
// See https://swift.org/LICENSE.txt for license information
9
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
//
11
//===----------------------------------------------------------------------===//
12
//
13
// Per-module contextual information for the differentiation transform.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#define DEBUG_TYPE "differentiation"
18
19
#include "swift/SILOptimizer/Differentiation/ADContext.h"
20
#include "swift/AST/DiagnosticsSIL.h"
21
#include "swift/AST/SourceFile.h"
22
#include "swift/SILOptimizer/PassManager/Transforms.h"
23
24
using llvm::DenseMap;
25
using llvm::SmallPtrSet;
26
using llvm::SmallVector;
27
28
namespace swift {
29
namespace autodiff {
30
31
//===----------------------------------------------------------------------===//
32
// Local helpers
33
//===----------------------------------------------------------------------===//
34
35
/// Given an operator name, such as '+', and a protocol, returns the '+'
36
/// operator. If the operator does not exist in the protocol, returns null.
37
static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName,
38
404
                                            ProtocolDecl *protocol) {
39
404
  assert(operatorName.isOperator());
40
  // Find the operator requirement in the given protocol declaration.
41
0
  auto opLookup = protocol->lookupDirect(operatorName);
42
404
  for (auto *decl : opLookup) {
43
404
    if (!decl->isProtocolRequirement())
44
0
      continue;
45
404
    auto *fd = dyn_cast<FuncDecl>(decl);
46
404
    if (!fd || !fd->isStatic() || !fd->isOperator())
47
0
      continue;
48
404
    return fd;
49
404
  }
50
  // Not found.
51
0
  return nullptr;
52
404
}
53
54
//===----------------------------------------------------------------------===//
55
// ADContext methods
56
//===----------------------------------------------------------------------===//
57
58
ADContext::ADContext(SILModuleTransform &transform)
59
    : transform(transform), module(*transform.getModule()),
60
24.3k
      passManager(*transform.getPassManager()) {}
61
62
/// Get the source file for the given `SILFunction`.
63
6.61k
static SourceFile &getSourceFile(SILFunction *f) {
64
6.61k
  if (f->hasLocation())
65
6.61k
    if (auto *declContext = f->getLocation().getAsDeclContext())
66
6.49k
      if (auto *parentSourceFile = declContext->getParentSourceFile())
67
6.49k
        return *parentSourceFile;
68
120
  for (auto *file : f->getModule().getSwiftModule()->getFiles())
69
120
    if (auto *sourceFile = dyn_cast<SourceFile>(file))
70
120
      return *sourceFile;
71
0
  llvm_unreachable("Could not resolve SourceFile from SILFunction");
72
0
}
73
74
SynthesizedFileUnit &
75
6.61k
ADContext::getOrCreateSynthesizedFile(SILFunction *original) {
76
6.61k
  auto &SF = getSourceFile(original);
77
6.61k
  return SF.getOrCreateSynthesizedFile();
78
6.61k
}
79
80
2.34k
FuncDecl *ADContext::getPlusDecl() const {
81
2.34k
  if (!cachedPlusFn) {
82
128
    cachedPlusFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+"),
83
128
                                              additiveArithmeticProtocol);
84
128
    assert(cachedPlusFn && "AdditiveArithmetic.+ not found");
85
128
  }
86
0
  return cachedPlusFn;
87
2.34k
}
88
89
8.19k
FuncDecl *ADContext::getPlusEqualDecl() const {
90
8.19k
  if (!cachedPlusEqualFn) {
91
276
    cachedPlusEqualFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+="),
92
276
                                                   additiveArithmeticProtocol);
93
276
    assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found");
94
276
  }
95
0
  return cachedPlusEqualFn;
96
8.19k
}
97
98
19.5k
AccessorDecl *ADContext::getAdditiveArithmeticZeroGetter() const {
99
19.5k
  if (cachedZeroGetter)
100
19.2k
    return cachedZeroGetter;
101
296
  auto zeroDeclLookup = getAdditiveArithmeticProtocol()
102
296
      ->lookupDirect(getASTContext().Id_zero);
103
296
  auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
104
296
  assert(zeroDecl->isProtocolRequirement());
105
0
  cachedZeroGetter = zeroDecl->getOpaqueAccessor(AccessorKind::Get);
106
296
  return cachedZeroGetter;
107
19.5k
}
108
109
24
void ADContext::cleanUp() {
110
  // Delete all references to generated functions.
111
996
  for (auto fnRef : generatedFunctionReferences) {
112
996
    if (auto *fnRefInst =
113
996
            peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
114
4
      fnRefInst->replaceAllUsesWithUndef();
115
4
      fnRefInst->eraseFromParent();
116
4
    }
117
996
  }
118
  // Delete all generated functions.
119
1.78k
  for (auto *generatedFunction : generatedFunctions) {
120
1.78k
    LLVM_DEBUG(getADDebugStream() << "Deleting generated function "
121
1.78k
                                  << generatedFunction->getName() << '\n');
122
1.78k
    generatedFunction->dropAllReferences();
123
1.78k
    transform.notifyWillDeleteFunction(generatedFunction);
124
1.78k
    module.eraseFunction(generatedFunction);
125
1.78k
  }
126
24
}
127
128
DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
129
    SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
130
    IndexSubset *resultIndices, SILValue original,
131
19.0k
    llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions) {
132
19.0k
  auto *dfi = builder.createDifferentiableFunction(
133
19.0k
      loc, parameterIndices, resultIndices, original, derivativeFunctions);
134
19.0k
  processedDifferentiableFunctionInsts.erase(dfi);
135
19.0k
  return dfi;
136
19.0k
}
137
138
LinearFunctionInst *ADContext::createLinearFunction(
139
    SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
140
12
    SILValue original, llvm::Optional<SILValue> transposeFunction) {
141
12
  auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
142
12
                                           transposeFunction);
143
12
  processedLinearFunctionInsts.erase(lfi);
144
12
  return lfi;
145
12
}
146
147
DifferentiableFunctionExpr *
148
44
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
149
44
  return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
150
44
}
151
152
LinearFunctionExpr *
153
0
ADContext::findDifferentialOperator(LinearFunctionInst *inst) {
154
0
  return inst->getLoc().getAsASTNode<LinearFunctionExpr>();
155
0
}
156
157
} // end namespace autodiff
158
} // end namespace swift