Skip to content
Snippets Groups Projects
Commit 5b513d7d authored by Wissam Boussella's avatar Wissam Boussella
Browse files

Revert "New tests adapt_to_backend"

This reverts commit f23dbcae.
parent eb84af98
No related branches found
No related tags found
No related merge requests found
Pipeline #65784 failed
...@@ -76,17 +76,13 @@ struct Registrar { ...@@ -76,17 +76,13 @@ struct Registrar {
} }
static auto create(const registrar_key& key) { static auto create(const registrar_key& key) {
AIDGE_ASSERT(exists(key), "missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name()); AIDGE_ASSERT(exists(key), "missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name());
return C::registry().at(key); return C::registry().at(key);
} }
static std::set<registrar_key> getKeys(){ static std::set<registrar_key> getKeys(){
std::set<registrar_key> keys; std::set<registrar_key> keys;
for(const auto& keyValue : C::registry()){ for(const auto& keyValue : C::registry())
keys.insert(keyValue.first); keys.insert(keyValue.first);
}
return keys; return keys;
} }
}; };
......
...@@ -24,11 +24,6 @@ namespace py = pybind11; ...@@ -24,11 +24,6 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Recipes(py::module &m) void init_Recipes(py::module &m)
{ {
m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>)>(constantFolding), py::arg("graph_view"), R"mydelimiter(
Recipe to optimize graphview by repeatedly identifying nodes with constant inputs, executes them immediately, and replaces them with pre-computed constant.
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
......
...@@ -74,6 +74,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { ...@@ -74,6 +74,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims}); requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims});
} }
const auto& inhAttrs = mOp.inheritedAttributes(); const auto& inhAttrs = mOp.inheritedAttributes();
if (inhAttrs) { if (inhAttrs) {
if (inhAttrs->hasAttr("impl")) { if (inhAttrs->hasAttr("impl")) {
...@@ -88,7 +89,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) ...@@ -88,7 +89,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs)
const auto availableSpecsSet = getAvailableImplSpecs(); const auto availableSpecsSet = getAvailableImplSpecs();
AIDGE_ASSERT(availableSpecsSet.size() > 0 , AIDGE_ASSERT(availableSpecsSet.size() > 0 ,
"OperatorImpl::getBestMatch():o available specs found by" "OperatorImpl::getBestMatch(): No available specs found by"
"getAvailableSpecs(). " "getAvailableSpecs(). "
"Cannot find best implementation for required specs, aborting."); "Cannot find best implementation for required specs, aborting.");
const std::vector<ImplSpec> availableSpecs(availableSpecsSet.begin(), availableSpecsSet.end()); const std::vector<ImplSpec> availableSpecs(availableSpecsSet.begin(), availableSpecsSet.end());
...@@ -234,7 +235,6 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -234,7 +235,6 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
auto op = std::static_pointer_cast<OperatorTensor>(mOp.clone()); auto op = std::static_pointer_cast<OperatorTensor>(mOp.clone());
auto node = std::make_shared<Node>(op); auto node = std::make_shared<Node>(op);
// node->getOperator()->setImpl(getOperator().getImpl());
auto adaptedGraph = std::make_shared<GraphView>(); auto adaptedGraph = std::make_shared<GraphView>();
adaptedGraph->add(node); adaptedGraph->add(node);
......
...@@ -1678,7 +1678,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone ...@@ -1678,7 +1678,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
} }
} }
newGraph->setOrderedOutputs(newOutputNodes); newGraph->setOrderedOutputs(newOutputNodes);
return newGraph; return newGraph;
} }
......
...@@ -54,7 +54,6 @@ Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) cons ...@@ -54,7 +54,6 @@ Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) cons
Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type()); AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type());
AIDGE_ASSERT(mImpl->prodConso() != nullptr, "getNbProducedData(): no prod consumer for {}!", mImpl->getOperator().type());
return mImpl->prodConso()->getNbProducedData(outputIdx); return mImpl->prodConso()->getNbProducedData(outputIdx);
} }
void Aidge::Operator::updateConsummerProducer(){ void Aidge::Operator::updateConsummerProducer(){
......
...@@ -11,7 +11,7 @@ if(NOT Catch2_FOUND) ...@@ -11,7 +11,7 @@ if(NOT Catch2_FOUND)
FetchContent_Declare( FetchContent_Declare(
Catch2 Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2.git GIT_REPOSITORY https://github.com/catchorg/Catch2.git
GIT_TAG v3.7.1 # or a later release GIT_TAG devel # or a later release
) )
FetchContent_MakeAvailable(Catch2) FetchContent_MakeAvailable(Catch2)
message(STATUS "Fetched Catch2 version ${Catch2_VERSION}") message(STATUS "Fetched Catch2 version ${Catch2_VERSION}")
......
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
namespace Aidge { namespace Aidge {
...@@ -41,9 +39,6 @@ public: ...@@ -41,9 +39,6 @@ public:
virtual std::shared_ptr<ProdConso> getProdConso() const override { virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
if (impl.prodConso(mOp)==nullptr){
fmt::println("no prod conso created ");
}
return impl.prodConso(mOp); return impl.prodConso(mOp);
} }
...@@ -68,37 +63,16 @@ REGISTRAR(Conv2D_Op_Impl_dummy, ...@@ -68,37 +63,16 @@ REGISTRAR(Conv2D_Op_Impl_dummy,
{DataType::Float32, DataFormat::NHWC}}}, {DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>; using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ConvRelu = MetaOperator_Op;
using ConvRelu_Op_Impl_dummy = OperatorImpl_dummy<ConvRelu>;
REGISTRAR(ConvRelu_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ConvRelu, std::array<std::string, 2>({"dummy", "ConvReLU"}), ConvRelu_Op_Impl_dummy::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy, REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}}, {{DataType::Any, DataFormat::Default}},
{ProdConso::inPlaceModel, nullptr, nullptr}); {ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
using Transpose_Op_Impl_dummy = OperatorImpl_dummy<Transpose_Op>;
REGISTRAR(Transpose_Op_Impl_dummy,
{{DataType::Any, DataFormat::Any}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(Transpose_Op, "dummy", OperatorImpl_dummy<Transpose_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -112,7 +86,6 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { ...@@ -112,7 +86,6 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
ReLU("relu2"), ReLU("relu2"),
Conv(8, 10, {1, 1}, "conv3") Conv(8, 10, {1, 1}, "conv3")
}); });
REQUIRE(g1->forwardDims());
g1->setBackend("dummy"); g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
...@@ -134,52 +107,11 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { ...@@ -134,52 +107,11 @@ TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default); REQUIRE(outTransOp->getOutput(0)->dataFormat() == DataFormat::Default);
// TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims() // TODO: uncomment when support of NHWC will be implemented in Conv_Op::forwardDims()
REQUIRE(g1->forwardDims()); //REQUIRE(g1->forwardDims());
g1->save("adapttobackend_after_forwarddims", true); //g1->save("adapttobackend_after_forwarddims", true);
SequentialScheduler sched(g1); //SequentialScheduler sched(g1);
sched.forward(); //sched.forward();
} }
TEST_CASE("[cpu/recipes] AdaptToBackend with MetaOp", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({1, 3, 22, 22}, "dataProvider"),
Conv(3, 4, {3, 3}, "conv1"),
ReLU("relu1")
});
REQUIRE(g1->forwardDims());
g1->setBackend("dummy");
fuseToMetaOps(g1, "Conv2D->ReLU", "ConvReLU");
g1->save("fuse_meta_op");
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
}
}
adaptToBackend(g1);
for( auto n : g1->getNodes()){
n->setName(n->createUniqueName("n"));
if (n->type() == "ConvReLU"){
auto convReluOp = std::static_pointer_cast<ConvRelu>(n->getOperator());
fmt::println("Backends avalaile for ConvRelu is {}",convReluOp->getAvailableBackends());
REQUIRE(convReluOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convReluOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
}
}
g1->save("adapt_to_backend");
// constantFolding(g1);
SequentialScheduler sched(g1);
sched.generateScheduling();
sched.generateMemory();
sched.forward();
}
} // namespace Aidge } // namespace Aidge
...@@ -19,21 +19,20 @@ ...@@ -19,21 +19,20 @@
using namespace Aidge; using namespace Aidge;
TEST_CASE("[ExplicitTranspose] ExplicitTranspose conv") { TEST_CASE("[ExplicitTranspose] conv") {
auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3", {2, 2}); auto conv3 = Conv(64, 10, {1, 1}, "conv3", {2, 2});
auto producer1 = Producer({16, 3, 224, 224}, "dataProvider");
conv2->getOperator()->setDataFormat(DataFormat::NHWC);
auto g1 = Sequential({ auto g1 = Sequential({
producer1, Producer({16, 3, 224, 224}, "dataProvider"),
conv1, conv1,
conv2, conv2,
conv3 conv3
}); });
g1->setDataFormat(DataFormat::NCHW); g1->setDataFormat(DataFormat::NCHW);
conv2->getOperator()->setDataFormat(DataFormat::NHWC);
g1->save("explicitTranspose_before"); g1->save("explicitTranspose_before");
REQUIRE(g1->getNodes().size() == 10); REQUIRE(g1->getNodes().size() == 10);
...@@ -42,13 +41,13 @@ TEST_CASE("[ExplicitTranspose] ExplicitTranspose conv") { ...@@ -42,13 +41,13 @@ TEST_CASE("[ExplicitTranspose] ExplicitTranspose conv") {
g1->forwardDims(); g1->forwardDims();
explicitTranspose(g1); explicitTranspose(g1);
// // Check that Transpose were inserted // Check that Transpose were inserted
// g1->save("explicitTranspose_after"); g1->save("explicitTranspose_after");
// REQUIRE(g1->getNodes().size() == 12); REQUIRE(g1->getNodes().size() == 12);
// // Check that Transpose are removed // Check that Transpose are removed
// conv2->getOperator()->setDataFormat(DataFormat::NCHW); conv2->getOperator()->setDataFormat(DataFormat::NCHW);
// explicitTranspose(g1); explicitTranspose(g1);
REQUIRE(g1->getNodes().size() == 10); REQUIRE(g1->getNodes().size() == 10);
REQUIRE(g1->getNodes() == initialNodes); REQUIRE(g1->getNodes() == initialNodes);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment