Skip to content
Snippets Groups Projects

Draft: Tests adapt to backend

Open Wissam Boussella requested to merge wboussella/aidge_core:tests_adapt_to_backend into dev
1 unresolved thread
@@ -21,6 +21,8 @@
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge {
@@ -39,6 +41,9 @@ public:
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created ");
}
return impl.prodConso(mOp);
}
@@ -63,19 +68,52 @@ REGISTRAR(Conv2D_Op_Impl_dummy,
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Default}},
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
void applyConstFold(std::shared_ptr<GraphView> &graphView)
{
for (const std::shared_ptr<Node> node : graphView->getNodes())
{
if (node->type() == "Producer" && node->name() != "dataProvider")
{
const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
producer->constant() = true;
}
}
constantFolding(graphView);
}
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
@@ -86,6 +124,7 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3")
});
REQUIRE(g1->forwardDims());
g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
@@ -107,11 +146,103 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
//REQUIRE(g1->forwardDims());
//g1->save("adapttobackend_after_forwarddims", true);
REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1);
sched.forward();
}
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1")
});
g1->forwardDims();
g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
}
adaptToBackend(g1);
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends());
REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
}
}
g1->save("adapt_to_backend");
SequentialScheduler sched(g1);
REQUIRE_NOTHROW(sched.generateScheduling());
REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward());
FAIL_CHECK("This test is expected to fail due to known issues.");
}
// Interesting test because used a lot for export
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1")
});
g1->forwardDims();
g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
//SequentialScheduler sched(g1);
//sched.forward();
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
}
}
adaptToBackend(g1);
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends());
REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
}
}
g1->forwardDims({{1, 3, 3, 3}});
for (const std::shared_ptr<Node> node : g1->getNodes())
{
if (node->type() == "Producer" && node->name() != "dataProvider")
{
const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
producer->constant() = true;
}
}
applyConstFold(g1);
g1->save("constant_folding_2");
SequentialScheduler sched(g1);
REQUIRE_NOTHROW(sched.generateScheduling());
REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward());
unsigned cpt = 0;
for( auto n : g1->getNodes()){
if (n->type() == "Transpose"){
cpt++;
}
}
REQUIRE(cpt == 2);
FAIL_CHECK("This test is expected to fail due to known issues.");
}
} // namespace Aidge
\ No newline at end of file
Loading