Skip to content
Snippets Groups Projects
Commit 4aab5768 authored by Wissam Boussella's avatar Wissam Boussella
Browse files

Update Test_AdaptToBackend.cpp

parent 0a62325b
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !336. Comments created here will be created in the context of that merge request.
...@@ -9,100 +9,100 @@ ...@@ -9,100 +9,100 @@
* *
********************************************************************************/ ********************************************************************************/
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <set> #include <set>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge { namespace Aidge {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Create a dummy implementation // Create a dummy implementation
template <class Op> template <class Op>
class OperatorImpl_dummy : public OperatorImpl, class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>> public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
{ {
public: public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {} OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) { static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op); return std::make_unique<OperatorImpl_dummy<Op>>(op);
} }
virtual std::shared_ptr<ProdConso> getProdConso() const override { virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){ if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created "); fmt::println("no prod conso created ");
} }
return impl.prodConso(mOp); return impl.prodConso(mOp);
} }
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys(); std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end()); return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
} }
void forward() override { void forward() override {
fmt::println("forward: {}", mOp.type()); fmt::println("forward: {}", mOp.type());
} }
}; };
// Register it // Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>; using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy, REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs {{ // Inputs
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}}, {DataType::Any, DataFormat::Default}},
{ // Outputs { // Outputs
{DataType::Float32, DataFormat::NHWC}}}, {DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>; using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op; using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>; using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy, REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs {{ // Inputs
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}}, {DataType::Any, DataFormat::Default}},
{ // Outputs { // Outputs
{DataType::Float32, DataFormat::NHWC}}}, {DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create); REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy, REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}}, {{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>; using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy, REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}}, {{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create); REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void applyConstFold(std::shared_ptr<GraphView> &graphView) void applyConstFold(std::shared_ptr<GraphView> &graphView)
{ {
for (const std::shared_ptr<Node> node : graphView->getNodes()) for (const std::shared_ptr<Node> node : graphView->getNodes())
{ {
...@@ -115,45 +115,45 @@ ...@@ -115,45 +115,45 @@
constantFolding(graphView); constantFolding(graphView);
} }
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({ auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"), Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"), Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"), ReLU("relu1"),
Conv(4, 8, {3, 3}, "conv2"), Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"), ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3") Conv(8, 10, {1, 1}, "conv3")
}); });
REQUIRE(g1->forwardDims()); REQUIRE(g1->forwardDims());
g1->setBackend("dummy"); g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default); REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default); REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default); REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true); g1->save("adapttobackend_before", true);
adaptToBackend(g1); adaptToBackend(g1);
g1->save("adapttobackend_after", true); g1->save("adapttobackend_after", true);
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU"); auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU");
REQUIRE(matches.size() == 1); REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator()); convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator()); auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC); REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC); REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC); REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims() // TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
REQUIRE(g1->forwardDims()); REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true); g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1); SequentialScheduler sched(g1);
sched.forward(); sched.forward();
} }
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") { TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({ auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"), Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"), Conv(3, 4, {3, 3}, "conv1"),
...@@ -161,7 +161,7 @@ ...@@ -161,7 +161,7 @@
}); });
g1->forwardDims(); g1->forwardDims();
g1->setBackend("dummy"); g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op"); g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){ for( auto n : g1->getNodes()){
...@@ -184,10 +184,10 @@ ...@@ -184,10 +184,10 @@
REQUIRE_NOTHROW(sched.generateMemory()); REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward()); REQUIRE_NOTHROW(sched.forward());
FAIL_CHECK("This test is expected to fail due to known issues."); FAIL_CHECK("This test is expected to fail due to known issues.");
} }
// Interesting test because used a lot for export // Interesting test because used a lot for export
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") { TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({ auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"), Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"), Conv(3, 4, {3, 3}, "conv1"),
...@@ -195,7 +195,7 @@ ...@@ -195,7 +195,7 @@
}); });
g1->forwardDims(); g1->forwardDims();
g1->setBackend("dummy"); g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op"); g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){ for( auto n : g1->getNodes()){
...@@ -242,7 +242,7 @@ ...@@ -242,7 +242,7 @@
} }
REQUIRE(cpt == 2); REQUIRE(cpt == 2);
FAIL_CHECK("This test is expected to fail due to known issues."); FAIL_CHECK("This test is expected to fail due to known issues.");
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment