Skip to content
Snippets Groups Projects
Commit d73de808 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix constant folding with shape as constant.

parent 02c2276f
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!297Reshape forward dims
...@@ -30,7 +30,7 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape ...@@ -30,7 +30,7 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
folded = false; folded = false;
std::set<std::shared_ptr<Node>> candidates; std::set<std::shared_ptr<Node>> candidates;
for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) { for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) {
if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() != Shape_Op::Type))) { if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() == Shape_Op::Type))) {
const auto& childs = nodePtr->getChildren(); const auto& childs = nodePtr->getChildren();
candidates.insert(childs.begin(), childs.end()); candidates.insert(childs.begin(), childs.end());
} }
...@@ -44,8 +44,8 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape ...@@ -44,8 +44,8 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
for (const auto& input : node->inputs()) { for (const auto& input : node->inputs()) {
if (input.first) { if (input.first) {
if (!(input.first->type() == Producer_Op::Type || (constantShape && (input.first->type() == Shape_Op::Type)))) { if (!(input.first->type() == Producer_Op::Type || (constantShape && (input.first->type() == Shape_Op::Type)))) {
Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant. With constant = {}", Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant.",
i, node->name(), node->type(), input.first->name(), input.first->type(), constantShape); i, node->name(), node->type(), input.first->name(), input.first->type());
foldable = false; foldable = false;
break; break;
} }
...@@ -98,7 +98,14 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape ...@@ -98,7 +98,14 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
// Add output in right order // Add output in right order
prodGraph->add(newProd); prodGraph->add(newProd);
} }
Log::debug("Trying to replace:");
for(auto nodeToReplace: replaceGraph->getNodes()){
Log::debug("\t- {} ({})", nodeToReplace->name(), nodeToReplace->type());
}
Log::debug("With:");
for(auto nodeReplacing: prodGraph->getNodes()){
Log::debug("\t- {} ({})", nodeReplacing->name(), nodeReplacing->type());
}
if (GraphView::replace(replaceGraph, prodGraph)) { if (GraphView::replace(replaceGraph, prodGraph)) {
folded = true; folded = true;
modified = true; modified = true;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment