Skip to content
Snippets Groups Projects

Draft: Tests adapt to backend

Open Wissam Boussella requested to merge wboussella/aidge_core:tests_adapt_to_backend into dev
1 unresolved thread
1 file
+ 138
138
Compare changes
  • Side-by-side
  • Inline
@@ -9,100 +9,100 @@
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge {
////////////////////////////////////////////////////////////////////////////////
// Create a dummy implementation
template <class Op>
class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
{
public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op);
}
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created ");
}
return impl.prodConso(mOp);
}
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
}
void forward() override {
fmt::println("forward: {}", mOp.type());
}
};
// Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge {
////////////////////////////////////////////////////////////////////////////////
// Create a dummy implementation
template <class Op>
class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
{
public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op);
}
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created ");
}
return impl.prodConso(mOp);
}
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
}
void forward() override {
fmt::println("forward: {}", mOp.type());
}
};
// Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
void applyConstFold(std::shared_ptr<GraphView> &graphView)
void applyConstFold(std::shared_ptr<GraphView> &graphView)
{
for (const std::shared_ptr<Node> node : graphView->getNodes())
{
@@ -115,45 +115,45 @@
constantFolding(graphView);
}
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"),
Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3")
});
REQUIRE(g1->forwardDims());
g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true);
adaptToBackend(g1);
g1->save("adapttobackend_after", true);
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU");
REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1);
sched.forward();
}
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"),
Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3")
});
REQUIRE(g1->forwardDims());
g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true);
adaptToBackend(g1);
g1->save("adapttobackend_after", true);
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU");
REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1);
sched.forward();
}
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
@@ -161,7 +161,7 @@
});
g1->forwardDims();
g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
@@ -184,10 +184,10 @@
REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward());
FAIL_CHECK("This test is expected to fail due to known issues.");
}
}
// Interesting test because used a lot for export
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") {
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
@@ -195,7 +195,7 @@
});
g1->forwardDims();
g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
@@ -242,7 +242,7 @@
}
REQUIRE(cpt == 2);
FAIL_CHECK("This test is expected to fail due to known issues.");
}
} // namespace Aidge
}
} // namespace Aidge
\ No newline at end of file
Loading