/Volumes/compiler/apple/swift/lib/SILOptimizer/Differentiation/AdjointValue.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- AdjointValue.h - Helper class for differentiation ----*- C++ -*---===// |
2 | | // |
3 | | // This source file is part of the Swift.org open source project |
4 | | // |
5 | | // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
6 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
7 | | // |
8 | | // See https://swift.org/LICENSE.txt for license information |
9 | | // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | // |
13 | | // AdjointValue - a symbolic representation for adjoint values enabling |
14 | | // efficient differentiation by avoiding zero materialization. |
15 | | // |
16 | | //===----------------------------------------------------------------------===// |
17 | | |
18 | | #define DEBUG_TYPE "differentiation" |
19 | | |
20 | | #include "swift/SILOptimizer/Differentiation/AdjointValue.h" |
21 | | |
22 | 5.58k | void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const { |
23 | 5.58k | switch (getKind()) { |
24 | 924 | case AdjointValueKind::Zero: |
25 | 924 | s << "Zero[" << getType() << ']'; |
26 | 924 | break; |
27 | 68 | case AdjointValueKind::Aggregate: |
28 | 68 | s << "Aggregate[" << getType() << "]("; |
29 | 68 | if (auto *decl = getType().getASTType()->getStructOrBoundGenericStruct()) { |
30 | 4 | interleave( |
31 | 4 | llvm::zip(decl->getStoredProperties(), getAggregateElements()), |
32 | 4 | [&s](std::tuple<VarDecl *, const AdjointValue &> elt) { |
33 | 4 | s << std::get<0>(elt)->getName() << ": "; |
34 | 4 | std::get<1>(elt).print(s); |
35 | 4 | }, |
36 | 4 | [&s] { s << ", "; }); |
37 | 64 | } else if (getType().is<TupleType>()) { |
38 | 64 | interleave( |
39 | 64 | getAggregateElements(), |
40 | 128 | [&s](const AdjointValue &elt) { elt.print(s); }, [&s] { s << ", "; }); |
41 | 64 | } else { |
42 | 0 | llvm_unreachable("Invalid aggregate"); |
43 | 0 | } |
44 | 68 | s << ')'; |
45 | 68 | break; |
46 | 4.35k | case AdjointValueKind::Concrete: |
47 | 4.35k | s << "Concrete[" << getType() << "](" << base->value.concrete << ')'; |
48 | 4.35k | break; |
49 | 240 | case AdjointValueKind::AddElement: |
50 | 240 | auto *addElementValue = getAddElementValue(); |
51 | 240 | auto baseAdjoint = addElementValue->baseAdjoint; |
52 | 240 | auto eltToAdd = addElementValue->eltToAdd; |
53 | | |
54 | 240 | s << "AddElement["; |
55 | 240 | baseAdjoint.print(s); |
56 | | |
57 | 240 | s << ", Field("; |
58 | 240 | if (addElementValue->isTupleAdjoint()) { |
59 | 0 | s << addElementValue->getFieldIndex(); |
60 | 240 | } else { |
61 | 240 | s << addElementValue->getFieldDecl()->getNameStr(); |
62 | 240 | } |
63 | 240 | s << "), "; |
64 | | |
65 | 240 | eltToAdd.print(s); |
66 | | |
67 | 240 | s << "]"; |
68 | 240 | break; |
69 | 5.58k | } |
70 | 5.58k | } |