Skip to content
Snippets Groups Projects
Commit a9f87e40 authored by Grégoire Kubler's avatar Grégoire Kubler Committed by Olivier BICHLER
Browse files

feat : removeConstantOfShape

parent f5d51a18
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!195Feat operator constantofshape
Pipeline #54356 passed
...@@ -50,6 +50,13 @@ void matMulToFC(std::shared_ptr<GraphView> graphView); ...@@ -50,6 +50,13 @@ void matMulToFC(std::shared_ptr<GraphView> graphView);
*/ */
size_t removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers = false); size_t removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers = false);
/**
* @brief Fuses constant => Generic | constantOfShape and transforms it into a Producer
* @param graph Graph to manipulate
* @return size_t Number of replacement
*/
size_t removeConstantOfShape(std::shared_ptr<GraphView> graph_view);
/** /**
* @brief Remove ``Dropout`` Node. * @brief Remove ``Dropout`` Node.
* *
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <cstddef> #include <cstddef>
#include <string> #include <string>
#include "aidge/graph/GraphView.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -78,12 +79,12 @@ void init_Recipes(py::module &m) ...@@ -78,12 +79,12 @@ void init_Recipes(py::module &m)
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
// m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( m.def("remove_constantOfShape", static_cast<size_t(*)(std::shared_ptr<GraphView>)>(removeConstantOfShape), py::arg("graph_view"), R"mydelimiter(
// Recipe to remove a flatten operator. Fuses constant => Generic | constantOfShape and transforms it into a Producer
// :param nodes: The flatten operator to remove. :param graph_view: Graph view on which we want to apply the recipe.
// :type nodes: list of :py:class:`aidge_core.Node` :type graph_view: :py:class:`aidge_core.GraphView`
// )mydelimiter"); )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipe to remove a flatten operator. Recipe to remove a flatten operator.
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/recipes/Recipes.hpp"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <functional>
#include <memory>
#include <numeric>
#include <set>
#include <stdexcept>
#include <string>
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/ConstantOfShape.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge {
size_t removeConstantOfShape(std::shared_ptr<GraphView> graph_view) {
const auto matches =
SinglePassGraphMatching(graph_view).match("Producer->ConstantOfShape");
size_t nbReplaced = 0;
for (const auto &match : matches) {
const auto prod_node = match.graph->rootNode();
const auto prod_op =
std::static_pointer_cast<Producer_Op>(prod_node->getOperator());
const NodePtr constantofshape_node =
prod_node->getOrderedChildren().at(0).at(0);
const auto constantofshape_op =
std::static_pointer_cast<ConstantOfShape_Op>(
constantofshape_node->getOperator());
if (prod_op->getOutput(0)->nbDims() != 1) {
Log::debug("{} : Producer output dimension number is {} != 1 and {} "
"input has to have 1 dim, skipping match.",
__func__, prod_op->getOutput(0)->nbDims(),
ConstantOfShape_Op::Type);
continue;
}
if (!prod_op->constant()) {
Log::debug("{} : Producer is not constant, skipping match.", __func__);
continue;
}
if (prod_op->getOutput(0)->dataType() != DataType::Int64) {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"{} : Producer output dtype is {} != int64 and {} "
"input type is restricted to int64_t, this is an error."
"Fix your network. skipping match.",
__func__, prod_op->getOutput(0)->dataType(),
ConstantOfShape_Op::Type);
continue;
}
auto graph_to_replace = std::make_shared<GraphView>();
auto new_graph = std::make_shared<GraphView>();
graph_to_replace->add(constantofshape_node);
if (prod_node->getChildren().size() == 1) {
graph_to_replace->add(prod_node);
} else {
Log::debug("{} : Producer node has multiple children, only"
"replacing the {} node.",
__func__, ConstantOfShape_Op::Type);
}
prod_node->forward();
std::shared_ptr<Tensor> prod_output = prod_op->getOutput(0);
std::vector<DimSize_t> new_input_dims;
new_input_dims.reserve(prod_output->dims()[0]);
for (DimSize_t i = 0; i < prod_output->size(); ++i) {
new_input_dims.push_back(prod_output->get<int64_t>(i));
}
auto new_input = std::make_shared<Tensor>(new_input_dims);
new_input->setBackend(prod_op->backend() == "" ? "cpu"
: prod_op->backend());
new_input->setDataType(constantofshape_op->value().dataType());
for (std::size_t i = 0; i < new_input->size(); ++i) {
new_input->getImpl()->copy(
constantofshape_op->value().getImpl()->rawPtr(), 1, i);
}
auto new_prod =
Producer(new_input, prod_node->name() + "_constant_of_shape", true);
new_graph->add(new_prod);
const auto success = GraphView::replace(graph_to_replace, new_graph);
if (!success) {
Log::warn("Could not replace Producer({})->ConstantOfShape({}) with"
"Producer",
prod_node->name(), constantofshape_node->name());
} else {
++nbReplaced;
}
}
Log::info("Replaced {} (out of {}) matching Producer->ConstantOfShape with "
"Producers",
nbReplaced, matches.size());
return nbReplaced;
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/recipes/Recipes.hpp"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <vector>
#include <catch2/catch_test_macros.hpp>
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/ConstantOfShape.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/Types.h"
using namespace Aidge;
TEST_CASE("[cpu/recipies] removeConstantOfShape",
"[ConstantOfShape][removeConstantOfShape][recipies]") {
auto input_T = std::make_shared<Tensor>(Array1D<int64_t, 4>({1, 1, 3, 3}));
auto model = std::make_shared<GraphView>();
SECTION("Sequential model") {
model = Sequential({Producer(input_T, "prod_0", true),
ConstantOfShape(3, "constantOfShape_0"),
Conv(1, 1, {3, 3}, "Conv_0"), ReLU("ReLU_1")});
model->save("test_removeConstantOfShape_model_before_1");
CHECK(removeConstantOfShape(model) == 1);
CHECK(model->forwardDims());
model->save("test_removeConstantOfShape_model_after_1");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment