Skip to content
Snippets Groups Projects
Commit 3e1a526d authored by Maxence Naud's avatar Maxence Naud
Browse files

Update 'removeConstOfShape' recipe

- split the function in two for more comprehesive structure
- change name from 'removeConstantOFShape' to 'foldConstantOfShape' since it describe way better what the function does
- enhance function structure, Producer has got only one child Node by construction. No need to check
parent 9a6798c4
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!357[upd] 'removeConstantOfShape' recipe
Pipeline #66481 failed
...@@ -47,7 +47,7 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { ...@@ -47,7 +47,7 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){ if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){
AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&
(getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()), (getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()),
"Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels()); "Wrong input channel size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels());
} }
else{ //For dataFormat in NCHW or Default Format else{ //For dataFormat in NCHW or Default Format
AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/recipes/Recipes.hpp"
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <memory>
#include <stdexcept> // std::runtime_error
#include "aidge/data/DataType.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/ConstantOfShape.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
static bool foldIndividualConstantOfShape(const SinglePassGraphMatching::MatchingResult match, std::size_t namingId) {
const std::shared_ptr<Node> prod_node = match.graph->rootNode();
const std::shared_ptr<Producer_Op> prod_op =
std::static_pointer_cast<Producer_Op>(prod_node->getOperator());
prod_op->forward(); // is this REALLY needed if Producer is constant?
const std::shared_ptr<Tensor> shape = prod_op->getOutput(0);
if (!prod_op->constant()) {
Log::debug("{} - Producer is not constant. Skipping match.", __func__);
return false;
}
if (shape->nbDims() != 1) {
Log::debug("[{}] - ConstantOfShape 'shape' input Tensor has {} != 1 dimensions. Skipping match.",
__func__, shape->nbDims());
return false;
}
if (shape->dataType() != DataType::Int64) {
Log::debug(
"Producer output data type is {} != {} required for ConstantOfShape "
"'shape' input. "
"Skipping match.",
prod_op->getOutput(0)->dataType(),
DataType::Int64
);
return false;
}
const std::shared_ptr<Node> constantofshape_node =
prod_node->getOrderedChildren().at(0).at(0);
const std::shared_ptr<ConstantOfShape_Op> constantofshape_op =
std::static_pointer_cast<ConstantOfShape_Op>(
constantofshape_node->getOperator());
std::shared_ptr<GraphView> graph_to_replace = std::make_shared<GraphView>();
graph_to_replace->add({constantofshape_node, prod_node});
constantofshape_op->forwardDims(true); // ConstantOfShape forwardDims is data dependent
std::string original_backend = constantofshape_op->backend().empty() ? "cpu" : constantofshape_op->backend();
graph_to_replace->setBackend("cpu"); // set backend to 'cpu' since speed is not the main focus here
constantofshape_op->forward();
const std::shared_ptr<Tensor> newInputTensor = constantofshape_op->getOutput(0);
newInputTensor->setBackend(original_backend); // set back original backend
const std::shared_ptr<Node> newProducer =
Producer(newInputTensor,
"constantOfShape_" + constantofshape_node->name() + "_folded_" + std::to_string(namingId),
true // remains constant
);
std::shared_ptr<GraphView> new_graph = std::make_shared<GraphView>();
new_graph->add(newProducer);
return GraphView::replace(graph_to_replace, new_graph);
}
std::size_t foldConstantOfShape(std::shared_ptr<GraphView> view) {
// this query guarantes that any Producer part of a returned match is ONLY followed by a ConstantOfShape
const auto matches = SinglePassGraphMatching(view).match("Producer->ConstantOfShape");
std::size_t nbReplaced = 0;
if (!Registrar<ConstantOfShape_Op>::exists("cpu")) {
Log::error("'cpu' backend not loaded. Impossible to run and fold any ConstantOfShape Operator.");
} else {
for (const auto &match : matches) {
if (!foldIndividualConstantOfShape(match, nbReplaced)) {
// TODO: this warning should be handled by GraphView::replace
Log::warn("Could not replace match with Producer");
} else {
++nbReplaced;
}
}
}
Log::info("Removed [\033[1m\033[3m{}/{}\033[0m] ConstantOfShape Nodes",
nbReplaced, matches.size());
return nbReplaced;
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/recipes/Recipes.hpp"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <functional>
#include <memory>
#include <numeric>
#include <set>
#include <stdexcept>
#include <string>
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/ConstantOfShape.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
size_t removeConstantOfShape(std::shared_ptr<GraphView> graph_view) {
const auto matches =
SinglePassGraphMatching(graph_view).match("Producer->ConstantOfShape");
size_t nbReplaced = 0;
for (const auto &match : matches) {
const auto prod_node = match.graph->rootNode();
const auto prod_op =
std::static_pointer_cast<Producer_Op>(prod_node->getOperator());
const NodePtr constantofshape_node =
prod_node->getOrderedChildren().at(0).at(0);
const auto constantofshape_op =
std::static_pointer_cast<ConstantOfShape_Op>(
constantofshape_node->getOperator());
if (prod_op->getOutput(0)->nbDims() != 1) {
Log::debug("{} : Producer output dimension number is {} != 1 and {} "
"input has to have 1 dim, skipping match.",
__func__, prod_op->getOutput(0)->nbDims(),
ConstantOfShape_Op::Type);
continue;
}
if (!prod_op->constant()) {
Log::debug("{} : Producer is not constant, skipping match.", __func__);
continue;
}
if (prod_op->getOutput(0)->dataType() != DataType::Int64) {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"{} : Producer output dtype is {} != int64 and {} "
"input type is restricted to int64_t, this is an error."
"Fix your network. skipping match.",
__func__, prod_op->getOutput(0)->dataType(),
ConstantOfShape_Op::Type);
continue;
}
auto graph_to_replace = std::make_shared<GraphView>();
auto new_graph = std::make_shared<GraphView>();
graph_to_replace->add(constantofshape_node);
if (prod_node->getChildren().size() == 1) {
graph_to_replace->add(prod_node);
} else {
Log::debug("{} : Producer node has multiple children, only"
"replacing the {} node.",
__func__, ConstantOfShape_Op::Type);
}
prod_node->forward();
std::shared_ptr<Tensor> prod_output = prod_op->getOutput(0);
std::vector<DimSize_t> new_input_dims;
new_input_dims.reserve(prod_output->dims()[0]);
for (DimSize_t i = 0; i < prod_output->size(); ++i) {
new_input_dims.push_back(prod_output->get<int64_t>(i));
}
auto new_input = std::make_shared<Tensor>(new_input_dims);
new_input->setBackend(prod_op->backend() == "" ? "cpu"
: prod_op->backend());
new_input->setDataType(constantofshape_op->value().dataType());
for (std::size_t i = 0; i < new_input->size(); ++i) {
new_input->getImpl()->copy(
constantofshape_op->value().getImpl()->rawPtr(), 1, i);
}
auto new_prod =
Producer(new_input, prod_node->name() + "_constant_of_shape", true);
new_graph->add(new_prod);
const auto success = GraphView::replace(graph_to_replace, new_graph);
if (!success) {
Log::warn("Could not replace Producer({})->ConstantOfShape({}) with"
"Producer",
prod_node->name(), constantofshape_node->name());
} else {
++nbReplaced;
}
}
Log::info("Replaced {} (out of {}) matching Producer->ConstantOfShape with "
"Producers",
nbReplaced, matches.size());
return nbReplaced;
}
} // namespace Aidge
...@@ -13,38 +13,38 @@ ...@@ -13,38 +13,38 @@
#include "aidge/operator/Identity.hpp" #include "aidge/operator/Identity.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include <cstddef> #include <cstdint> // std::int64_t
#include <cstdint>
#include <memory> #include <memory>
#include <vector>
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/ConstantOfShape.hpp" #include "aidge/operator/ConstantOfShape.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/ArrayHelpers.hpp" #include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
using namespace Aidge; namespace Aidge {
TEST_CASE("[cpu/recipes] removeConstantOfShape", TEST_CASE("[cpu/recipes] foldConstantOfShape",
"[ConstantOfShape][removeConstantOfShape][recipes]") { "[ConstantOfShape][foldConstantOfShape][recipes]") {
auto input_T = std::make_shared<Tensor>(Array1D<int64_t, 4>({1, 1, 3, 3})); auto input_T = std::make_shared<Tensor>(Array1D<std::int64_t, 4>({1, 1, 3, 3}));
auto model = std::make_shared<GraphView>(); auto model = std::make_shared<GraphView>();
SECTION("Sequential model") { SECTION("Sequential model") {
model = Sequential({Producer(input_T, "prod_0", true), model = Sequential({
ConstantOfShape(3, "constantOfShape_0"), Producer(input_T, "prod_0", true),
Conv(1, 1, {3, 3}, "Conv_0"), ReLU("ReLU_1")}); ConstantOfShape(3, "constantOfShape_0"),
model->save("test_removeConstantOfShape_model_before_1"); Conv(1, 1, {3, 3}, "Conv_0"),
CHECK(removeConstantOfShape(model) == 1); ReLU("ReLU_1")
CHECK(model->forwardDims()); });
model->save("test_removeConstantOfShape_model_after_1"); model->save("test_foldConstantOfShape_model_before_1");
// aidge_backend_cpu not loaded. Recipe should not work
REQUIRE(foldConstantOfShape(model) == 0);
} }
} }
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment