Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
bool folded;
do {
folded = false;
std::set<std::shared_ptr<Node>> candidates;
for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) {
if (nodePtr->type() == Producer_Op::Type) {
const auto& childs = nodePtr->getChildren();
candidates.insert(childs.begin(), childs.end());
}
}
for (const auto& node : candidates) {
bool foldable = true;
auto replaceGraph = std::make_shared<GraphView>();
for (const auto& input : node->inputs()) {
if (input.first) {
if (input.first->type() != Producer_Op::Type) {
foldable = false;
break;
}
const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
if (!producer->getAttr<bool>("Constant")) {
Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
node->name(), node->type(), input.first->name());
foldable = false;
break;
}
replaceGraph->add(input.first, false);
}
}
if (foldable) {
Log::info("Folding node {} (of type {})", node->name(), node->type());
replaceGraph->add(node, false);
node->forward();
auto prodGraph = std::make_shared<GraphView>();
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
for (IOIndex_t output = 0; output < node->nbOutputs(); ++output) {
const auto computedOutput = std::make_shared<Tensor>(op->getOutput(output)->clone());
const auto newProd = Producer(computedOutput, node->name() + "_" + std::to_string(output), true);
// Add output in right order
prodGraph->add(newProd);
}
if (GraphView::replace(replaceGraph, prodGraph)) {
folded = true;
}
else {
Log::warn("Error with replace when folding node {} (of type {})",
node->name(), node->type());
}
}
}
}
while (folded);
}