Skip to content
Snippets Groups Projects
Commit 8059dcbf authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Fixed adaptToBackend and added unit test

parent 57e57901
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!312Fixed adaptToBackend and added unit test
...@@ -258,7 +258,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -258,7 +258,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
cast->getOperator()->setBackend(node->getOperator()->backend()); cast->getOperator()->setBackend(node->getOperator()->backend());
cast->addChild(parent, 0, i); cast->addChild(parent, 0, i);
op->getInput(i)->setDataType(requiredIOSpec.type); op->getInput(i)->setDataType(IOSpec.type);
} }
// Input format // Input format
...@@ -273,7 +273,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -273,7 +273,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
transposeOp->getOperator()->setBackend(node->getOperator()->backend()); transposeOp->getOperator()->setBackend(node->getOperator()->backend());
transposeOp->addChild(parent, 0, i); transposeOp->addChild(parent, 0, i);
op->getInput(i)->setDataFormat(requiredIOSpec.format); op->getInput(i)->setDataFormat(IOSpec.format);
} }
// Input dims // Input dims
...@@ -311,7 +311,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -311,7 +311,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
cast->getOperator()->setBackend(node->getOperator()->backend()); cast->getOperator()->setBackend(node->getOperator()->backend());
parent->addChild(cast, i, 0); parent->addChild(cast, i, 0);
op->getInput(i)->setDataType(IOSpec.type); op->getOutput(i)->setDataType(IOSpec.type);
} }
// Output format // Output format
...@@ -326,7 +326,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -326,7 +326,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
transposeOp->getOperator()->setBackend(node->getOperator()->backend()); transposeOp->getOperator()->setBackend(node->getOperator()->backend());
parent->addChild(transposeOp, i, 0); parent->addChild(transposeOp, i, 0);
op->getInput(i)->setDataFormat(IOSpec.format); op->getOutput(i)->setDataFormat(IOSpec.format);
} }
// Output dims // Output dims
......
...@@ -267,7 +267,7 @@ TEST_CASE("[core/data] Tensor(getter/setter)", "[Tensor][Getter][Setter]") { ...@@ -267,7 +267,7 @@ TEST_CASE("[core/data] Tensor(getter/setter)", "[Tensor][Getter][Setter]") {
////////////// //////////////
// backend // backend
// getAvailableBackends() // getAvailableBackends()
REQUIRE(Tensor::getAvailableBackends() == std::set<std::string>({"cpu"})); REQUIRE(Tensor::getAvailableBackends() == std::set<std::string>({"cpu", "dummy"}));
// setBackend() // setBackend()
REQUIRE_NOTHROW(T.setBackend("cpu", 0)); REQUIRE_NOTHROW(T.setBackend("cpu", 0));
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
namespace Aidge {
////////////////////////////////////////////////////////////////////////////////
// Create a dummy implementation
template <class Op>
class OperatorImpl_dummy : public OperatorImpl,
public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>>
{
public:
OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {}
static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) {
return std::make_unique<OperatorImpl_dummy<Op>>(op);
}
virtual std::shared_ptr<ProdConso> getProdConso() const override {
const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec()));
return impl.prodConso(mOp);
}
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys();
return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end());
}
};
// Register it
using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>;
REGISTRAR(Conv2D_Op_Impl_dummy,
{{ // Inputs
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::NHWC},
{DataType::Any, DataFormat::Default}},
{ // Outputs
{DataType::Float32, DataFormat::NHWC}}},
{ProdConso::inPlaceModel, nullptr, nullptr});
using Conv2D_Op = Conv_Op<2>;
REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create);
using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>;
REGISTRAR(ReLU_Op_Impl_dummy,
{{DataType::Any, DataFormat::Default}},
{ProdConso::inPlaceModel, nullptr, nullptr});
REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create);
REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32}));
////////////////////////////////////////////////////////////////////////////////
TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
ReLU("relu1"),
Conv(32, 64, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(64, 10, {1, 1}, "conv3")
});
g1->setBackend("dummy");
auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default);
g1->save("adapttobackend_before", true);
adaptToBackend(g1);
g1->save("adapttobackend_after", true);
// FIXME: the last ~> should be ->, but no match in this case!
auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#~>Transpose->ReLU");
REQUIRE(matches.size() == 1);
convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator());
REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC);
REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment