From d73de8080f43100db336d55cffbdd4ce17cd7f7e Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 28 Feb 2025 15:23:02 +0000 Subject: [PATCH] Fix constant folding with shape as constant. --- src/recipes/ConstantFolding.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp index 05c92afb3..40cd30a7d 100644 --- a/src/recipes/ConstantFolding.cpp +++ b/src/recipes/ConstantFolding.cpp @@ -30,7 +30,7 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape folded = false; std::set<std::shared_ptr<Node>> candidates; for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) { - if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() != Shape_Op::Type))) { + if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() == Shape_Op::Type))) { const auto& childs = nodePtr->getChildren(); candidates.insert(childs.begin(), childs.end()); } @@ -44,8 +44,8 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape for (const auto& input : node->inputs()) { if (input.first) { if (!(input.first->type() == Producer_Op::Type || (constantShape && (input.first->type() == Shape_Op::Type)))) { - Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant. With constant = {}", - i, node->name(), node->type(), input.first->name(), input.first->type(), constantShape); + Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant.", + i, node->name(), node->type(), input.first->name(), input.first->type()); foldable = false; break; } @@ -98,7 +98,14 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape // Add output in right order prodGraph->add(newProd); } - + Log::debug("Trying to replace:"); + for(auto nodeToReplace: replaceGraph->getNodes()){ + Log::debug("\t- {} ({})", nodeToReplace->name(), nodeToReplace->type()); + } + Log::debug("With:"); + for(auto nodeReplacing: prodGraph->getNodes()){ + Log::debug("\t- {} ({})", nodeReplacing->name(), nodeReplacing->type()); + } if (GraphView::replace(replaceGraph, prodGraph)) { folded = true; modified = true; -- GitLab