Skip to content
Snippets Groups Projects
Commit c3a2fb53 authored by Wissam Boussella's avatar Wissam Boussella Committed by Maxence Naud
Browse files

Update Test_AdaptToBackend.cpp

parent 67d57c1f
No related branches found
No related tags found
1 merge request!336Draft: Tests adapt to backend
Pipeline #65945 failed
This commit is part of merge request !336. Comments created here will be created in the context of that merge request.
...@@ -9,100 +9,100 @@ ...@@ -9,100 +9,100 @@
* *
********************************************************************************/ ********************************************************************************/
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <set> #include <set>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge { namespace Aidge {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Create a dummy implementation // Create a dummy implementation
template <class Op> template <class Op>
class OperatorImpl_dummy : public OperatorImpl, class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>> public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
{ {
public: public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {} OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) { static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op); return std::make_unique<OperatorImpl_dummy<Op>>(op);
} }
virtual std::shared_ptr<ProdConso> getProdConso() const override { virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){ if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created "); fmt::println("no prod conso created ");
} }
return impl.prodConso(mOp); return impl.prodConso(mOp);
} }
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys(); std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end()); return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
} }
void forward() override { void forward() override {
fmt::println("forward: {}", mOp.type()); fmt::println("forward: {}", mOp.type());
} }
}; };
// Register it // Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>; using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy, REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs {{ // Inputs
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}}, {DataType::Any, DataFormat::Default}},
{ // Outputs { // Outputs
{DataType::Float32, DataFormat::NHWC}}}, {DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>; using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op; using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>; using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy, REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs {{ // Inputs
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC}, {DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}}, {DataType::Any, DataFormat::Default}},
{ // Outputs { // Outputs
{DataType::Float32, DataFormat::NHWC}}}, {DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create); REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy, REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}}, {{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>; using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy, REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}}, {{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create); REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void applyConstFold(std::shared_ptr<GraphView> &graphView) void applyConstFold(std::shared_ptr<GraphView> &graphView)
{ {
for (const std::shared_ptr<Node> node : graphView->getNodes()) for (const std::shared_ptr<Node> node : graphView->getNodes())
{ {
...@@ -115,45 +115,45 @@ ...@@ -115,45 +115,45 @@
constantFolding(graphView); constantFolding(graphView);
} }
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({ auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"), Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"), Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"), ReLU("relu1"),
Conv(4, 8, {3, 3}, "conv2"), Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"), ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3") Conv(8, 10, {1, 1}, "conv3")
}); });
REQUIRE(g1->forwardDims()); REQUIRE(g1->forwardDims());
g1->setBackend("dummy"); g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default); REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default); REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default); REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true); g1->save("adapttobackend_before", true);
adaptToBackend(g1); adaptToBackend(g1);
g1->save("adapttobackend_after", true); g1->save("adapttobackend_after", true);
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU"); auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU");
REQUIRE(matches.size() == 1); REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator()); convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator()); auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC); REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC); REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC); REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims() // TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
REQUIRE(g1->forwardDims()); REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true); g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1); SequentialScheduler sched(g1);
sched.forward(); sched.forward();
} }
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") { TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({ auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"), Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"), Conv(3, 4, {3, 3}, "conv1"),
...@@ -161,7 +161,7 @@ ...@@ -161,7 +161,7 @@
}); });
g1->forwardDims(); g1->forwardDims();
g1->setBackend("dummy"); g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op"); g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){ for( auto n : g1->getNodes()){
...@@ -184,10 +184,10 @@ ...@@ -184,10 +184,10 @@
REQUIRE_NOTHROW(sched.generateMemory()); REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward()); REQUIRE_NOTHROW(sched.forward());
FAIL_CHECK("This test is expected to fail due to known issues."); FAIL_CHECK("This test is expected to fail due to known issues.");
} }
// Interesting test because used a lot for export // Interesting test because used a lot for export
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") { TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({ auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"), Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"), Conv(3, 4, {3, 3}, "conv1"),
...@@ -195,7 +195,7 @@ ...@@ -195,7 +195,7 @@
}); });
g1->forwardDims(); g1->forwardDims();
g1->setBackend("dummy"); g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU"); fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op"); g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){ for( auto n : g1->getNodes()){
...@@ -242,7 +242,7 @@ ...@@ -242,7 +242,7 @@
} }
REQUIRE(cpt == 2); REQUIRE(cpt == 2);
FAIL_CHECK("This test is expected to fail due to known issues."); FAIL_CHECK("This test is expected to fail due to known issues.");
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment