diff --git a/unit_tests/recipes/Test_AdaptToBackend.cpp b/unit_tests/recipes/Test_AdaptToBackend.cpp index 1238d1dc448f1d6cbf10245b03351c477d063141..2face8b91661f3a860fcf12ef74c1fe3b810398b 100644 --- a/unit_tests/recipes/Test_AdaptToBackend.cpp +++ b/unit_tests/recipes/Test_AdaptToBackend.cpp @@ -9,109 +9,240 @@ * ********************************************************************************/ -#include <catch2/catch_test_macros.hpp> -#include <set> - -#include "aidge/data/Tensor.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/OpArgs.hpp" -#include "aidge/operator/Conv.hpp" -#include "aidge/operator/ReLU.hpp" -#include "aidge/operator/Transpose.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/recipes/Recipes.hpp" -#include "aidge/scheduler/SequentialScheduler.hpp" - -namespace Aidge { - -//////////////////////////////////////////////////////////////////////////////// -// Create a dummy implementation -template <class Op> -class OperatorImpl_dummy : public OperatorImpl, - public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>> + #include <catch2/catch_test_macros.hpp> + #include <set> + + #include "aidge/data/Tensor.hpp" + #include "aidge/graph/GraphView.hpp" + #include "aidge/graph/OpArgs.hpp" + #include "aidge/operator/Conv.hpp" + #include "aidge/operator/ReLU.hpp" + #include "aidge/operator/Transpose.hpp" + #include "aidge/operator/Producer.hpp" + #include "aidge/recipes/Recipes.hpp" + #include "aidge/scheduler/SequentialScheduler.hpp" + #include "aidge/operator/MetaOperatorDefs.hpp" + + + namespace Aidge { + + //////////////////////////////////////////////////////////////////////////////// + // Create a dummy implementation + template <class Op> + class OperatorImpl_dummy : public OperatorImpl, + public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>> + { + public: + OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {} + + static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) { + return std::make_unique<OperatorImpl_dummy<Op>>(op); + } + + virtual std::shared_ptr<ProdConso> getProdConso() const override { + const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); + if (impl.prodConso(mOp)==nullptr){ + fmt::println("no prod conso created "); + } + return impl.prodConso(mOp); + } + + virtual std::vector<ImplSpec> getAvailableImplSpecs() const override { + std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys(); + return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end()); + } + + void forward() override { + fmt::println("forward: {}", mOp.type()); + } + }; + + // Register it + using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>; + REGISTRAR(Conv2D_Op_Impl_dummy, + {{ // Inputs + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::Default}}, + { // Outputs + {DataType::Float32, DataFormat::NHWC}}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); + + + using Conv2D_Op = Conv_Op<2>; + REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); + + using ConvRelu = MetaOperator_Op; + using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>; + REGISTRAR(ConvRelu_Op_Impl_dummy, + {{ // Inputs + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::Default}}, + { // Outputs + {DataType::Float32, DataFormat::NHWC}}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); + REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create); + + + using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; + REGISTRAR(ReLU_Op_Impl_dummy, + {{DataType::Any, DataFormat::Any}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); + + REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); + + using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>; + REGISTRAR(Transpose_Op_Impl_dummy, + {{DataType::Any, DataFormat::Any}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); + + REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create); + + REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); + //////////////////////////////////////////////////////////////////////////////// + + void applyConstFold(std::shared_ptr<GraphView> &graphView) { -public: - OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {} - - static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) { - return std::make_unique<OperatorImpl_dummy<Op>>(op); - } - - virtual std::shared_ptr<ProdConso> getProdConso() const override { - const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); - return impl.prodConso(mOp); + for (const std::shared_ptr<Node> node : graphView->getNodes()) + { + if (node->type() == "Producer" && node->name() != "dataProvider") + { + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + producer->constant() = true; + } } + constantFolding(graphView); +} - virtual std::vector<ImplSpec> getAvailableImplSpecs() const override { - std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys(); - return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end()); + TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { + auto g1 = Sequential({ + Producer({1, 3, 22, 22}, "dataProvider"), + Conv(3, 4, {3, 3}, "conv1"), + ReLU("relu1"), + Conv(4, 8, {3, 3}, "conv2"), + ReLU("relu2"), + Conv(8, 10, {1, 1}, "conv3") + }); + REQUIRE(g1->forwardDims()); + + g1->setBackend("dummy"); + auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); + REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default); + REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default); + REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default); + + g1->save("adapttobackend_before", true); + adaptToBackend(g1); + g1->save("adapttobackend_after", true); + + auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU"); + REQUIRE(matches.size() == 1); + convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator()); + auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator()); + REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC); + REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC); + REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC); + REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); + + // TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims() + REQUIRE(g1->forwardDims()); + g1->save("adapttobackend_after_forwarddims", true); + + SequentialScheduler sched(g1); + sched.forward(); + } + + TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") { + auto g1 = Sequential({ + Producer({1, 3, 22, 22}, "dataProvider"), + Conv(3, 4, {3, 3}, "conv1"), + ReLU("relu1") + }); + g1->forwardDims(); + g1->setBackend("dummy"); + + fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); + g1->save("fuse_meta_op"); + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); } - - void forward() override { - fmt::println("forward: {}", mOp.type()); + adaptToBackend(g1); + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); + if (n->type() == "ConvReLU"){ + auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); + fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends()); + REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC); + } } -}; - -// Register it -using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>; -REGISTRAR(Conv2D_Op_Impl_dummy, - {{ // Inputs - {DataType::Any, DataFormat::NHWC}, - {DataType::Any, DataFormat::NHWC}, - {DataType::Any, DataFormat::Default}}, - { // Outputs - {DataType::Float32, DataFormat::NHWC}}}, - {ProdConso::inPlaceModel, nullptr, nullptr}); - -using Conv2D_Op = Conv_Op<2>; -REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); - -using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; -REGISTRAR(ReLU_Op_Impl_dummy, - {{DataType::Any, DataFormat::Default}}, - {ProdConso::inPlaceModel, nullptr, nullptr}); - -REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); - -REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); -//////////////////////////////////////////////////////////////////////////////// - - -TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { + g1->save("adapt_to_backend"); + SequentialScheduler sched(g1); + REQUIRE_NOTHROW(sched.generateScheduling()); + REQUIRE_NOTHROW(sched.generateMemory()); + REQUIRE_NOTHROW(sched.forward()); + } + +// Interesting test because used a lot for export + TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") { auto g1 = Sequential({ Producer({1, 3, 22, 22}, "dataProvider"), Conv(3, 4, {3, 3}, "conv1"), - ReLU("relu1"), - Conv(4, 8, {3, 3}, "conv2"), - ReLU("relu2"), - Conv(8, 10, {1, 1}, "conv3") + ReLU("relu1") }); - + g1->forwardDims(); g1->setBackend("dummy"); - auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); - REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default); - REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default); - REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default); - - g1->save("adapttobackend_before", true); + + fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); + g1->save("fuse_meta_op"); + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); + + if (n->type() == "ConvReLU"){ + auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); + } + } adaptToBackend(g1); - g1->save("adapttobackend_after", true); - - auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU"); - REQUIRE(matches.size() == 1); - convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator()); - auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator()); - REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC); - REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC); - REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC); - REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); - - // TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims() - //REQUIRE(g1->forwardDims()); - //g1->save("adapttobackend_after_forwarddims", true); - - //SequentialScheduler sched(g1); - //sched.forward(); -} - -} // namespace Aidge + for( auto n : g1->getNodes()){ + n->setName(n->createUniqueName("n")); + if (n->type() == "ConvReLU"){ + auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator()); + fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends()); + REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC); + REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC); + } + } + g1->forwardDims({{1, 3, 3, 3}}); + + for (const std::shared_ptr<Node> node : g1->getNodes()) + { + if (node->type() == "Producer" && node->name() != "dataProvider") + { + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + producer->constant() = true; + } + } + applyConstFold(g1); + g1->save("constant_folding_2"); + + SequentialScheduler sched(g1); + REQUIRE_NOTHROW(sched.generateScheduling()); + REQUIRE_NOTHROW(sched.generateMemory()); + REQUIRE_NOTHROW(sched.forward()); + + unsigned cpt = 0; + for( auto n : g1->getNodes()){ + if (n->type() == "Transpose"){ + cpt++; + } + } + REQUIRE(cpt == 2); + } + + + + } // namespace Aidge + \ No newline at end of file