Skip to content
Snippets Groups Projects
Commit ea6ac3c2 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Add toGenericOp recipie + bind + test

parent 170c5a23
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!272[Add] Possibility to create a GenericOperator from any Operator
Pipeline #60689 failed
......@@ -180,6 +180,19 @@ size_t convToMatMul(std::shared_ptr<GraphView> graph);
*/
void adaptToBackend(std::shared_ptr<GraphView> graph);
// /**
// * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary.
// * This recipie only operates memory transformations on the weight tensor.
// * First, permutes the dimensions to match the dataformat NHWC
// * Second, compact the last dimension (Channel dimension) into int8_t
// *
// * @param node Node
// */
// void applyWeightInterleaving(std::shared_ptr<Node> node);
void toGenericOp(std::shared_ptr<Node> node);
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_RECIPES_H_ */
......@@ -144,6 +144,13 @@ void init_Recipes(py::module &m)
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("to_generic_op", toGenericOp, py::arg("node"), R"mydelimiter(
Transform to a Generic Operator.
:param node: Node which Operator will turn into a Generic Operator
:type graph_view: :py:class:`aidge_core.Node`
)mydelimiter");
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/recipes/Recipes.hpp"
void Aidge::toGenericOp(std::shared_ptr<Node> node) {
auto newGenOp = {GenericOperator(node->type(), std::dynamic_pointer_cast<Aidge::OperatorTensor>(node->getOperator()), node->name())};
auto OldOp = {node};
GraphView::replace(OldOp, newGenOp);
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <memory>
#include <set>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/recipes/Recipes.hpp"
namespace Aidge {
TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") {
// Create a convolution operator
std::shared_ptr<GraphView> g =
Sequential({
Conv(1, 3, {3, 3}, "conv1"),
ReLU(),
Conv(3, 4, {1, 1}, "conv2"),
ReLU(),
Conv(4, 3, {1, 1}, "conv3"),
ReLU(),
FC(2028, 256, false, "fc1"),
ReLU(),
FC(256, 10, false, "fc2")});
// NCHW - MNIST DATA like
g->forwardDims({{5, 1, 28, 28}});
SECTION("Test Operator to Generic Operator") {
auto convOp = g->getNode("conv2");
// Convert to GenericOperator
toGenericOp(convOp);
auto newGenOp = g->getNode("conv2");
// Ensure the conversion
REQUIRE(newGenOp->type() == "Conv2D");
REQUIRE(newGenOp->getOperator()->attributes() == convOp->getOperator()->attributes());
}
SECTION("Test MetaOperator to Generic Operator") {
const auto nbFused = fuseToMetaOps(g, "Conv2D->ReLU->FC", "ConvReLUFC");
REQUIRE(nbFused == 1);
std::shared_ptr<Node> MetaOpNode;
for (const auto& nodePtr : g->getNodes())
{
if (nodePtr->type() == "ConvReLUFC")
{
nodePtr->setName("ConvReLUFC_0");
MetaOpNode = nodePtr;
// Convert to GenericOperator
toGenericOp(nodePtr);
}
}
auto newGenOp = g->getNode("ConvReLUFC_0");
// Ensure the conversion
REQUIRE(newGenOp->type() == "ConvReLUFC");
REQUIRE(newGenOp->getOperator()->attributes() == MetaOpNode->getOperator()->attributes());
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment