Skip to content
Snippets Groups Projects
Commit 959c1655 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added constantFolding recipe

parent fbad49a1
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,8 @@
namespace Aidge {
void constantFolding(std::shared_ptr<GraphView> graph);
// FUSE MATMUL + ADD -> FC
/**
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
bool folded;
do {
folded = false;
std::set<std::shared_ptr<Node>> candidates;
for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) {
if (nodePtr->type() == Producer_Op::Type) {
const auto& childs = nodePtr->getChildren();
candidates.insert(childs.begin(), childs.end());
}
}
for (const auto& node : candidates) {
bool foldable = true;
auto replaceGraph = std::make_shared<GraphView>();
for (const auto& input : node->inputs()) {
if (input.first) {
if (input.first->type() != Producer_Op::Type) {
foldable = false;
break;
}
const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
if (!producer->getAttr<bool>("Constant")) {
Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
node->name(), node->type(), input.first->name());
foldable = false;
break;
}
replaceGraph->add(input.first, false);
}
}
if (foldable) {
Log::info("Folding node {} (of type {})", node->name(), node->type());
replaceGraph->add(node, false);
node->forward();
auto prodGraph = std::make_shared<GraphView>();
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
for (IOIndex_t output = 0; output < node->nbOutputs(); ++output) {
const auto computedOutput = std::make_shared<Tensor>(op->getOutput(output)->clone());
const auto newProd = Producer(computedOutput, node->name() + "_" + std::to_string(output), true);
// Add output in right order
prodGraph->add(newProd);
}
if (GraphView::replace(replaceGraph, prodGraph)) {
folded = true;
}
else {
Log::warn("Error with replace when folding node {} (of type {})",
node->name(), node->type());
}
}
}
}
while (folded);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment