diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp index 8e23e8b3c6ce5fa3ac179bb058e7e04d5905b6b2..d69aad616bcdaedd7ffa9cdb04d02802bb998f5a 100644 --- a/src/operator/Conv.cpp +++ b/src/operator/Conv.cpp @@ -47,7 +47,7 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){ AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && (getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()), - "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels()); + "Wrong input channel size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels()); } else{ //For dataFormat in NCHW or Default Format AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && diff --git a/src/recipes/FoldConstantOfShape.cpp b/src/recipes/FoldConstantOfShape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..631c751920a41b79442d820097bd9a3f0e4e3b35 --- /dev/null +++ b/src/recipes/FoldConstantOfShape.cpp @@ -0,0 +1,112 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include "aidge/recipes/Recipes.hpp" + +#include <cstddef> // std::size_t +#include <cstdint> // std::int64_t +#include <memory> +#include <stdexcept> // std::runtime_error + +#include "aidge/data/DataType.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/ConstantOfShape.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Log.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +static bool foldIndividualConstantOfShape(const SinglePassGraphMatching::MatchingResult match, std::size_t namingId) { + const std::shared_ptr<Node> prod_node = match.graph->rootNode(); + const std::shared_ptr<Producer_Op> prod_op = + std::static_pointer_cast<Producer_Op>(prod_node->getOperator()); + prod_op->forward(); // is this REALLY needed if Producer is constant? + const std::shared_ptr<Tensor> shape = prod_op->getOutput(0); + + if (!prod_op->constant()) { + Log::debug("{} - Producer is not constant. Skipping match.", __func__); + return false; + } + if (shape->nbDims() != 1) { + Log::debug("[{}] - ConstantOfShape 'shape' input Tensor has {} != 1 dimensions. Skipping match.", + __func__, shape->nbDims()); + return false; + } + + if (shape->dataType() != DataType::Int64) { + Log::debug( + "Producer output data type is {} != {} required for ConstantOfShape " + "'shape' input. " + "Skipping match.", + prod_op->getOutput(0)->dataType(), + DataType::Int64 + ); + return false; + } + + const std::shared_ptr<Node> constantofshape_node = + prod_node->getOrderedChildren().at(0).at(0); + const std::shared_ptr<ConstantOfShape_Op> constantofshape_op = + std::static_pointer_cast<ConstantOfShape_Op>( + constantofshape_node->getOperator()); + + std::shared_ptr<GraphView> graph_to_replace = std::make_shared<GraphView>(); + graph_to_replace->add({constantofshape_node, prod_node}); + + constantofshape_op->forwardDims(true); // ConstantOfShape forwardDims is data dependent + + std::string original_backend = constantofshape_op->backend().empty() ? "cpu" : constantofshape_op->backend(); + graph_to_replace->setBackend("cpu"); // set backend to 'cpu' since speed is not the main focus here + + constantofshape_op->forward(); + + const std::shared_ptr<Tensor> newInputTensor = constantofshape_op->getOutput(0); + newInputTensor->setBackend(original_backend); // set back original backend + const std::shared_ptr<Node> newProducer = + Producer(newInputTensor, + "constantOfShape_" + constantofshape_node->name() + "_folded_" + std::to_string(namingId), + true // remains constant + ); + + std::shared_ptr<GraphView> new_graph = std::make_shared<GraphView>(); + new_graph->add(newProducer); + + return GraphView::replace(graph_to_replace, new_graph); +} + +std::size_t foldConstantOfShape(std::shared_ptr<GraphView> view) { + // this query guarantes that any Producer part of a returned match is ONLY followed by a ConstantOfShape + const auto matches = SinglePassGraphMatching(view).match("Producer->ConstantOfShape"); + + std::size_t nbReplaced = 0; + if (!Registrar<ConstantOfShape_Op>::exists("cpu")) { + Log::error("'cpu' backend not loaded. Impossible to run and fold any ConstantOfShape Operator."); + } else { + for (const auto &match : matches) { + if (!foldIndividualConstantOfShape(match, nbReplaced)) { + // TODO: this warning should be handled by GraphView::replace + Log::warn("Could not replace match with Producer"); + } else { + ++nbReplaced; + } + } + } + + Log::info("Removed [\033[1m\033[3m{}/{}\033[0m] ConstantOfShape Nodes", + nbReplaced, matches.size()); + return nbReplaced; +} + +} // namespace Aidge diff --git a/src/recipes/removeConstantOfShape.cpp b/src/recipes/removeConstantOfShape.cpp deleted file mode 100644 index e743050c2c0f13513f639a0690943e0d934f947d..0000000000000000000000000000000000000000 --- a/src/recipes/removeConstantOfShape.cpp +++ /dev/null @@ -1,125 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ -#include "aidge/recipes/Recipes.hpp" - -#include <algorithm> -#include <cassert> -#include <cstddef> -#include <cstdint> -#include <cstdio> -#include <functional> -#include <memory> -#include <numeric> -#include <set> -#include <stdexcept> -#include <string> - -#include "aidge/data/Data.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/filler/Filler.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Matching.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/operator/ConstantOfShape.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/utils/ErrorHandling.hpp" -#include "aidge/utils/Types.h" - -namespace Aidge { - -size_t removeConstantOfShape(std::shared_ptr<GraphView> graph_view) { - const auto matches = - SinglePassGraphMatching(graph_view).match("Producer->ConstantOfShape"); - - size_t nbReplaced = 0; - for (const auto &match : matches) { - const auto prod_node = match.graph->rootNode(); - const auto prod_op = - std::static_pointer_cast<Producer_Op>(prod_node->getOperator()); - - const NodePtr constantofshape_node = - prod_node->getOrderedChildren().at(0).at(0); - - const auto constantofshape_op = - std::static_pointer_cast<ConstantOfShape_Op>( - constantofshape_node->getOperator()); - - if (prod_op->getOutput(0)->nbDims() != 1) { - Log::debug("{} : Producer output dimension number is {} != 1 and {} " - "input has to have 1 dim, skipping match.", - __func__, prod_op->getOutput(0)->nbDims(), - ConstantOfShape_Op::Type); - continue; - } - if (!prod_op->constant()) { - Log::debug("{} : Producer is not constant, skipping match.", __func__); - continue; - } - if (prod_op->getOutput(0)->dataType() != DataType::Int64) { - AIDGE_THROW_OR_ABORT( - std::runtime_error, - "{} : Producer output dtype is {} != int64 and {} " - "input type is restricted to int64_t, this is an error." - "Fix your network. skipping match.", - __func__, prod_op->getOutput(0)->dataType(), - ConstantOfShape_Op::Type); - continue; - } - - auto graph_to_replace = std::make_shared<GraphView>(); - auto new_graph = std::make_shared<GraphView>(); - graph_to_replace->add(constantofshape_node); - if (prod_node->getChildren().size() == 1) { - graph_to_replace->add(prod_node); - } else { - Log::debug("{} : Producer node has multiple children, only" - "replacing the {} node.", - __func__, ConstantOfShape_Op::Type); - } - - prod_node->forward(); - std::shared_ptr<Tensor> prod_output = prod_op->getOutput(0); - std::vector<DimSize_t> new_input_dims; - new_input_dims.reserve(prod_output->dims()[0]); - for (DimSize_t i = 0; i < prod_output->size(); ++i) { - new_input_dims.push_back(prod_output->get<int64_t>(i)); - } - - auto new_input = std::make_shared<Tensor>(new_input_dims); - new_input->setBackend(prod_op->backend() == "" ? "cpu" - : prod_op->backend()); - new_input->setDataType(constantofshape_op->value().dataType()); - for (std::size_t i = 0; i < new_input->size(); ++i) { - new_input->getImpl()->copy( - constantofshape_op->value().getImpl()->rawPtr(), 1, i); - } - auto new_prod = - Producer(new_input, prod_node->name() + "_constant_of_shape", true); - new_graph->add(new_prod); - - const auto success = GraphView::replace(graph_to_replace, new_graph); - if (!success) { - Log::warn("Could not replace Producer({})->ConstantOfShape({}) with" - "Producer", - prod_node->name(), constantofshape_node->name()); - } else { - ++nbReplaced; - } - } - - Log::info("Replaced {} (out of {}) matching Producer->ConstantOfShape with " - "Producers", - nbReplaced, matches.size()); - return nbReplaced; -} -} // namespace Aidge - diff --git a/unit_tests/recipes/Test_removeConstantOfShape.cpp b/unit_tests/recipes/Test_FoldConstantOfShape.cpp similarity index 56% rename from unit_tests/recipes/Test_removeConstantOfShape.cpp rename to unit_tests/recipes/Test_FoldConstantOfShape.cpp index b912efc640fc901f694afeda256be91d51010419..02bace3c44d208044936766ec3fb1d30334ffdf3 100644 --- a/unit_tests/recipes/Test_removeConstantOfShape.cpp +++ b/unit_tests/recipes/Test_FoldConstantOfShape.cpp @@ -13,38 +13,38 @@ #include "aidge/operator/Identity.hpp" #include "aidge/recipes/Recipes.hpp" -#include <cstddef> -#include <cstdint> +#include <cstdint> // std::int64_t #include <memory> -#include <vector> #include <catch2/catch_test_macros.hpp> #include "aidge/graph/OpArgs.hpp" -#include "aidge/operator/Add.hpp" #include "aidge/operator/ConstantOfShape.hpp" #include "aidge/operator/Conv.hpp" -#include "aidge/operator/MatMul.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/ReLU.hpp" +#include "aidge/recipes/Recipes.hpp" #include "aidge/utils/ArrayHelpers.hpp" #include "aidge/utils/Types.h" -using namespace Aidge; +namespace Aidge { -TEST_CASE("[cpu/recipes] removeConstantOfShape", - "[ConstantOfShape][removeConstantOfShape][recipes]") { - auto input_T = std::make_shared<Tensor>(Array1D<int64_t, 4>({1, 1, 3, 3})); +TEST_CASE("[cpu/recipes] foldConstantOfShape", + "[ConstantOfShape][foldConstantOfShape][recipes]") { + auto input_T = std::make_shared<Tensor>(Array1D<std::int64_t, 4>({1, 1, 3, 3})); auto model = std::make_shared<GraphView>(); SECTION("Sequential model") { - model = Sequential({Producer(input_T, "prod_0", true), - ConstantOfShape(3, "constantOfShape_0"), - Conv(1, 1, {3, 3}, "Conv_0"), ReLU("ReLU_1")}); - model->save("test_removeConstantOfShape_model_before_1"); - CHECK(removeConstantOfShape(model) == 1); - CHECK(model->forwardDims()); - model->save("test_removeConstantOfShape_model_after_1"); + model = Sequential({ + Producer(input_T, "prod_0", true), + ConstantOfShape(3, "constantOfShape_0"), + Conv(1, 1, {3, 3}, "Conv_0"), + ReLU("ReLU_1") + }); + model->save("test_foldConstantOfShape_model_before_1"); + // aidge_backend_cpu not loaded. Recipe should not work + REQUIRE(foldConstantOfShape(model) == 0); } } +} // namespace Aidge