diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index 500e1f9a3092bf232b63b23dc208cb9892f83bba..28dab05f80a64f12a59dc1f684652f66a96dc95f 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -76,17 +76,13 @@ struct Registrar { } static auto create(const registrar_key& key) { - AIDGE_ASSERT(exists(key), "missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name()); return C::registry().at(key); } static std::set<registrar_key> getKeys(){ std::set<registrar_key> keys; - for(const auto& keyValue : C::registry()){ - + for(const auto& keyValue : C::registry()) keys.insert(keyValue.first); - } - return keys; } }; diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 9f28667de532a5ba013d03f11a19684efd925ea1..21478a5b14d609801f232b20cda25e7e1c0d9475 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -24,11 +24,6 @@ namespace py = pybind11; namespace Aidge { void init_Recipes(py::module &m) { - m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>)>(constantFolding), py::arg("graph_view"), R"mydelimiter( - Recipe to optimize graphview by repeatedly identifying nodes with constant inputs, executes them immediately, and replaces them with pre-computed constant. - :param graph_view: Graph view on which we want to apply the recipe - :type graph_view: :py:class:`aidge_core.GraphView` - )mydelimiter"); m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 1959c23f189efde43d867118f194979ba1be2648..08f5fe671c7502a6c5fe01dbdfb7ae4c9b95ac81 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -74,6 +74,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims}); } + const auto& inhAttrs = mOp.inheritedAttributes(); if (inhAttrs) { if (inhAttrs->hasAttr("impl")) { @@ -88,7 +89,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) const auto availableSpecsSet = getAvailableImplSpecs(); AIDGE_ASSERT(availableSpecsSet.size() > 0 , - "OperatorImpl::getBestMatch():o available specs found by" + "OperatorImpl::getBestMatch(): No available specs found by" "getAvailableSpecs(). " "Cannot find best implementation for required specs, aborting."); const std::vector<ImplSpec> availableSpecs(availableSpecsSet.begin(), availableSpecsSet.end()); @@ -234,7 +235,6 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& auto op = std::static_pointer_cast<OperatorTensor>(mOp.clone()); auto node = std::make_shared<Node>(op); - // node->getOperator()->setImpl(getOperator().getImpl()); auto adaptedGraph = std::make_shared<GraphView>(); adaptedGraph->add(node); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index d150f8ba2781f51e4d82e2d6c1263ef65919a03a..315844858103cbce91049ec2195ff0a3bd7a9d81 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -1678,7 +1678,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone } } newGraph->setOrderedOutputs(newOutputNodes); - + return newGraph; } diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 59817b7d90671427d0e986a457b976912748a3e2..bd09e9d1297ec612b08634f59bfe33f0802ef3fd 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -54,7 +54,6 @@ Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) cons Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type()); - AIDGE_ASSERT(mImpl->prodConso() != nullptr, "getNbProducedData(): no prod consumer for {}!", mImpl->getOperator().type()); return mImpl->prodConso()->getNbProducedData(outputIdx); } void Aidge::Operator::updateConsummerProducer(){ diff --git a/unit_tests/CMakeLists.txt b/unit_tests/CMakeLists.txt index 6622e76f4e82598f24243f937e6ed083f2468865..cf7e896da87c9b08c08325874e78ad7957ec27da 100644 --- a/unit_tests/CMakeLists.txt +++ b/unit_tests/CMakeLists.txt @@ -11,7 +11,7 @@ if(NOT Catch2_FOUND) FetchContent_Declare( Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2.git - GIT_TAG v3.7.1 # or a later release + GIT_TAG devel # or a later release ) FetchContent_MakeAvailable(Catch2) message(STATUS "Fetched Catch2 version ${Catch2_VERSION}") diff --git a/unit_tests/recipes/Test_AdaptToBackend.cpp b/unit_tests/recipes/Test_AdaptToBackend.cpp index 08c3de04be6b0e29dc3d0bfda097e8897dee1314..1238d1dc448f1d6cbf10245b03351c477d063141 100644 --- a/unit_tests/recipes/Test_AdaptToBackend.cpp +++ b/unit_tests/recipes/Test_AdaptToBackend.cpp @@ -21,8 +21,6 @@ #include "aidge/operator/Producer.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" -#include "aidge/operator/MetaOperatorDefs.hpp" - namespace Aidge { @@ -41,9 +39,6 @@ public: virtual std::shared_ptr<ProdConso> getProdConso() const override { const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); - if (impl.prodConso(mOp)==nullptr){ - fmt::println("no prod conso created "); - } return impl.prodConso(mOp); } @@ -68,37 +63,16 @@ REGISTRAR(Conv2D_Op_Impl_dummy, {DataType::Float32, DataFormat::NHWC}}}, {ProdConso::inPlaceModel, nullptr, nullptr}); - using Conv2D_Op = Conv_Op<2>; REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); -using ConvRelu = MetaOperator_Op; -using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>; -REGISTRAR(ConvRelu_Op_Impl_dummy, - {{ // Inputs - {DataType::Any, DataFormat::NHWC}, - {DataType::Any, DataFormat::NHWC}, - {DataType::Any, DataFormat::Default}}, - { // Outputs - {DataType::Float32, DataFormat::NHWC}}}, - {ProdConso::inPlaceModel, nullptr, nullptr}); -REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create); - - using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; REGISTRAR(ReLU_Op_Impl_dummy, - {{DataType::Any, DataFormat::Any}}, + {{DataType::Any, DataFormat::Default}}, {ProdConso::inPlaceModel, nullptr, nullptr}); REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); -using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>; -REGISTRAR(Transpose_Op_Impl_dummy, - {{DataType::Any, DataFormat::Any}}, - {ProdConso::inPlaceModel, nullptr, nullptr}); - -REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create); - REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); //////////////////////////////////////////////////////////////////////////////// @@ -112,7 +86,6 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { ReLU("relu2"), Conv(8, 10, {1, 1}, "conv3") }); - REQUIRE(g1->forwardDims()); g1->setBackend("dummy"); auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); @@ -134,52 +107,11 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); // TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims() - REQUIRE(g1->forwardDims()); - g1->save("adapttobackend_after_forwarddims", true); + //REQUIRE(g1->forwardDims()); + //g1->save("adapttobackend_after_forwarddims", true); - SequentialScheduler sched(g1); - sched.forward(); + //SequentialScheduler sched(g1); + //sched.forward(); } -TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") { - auto g1 = Sequential({ - Producer({1, 3, 22, 22}, "dataProvider"), - Conv(3, 4, {3, 3}, "conv1"), - ReLU("relu1") - }); - REQUIRE(g1->forwardDims()); - g1->setBackend("dummy"); - - fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); - g1->save("fuse_meta_op"); - for( auto n : g1->getNodes()){ - n->setName(n->createUniqueName("n")); - - if (n->type() == "ConvReLU"){ - auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); - } - } - adaptToBackend(g1); - for( auto n : g1->getNodes()){ - n->setName(n->createUniqueName("n")); - - if (n->type() == "ConvReLU"){ - auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); - fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends()); - REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC); - REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC); - REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC); - } - } - g1->save("adapt_to_backend"); - // constantFolding(g1); - SequentialScheduler sched(g1); - sched.generateScheduling(); - sched.generateMemory(); - sched.forward(); - -} - - - } // namespace Aidge diff --git a/unit_tests/recipes/Test_ExplicitTranspose.cpp b/unit_tests/recipes/Test_ExplicitTranspose.cpp index 09711768ca64a6a40ede66cdcc983c5980012926..bb89ba7952347a779e6979e7cf3c4f1bd68abf9b 100644 --- a/unit_tests/recipes/Test_ExplicitTranspose.cpp +++ b/unit_tests/recipes/Test_ExplicitTranspose.cpp @@ -19,21 +19,20 @@ using namespace Aidge; -TEST_CASE("[ExplicitTranspose] ExplicitTranspose conv") { +TEST_CASE("[ExplicitTranspose] conv") { auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(64, 10, {1, 1}, "conv3", {2, 2}); - auto producer1 = Producer({16, 3, 224, 224}, "dataProvider"); - conv2->getOperator()->setDataFormat(DataFormat::NHWC); auto g1 = Sequential({ - producer1, + Producer({16, 3, 224, 224}, "dataProvider"), conv1, conv2, conv3 }); g1->setDataFormat(DataFormat::NCHW); + conv2->getOperator()->setDataFormat(DataFormat::NHWC); g1->save("explicitTranspose_before"); REQUIRE(g1->getNodes().size() == 10); @@ -42,13 +41,13 @@ TEST_CASE("[ExplicitTranspose] ExplicitTranspose conv") { g1->forwardDims(); explicitTranspose(g1); - // // Check that Transpose were inserted - // g1->save("explicitTranspose_after"); - // REQUIRE(g1->getNodes().size() == 12); + // Check that Transpose were inserted + g1->save("explicitTranspose_after"); + REQUIRE(g1->getNodes().size() == 12); - // // Check that Transpose are removed - // conv2->getOperator()->setDataFormat(DataFormat::NCHW); - // explicitTranspose(g1); + // Check that Transpose are removed + conv2->getOperator()->setDataFormat(DataFormat::NCHW); + explicitTranspose(g1); REQUIRE(g1->getNodes().size() == 10); REQUIRE(g1->getNodes() == initialNodes);