diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 86c722b158657633d4509c1181b1f18201d0d514..0fb405bfe5e74f159fbd5504cc199e3b29842254 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -180,6 +180,19 @@ size_t convToMatMul(std::shared_ptr<GraphView> graph); */ void adaptToBackend(std::shared_ptr<GraphView> graph); +// /** +// * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary. +// * This recipie only operates memory transformations on the weight tensor. +// * First, permutes the dimensions to match the dataformat NHWC +// * Second, compact the last dimension (Channel dimension) into int8_t +// * +// * @param node Node +// */ +// void applyWeightInterleaving(std::shared_ptr<Node> node); + + +void toGenericOp(std::shared_ptr<Node> node); + } // namespace Aidge #endif /* AIDGE_CORE_UTILS_RECIPES_H_ */ diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 77f20b9d655c6d9f6e95b23c4884bd1bc4f9ffd6..f656af70dfa05678875afd4b4748f358437852a8 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -144,6 +144,13 @@ void init_Recipes(py::module &m) :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); + + m.def("to_generic_op", toGenericOp, py::arg("node"), R"mydelimiter( + Transform to a Generic Operator. + + :param node: Node which Operator will turn into a Generic Operator + :type graph_view: :py:class:`aidge_core.Node` + )mydelimiter"); } } // namespace Aidge diff --git a/src/recipes/ToGenericOp.cpp b/src/recipes/ToGenericOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f151d8904a21ca49c96dbe089fdc96cd77e7501 --- /dev/null +++ b/src/recipes/ToGenericOp.cpp @@ -0,0 +1,23 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/recipes/Recipes.hpp" + +void Aidge::toGenericOp(std::shared_ptr<Node> node) { + auto newGenOp = {GenericOperator(node->type(), std::dynamic_pointer_cast<Aidge::OperatorTensor>(node->getOperator()), node->name())}; + auto OldOp = {node}; + GraphView::replace(OldOp, newGenOp); +} diff --git a/unit_tests/recipes/Test_ToGenericOp.cpp b/unit_tests/recipes/Test_ToGenericOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53ad86e7ce7dbe0b3eb499dc76166feff642835a --- /dev/null +++ b/unit_tests/recipes/Test_ToGenericOp.cpp @@ -0,0 +1,87 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <memory> +#include <set> +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/recipes/Recipes.hpp" + +namespace Aidge { + +TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") { + // Create a convolution operator + std::shared_ptr<GraphView> g = + Sequential({ + Conv(1, 3, {3, 3}, "conv1"), + ReLU(), + Conv(3, 4, {1, 1}, "conv2"), + ReLU(), + Conv(4, 3, {1, 1}, "conv3"), + ReLU(), + FC(2028, 256, false, "fc1"), + ReLU(), + FC(256, 10, false, "fc2")}); + + // NCHW - MNIST DATA like + g->forwardDims({{5, 1, 28, 28}}); + + SECTION("Test Operator to Generic Operator") { + auto convOp = g->getNode("conv2"); + + // Convert to GenericOperator + toGenericOp(convOp); + + auto newGenOp = g->getNode("conv2"); + + // Ensure the conversion + REQUIRE(newGenOp->type() == "Conv2D"); + + REQUIRE(newGenOp->getOperator()->attributes() == convOp->getOperator()->attributes()); + + } + + SECTION("Test MetaOperator to Generic Operator") { + + const auto nbFused = fuseToMetaOps(g, "Conv2D->ReLU->FC", "ConvReLUFC"); + + REQUIRE(nbFused == 1); + + std::shared_ptr<Node> MetaOpNode; + + for (const auto& nodePtr : g->getNodes()) + { + if (nodePtr->type() == "ConvReLUFC") + { + nodePtr->setName("ConvReLUFC_0"); + MetaOpNode = nodePtr; + // Convert to GenericOperator + toGenericOp(nodePtr); + } + } + + auto newGenOp = g->getNode("ConvReLUFC_0"); + + // Ensure the conversion + REQUIRE(newGenOp->type() == "ConvReLUFC"); + + REQUIRE(newGenOp->getOperator()->attributes() == MetaOpNode->getOperator()->attributes()); + + } + +} + +} // namespace Aidge