Skip to content
Snippets Groups Projects
Commit 7317fc2d authored by Wissam Boussella's avatar Wissam Boussella Committed by Olivier BICHLER
Browse files

new tests adapt to backend, with metaOP and constant folding (not working yet)

parent cb7ee7d5
No related branches found
No related tags found
1 merge request!341Error
This commit is part of merge request !336. Comments created here will be created in the context of that merge request.
......@@ -9,109 +9,240 @@
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
namespace Aidge {
////////////////////////////////////////////////////////////////////////////////
// Create a dummy implementation
template <class Op>
class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge {
////////////////////////////////////////////////////////////////////////////////
// Create a dummy implementation
template <class Op>
class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
{
public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op);
}
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created ");
}
return impl.prodConso(mOp);
}
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
}
void forward() override {
fmt::println("forward: {}", mOp.type());
}
};
// Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
void applyConstFold(std::shared_ptr<GraphView> &graphView)
{
public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op);
}
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
return impl.prodConso(mOp);
for (const std::shared_ptr<Node> node : graphView->getNodes())
{
if (node->type() == "Producer" && node->name() != "dataProvider")
{
const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
producer->constant() = true;
}
}
constantFolding(graphView);
}
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"),
Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3")
});
REQUIRE(g1->forwardDims());
g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true);
adaptToBackend(g1);
g1->save("adapttobackend_after", true);
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU");
REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1);
sched.forward();
}
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1")
});
g1->forwardDims();
g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
}
void forward() override {
fmt::println("forward: {}", mOp.type());
adaptToBackend(g1);
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends());
REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
}
}
};
// Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Default}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
g1->save("adapt_to_backend");
SequentialScheduler sched(g1);
REQUIRE_NOTHROW(sched.generateScheduling());
REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward());
}
// Interesting test because used a lot for export
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp and constantFolding", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1"),
Conv(4, 8, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3")
ReLU("relu1")
});
g1->forwardDims();
g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true);
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
}
}
adaptToBackend(g1);
g1->save("adapttobackend_after", true);
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#->Transpose#->ReLU");
REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
auto outTransOp = std::static_pointer_cast<Transpose_Op>(matches.begin()->anchors.at("Transpose").at("#")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
//REQUIRE(g1->forwardDims());
//g1->save("adapttobackend_after_forwarddims", true);
//SequentialScheduler sched(g1);
//sched.forward();
}
} // namespace Aidge
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends());
REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
}
}
g1->forwardDims({{1, 3, 3, 3}});
for (const std::shared_ptr<Node> node : g1->getNodes())
{
if (node->type() == "Producer" && node->name() != "dataProvider")
{
const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
producer->constant() = true;
}
}
applyConstFold(g1);
g1->save("constant_folding_2");
SequentialScheduler sched(g1);
REQUIRE_NOTHROW(sched.generateScheduling());
REQUIRE_NOTHROW(sched.generateMemory());
REQUIRE_NOTHROW(sched.forward());
unsigned cpt = 0;
for( auto n : g1->getNodes()){
if (n->type() == "Transpose"){
cpt++;
}
}
REQUIRE(cpt == 2);
}
} // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment