Skip to content
Snippets Groups Projects
Commit 0bb298f6 authored by Wissam Boussella's avatar Wissam Boussella Committed by Maxence Naud
Browse files

new tests adapt to backend, with metaOP and constant folding (not working yet)

parent 03de25a0
No related branches found
No related tags found
1 merge request!336Tests adapt to backend
This commit is part of merge request !336. Comments created here will be created in the context of that merge request.
...@@ -9,109 +9,240 @@ ...@@ -9,109 +9,240 @@
* *
********************************************************************************/ ********************************************************************************/
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <set> #include <set>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge {
//////////////////////////////////////////////////////////////////////////////// namespace Aidge {
// Create a dummy implementation
template <class Op> ////////////////////////////////////////////////////////////////////////////////
class OperatorImpl_dummy : public OperatorImpl, // Create a dummy implementation
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>> template <class Op>
class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
{
public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op);
}
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created ");
}
return impl.prodConso(mOp);
}
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
}
void forward() override {
fmt::println("forward: {}", mOp.type());
}
};
// Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
void applyConstFold(std::shared_ptr<GraphView> &graphView)
{ {
public: for (const std::shared_ptr<Node> node : graphView->getNodes())
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {} {
if (node->type() == "Producer" && node->name() != "dataProvider")
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) { {
return std::make_unique<OperatorImpl_dummy<Op>>(op); const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
} producer->constant() = true;
}
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
return impl.prodConso(mOp);
} }
constantFolding(graphView);
}
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override { TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys(); auto g1 = Sequential({
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end()); Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"),
Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3")
});
REQUIRE(g1->forwardDims());
g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true);
adaptToBackend(g1);
g1->save("adapttobackend_after", true);
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU");
REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1);
sched.forward();
}
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1")
});
g1->forwardDims();
g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
} }
adaptToBackend(g1);
void forward() override { for( auto n : g1->getNodes()){
fmt::println("forward: {}", mOp.type()); n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends());
REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
}
} }
}; g1->save("adapt_to_backend");
SequentialScheduler sched(g1);
// Register it REQUIRE_NOTHROW(sched.generateScheduling());
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>; REQUIRE_NOTHROW(sched.generateMemory());
REGISTRAR(Conv2D_Op_Impl_dummy, REQUIRE_NOTHROW(sched.forward());
{{ // Inputs }
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC}, // Interesting test because used a lot for export
{DataType::Any, DataFormat::Default}}, TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") {
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Default}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({ auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"), Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"), Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"), ReLU("relu1")
Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3")
}); });
g1->forwardDims();
g1->setBackend("dummy"); g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default); fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default); g1->save("fuse_meta_op");
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default); for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
g1->save("adapttobackend_before", true);
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
}
}
adaptToBackend(g1); adaptToBackend(g1);
g1->save("adapttobackend_after", true); for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU"); if (n->type() == "ConvReLU"){
REQUIRE(matches.size() == 1); auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator()); fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator()); REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC); REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC); REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC); }
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); }
g1->forwardDims({{1, 3, 3, 3}});
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
//REQUIRE(g1->forwardDims()); for (const std::shared_ptr<Node> node : g1->getNodes())
//g1->save("adapttobackend_after_forwarddims", true); {
if (node->type() == "Producer" && node->name() != "dataProvider")
//SequentialScheduler sched(g1); {
//sched.forward(); const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
} producer->constant() = true;
}
} // namespace Aidge }
applyConstFold(g1);
g1->save("constant_folding_2");
SequentialScheduler sched(g1);
REQUIRE_NOTHROW(sched.generateScheduling());
REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward());
unsigned cpt = 0;
for( auto n : g1->getNodes()){
if (n->type() == "Transpose"){
cpt++;
}
}
REQUIRE(cpt == 2);
}
} // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment