Skip to content
Snippets Groups Projects
Commit cfc1e933 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Update ConstantFolding recipes with arg to consider Shape as constant.

parent c5492e23
No related branches found
No related tags found
1 merge request!297Reshape forward dims
......@@ -22,7 +22,13 @@
namespace Aidge {
void constantFolding(std::shared_ptr<GraphView> graph);
/**
* @brief Retrieve part of the graph that can be pre-computed and replace them by a Producer.
*
* @param graph Graph to fold the constant
* @param constant_shape If true Shape operators are considered to be constant
*/
void constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false);
// FUSE MATMUL + ADD -> FC
......
......@@ -25,6 +25,14 @@ namespace Aidge {
void init_Recipes(py::module &m)
{
m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter(
Retrieve part of the graph that can be pre-computed and replace them by a Producer.
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
:param constant_shape: If true, ``Shape`` operator are considered constant, default=False
:type constant_shape: bool, optional
)mydelimiter");
m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
......
......@@ -17,17 +17,18 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Shape.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) {
bool folded;
do {
folded = false;
std::set<std::shared_ptr<Node>> candidates;
for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) {
if (nodePtr->type() == Producer_Op::Type) {
if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() != Shape_Op::Type))) {
const auto& childs = nodePtr->getChildren();
candidates.insert(childs.begin(), childs.end());
}
......@@ -39,17 +40,18 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
size_t i = 0;
for (const auto& input : node->inputs()) {
if (input.first) {
if (input.first->type() != Producer_Op::Type) {
if (input.first->type() != Producer_Op::Type || (constantShape && (input.first->type() != Shape_Op::Type))) {
foldable = false;
break;
}
const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
if (!producer->constant()) {
Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
node->name(), node->type(), input.first->name());
foldable = false;
break;
if (input.first->type() == Producer_Op::Type){
const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
if (!producer->constant()) {
Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
node->name(), node->type(), input.first->name());
foldable = false;
break;
}
}
replaceGraph->add(input.first, false);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment