diff --git a/unit_tests/recipes/Test_AdaptToBackend.cpp b/unit_tests/recipes/Test_AdaptToBackend.cpp index 1238d1dc448f1d6cbf10245b03351c477d063141..34b6aa671bc92ab1a9c9376ab6315a380c66bb2a 100644 --- a/unit_tests/recipes/Test_AdaptToBackend.cpp +++ b/unit_tests/recipes/Test_AdaptToBackend.cpp @@ -21,6 +21,8 @@ #include "aidge/operator/Producer.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" + namespace Aidge { @@ -39,6 +41,9 @@ public: virtual std::shared_ptr<ProdConso> getProdConso() const override { const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); + if (impl.prodConso(mOp)==nullptr){ + fmt::println("no prod conso created "); + } return impl.prodConso(mOp); } @@ -63,19 +68,52 @@ REGISTRAR(Conv2D_Op_Impl_dummy, {DataType::Float32, DataFormat::NHWC}}}, {ProdConso::inPlaceModel, nullptr, nullptr}); + using Conv2D_Op = Conv_Op<2>; REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); +using ConvRelu = MetaOperator_Op; +using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>; +REGISTRAR(ConvRelu_Op_Impl_dummy, + {{ // Inputs + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::Default}}, + { // Outputs + {DataType::Float32, DataFormat::NHWC}}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); +REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create); + + using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; REGISTRAR(ReLU_Op_Impl_dummy, - {{DataType::Any, DataFormat::Default}}, + {{DataType::Any, DataFormat::Any}}, {ProdConso::inPlaceModel, nullptr, nullptr}); REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); -REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); -//////////////////////////////////////////////////////////////////////////////// +using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>; +REGISTRAR(Transpose_Op_Impl_dummy, + {{DataType::Any, DataFormat::Any}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); +REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create); + +REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); + //////////////////////////////////////////////////////////////////////////////// + +void applyConstFold(std::shared_ptr<GraphView> &graphView) +{ + for (const std::shared_ptr<Node> node : graphView->getNodes()) + { + if (node->type() == "Producer" && node->name() != "dataProvider") + { + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + producer->constant() = true; + } + } + constantFolding(graphView); +} TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { auto g1 = Sequential({ @@ -86,6 +124,7 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { ReLU("relu2"), Conv(8, 10, {1, 1}, "conv3") }); + REQUIRE(g1->forwardDims()); g1->setBackend("dummy"); auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); @@ -107,11 +146,103 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); // TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims() - //REQUIRE(g1->forwardDims()); - //g1->save("adapttobackend_after_forwarddims", true); + REQUIRE(g1->forwardDims()); + g1->save("adapttobackend_after_forwarddims", true); + + SequentialScheduler sched(g1); + sched.forward(); +} + +TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") { + auto g1 = Sequential({ + Producer({1, 3, 22, 22}, "dataProvider"), + Conv(3, 4, {3, 3}, "conv1"), + ReLU("relu1") + }); + g1->forwardDims(); + g1->setBackend("dummy"); + + fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); + g1->save("fuse_meta_op"); + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); + } + adaptToBackend(g1); + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); + if (n->type() == "ConvReLU"){ + auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); + fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends()); + REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC); + } + } + g1->save("adapt_to_backend"); + SequentialScheduler sched(g1); + REQUIRE_NOTHROW(sched.generateScheduling()); + REQUIRE_NOTHROW(sched.generateMemory()); + REQUIRE_NOTHROW(sched.forward()); + FAIL_CHECK("This test is expected to fail due to known issues."); +} + +// Interesting test because used a lot for export +TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") { + auto g1 = Sequential({ + Producer({1, 3, 22, 22}, "dataProvider"), + Conv(3, 4, {3, 3}, "conv1"), + ReLU("relu1") + }); + g1->forwardDims(); + g1->setBackend("dummy"); + + fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); + g1->save("fuse_meta_op"); + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); - //SequentialScheduler sched(g1); - //sched.forward(); + if (n->type() == "ConvReLU"){ + auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); + } + } + adaptToBackend(g1); + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); + if (n->type() == "ConvReLU"){ + auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); + fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends()); + REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC); + } + } + g1->forwardDims({{1, 3, 3, 3}}); + + for (const std::shared_ptr<Node> node : g1->getNodes()) + { + if (node->type() == "Producer" && node->name() != "dataProvider") + { + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + producer->constant() = true; + } + } + applyConstFold(g1); + g1->save("constant_folding_2"); + + SequentialScheduler sched(g1); + REQUIRE_NOTHROW(sched.generateScheduling()); + REQUIRE_NOTHROW(sched.generateMemory()); + REQUIRE_NOTHROW(sched.forward()); + + unsigned cpt = 0; + for( auto n : g1->getNodes()){ + if (n->type() == "Transpose"){ + cpt++; + } + } + REQUIRE(cpt == 2); + FAIL_CHECK("This test is expected to fail due to known issues."); } } // namespace Aidge + \ No newline at end of file