diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp index 40cd30a7d908ef04fe1b02a55106211f6153fa38..b2fc427b3982360e2ea31fc61e45819d2d2c0232 100644 --- a/src/recipes/ConstantFolding.cpp +++ b/src/recipes/ConstantFolding.cpp @@ -40,6 +40,7 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape Log::debug("Checking if node {} (of type {}) is foldable", node->name(), node->type()); bool foldable = true; auto replaceGraph = std::make_shared<GraphView>(); + std::shared_ptr<DynamicAttributes> attributes; size_t i = 0; for (const auto& input : node->inputs()) { if (input.first) { @@ -70,6 +71,18 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape } replaceGraph->add(input.first, false); + + if (!input.first->attributes()->getAttrs().empty()) { + if (!attributes) { + attributes = input.first->attributes(); + } + else { + // We could merge attributes here, but there would be a risk of duplicates. + // For now, emit a notice to the user, and let's see if there is an actual use case. + Log::notice("Cannot propagate attributes from Producer input#{} {} because a previous input already propagated its attributes", + i, input.first->name()); + } + } } else if (node->inputCategory(i) != InputCategory::OptionalData && node->inputCategory(i) != InputCategory::OptionalParam) @@ -93,7 +106,12 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape for (IOIndex_t output = 0; output < node->nbOutputs(); ++output) { const auto computedOutput = std::make_shared<Tensor>(op->getOutput(output)->clone()); - const auto newProd = Producer(computedOutput, node->name() + "_" + std::to_string(output), true); + auto newProd = Producer(computedOutput, node->name() + "_" + std::to_string(output), true); + + // Propagate node attributes from initial Producer to the new one + if (attributes) { + newProd->attributes() = attributes; + } // Add output in right order prodGraph->add(newProd);