diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 5c25ad6056c054b8e757fd942d8f77ccebbf0741..c63554259a7b0a474ff83e97f27885edfec4eef6 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -258,7 +258,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& cast->getOperator()->setBackend(node->getOperator()->backend()); cast->addChild(parent, 0, i); - op->getInput(i)->setDataType(requiredIOSpec.type); + op->getInput(i)->setDataType(IOSpec.type); } // Input format @@ -273,7 +273,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& transposeOp->getOperator()->setBackend(node->getOperator()->backend()); transposeOp->addChild(parent, 0, i); - op->getInput(i)->setDataFormat(requiredIOSpec.format); + op->getInput(i)->setDataFormat(IOSpec.format); } // Input dims @@ -311,7 +311,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& cast->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(cast, i, 0); - op->getInput(i)->setDataType(IOSpec.type); + op->getOutput(i)->setDataType(IOSpec.type); } // Output format @@ -326,7 +326,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& transposeOp->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(transposeOp, i, 0); - op->getInput(i)->setDataFormat(IOSpec.format); + op->getOutput(i)->setDataFormat(IOSpec.format); } // Output dims diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp index 6c4b14602aed98ff5736d2cf30ba642f9e7ec57b..bfdc1a6b9c058b348942e9c29a77ac4d6db5086f 100644 --- a/unit_tests/data/Test_Tensor.cpp +++ b/unit_tests/data/Test_Tensor.cpp @@ -267,7 +267,7 @@ TEST_CASE("[core/data] Tensor(getter/setter)", "[Tensor][Getter][Setter]") { ////////////// // backend // getAvailableBackends() - REQUIRE(Tensor::getAvailableBackends() == std::set<std::string>({"cpu"})); + REQUIRE(Tensor::getAvailableBackends() == std::set<std::string>({"cpu", "dummy"})); // setBackend() REQUIRE_NOTHROW(T.setBackend("cpu", 0)); diff --git a/unit_tests/recipes/Test_AdaptToBackend.cpp b/unit_tests/recipes/Test_AdaptToBackend.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7cccaf724abd6360668dc0bf8bce543ff9dde7c --- /dev/null +++ b/unit_tests/recipes/Test_AdaptToBackend.cpp @@ -0,0 +1,102 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/recipes/Recipes.hpp" + +namespace Aidge { + +//////////////////////////////////////////////////////////////////////////////// +// Create a dummy implementation +template <class Op> +class OperatorImpl_dummy : public OperatorImpl, + public Registrable<OperatorImpl_dummy<Op>, ImplSpec, Impl<void(), void()>> +{ +public: + OperatorImpl_dummy(const Op& op) : OperatorImpl(op, "dummy") {} + + static std::unique_ptr<OperatorImpl_dummy<Op>> create(const Op& op) { + return std::make_unique<OperatorImpl_dummy<Op>>(op); + } + + virtual std::shared_ptr<ProdConso> getProdConso() const override { + const auto impl = Registrar<OperatorImpl_dummy>::create(getBestMatch(getRequiredSpec())); + return impl.prodConso(mOp); + } + + virtual std::vector<ImplSpec> getAvailableImplSpecs() const override { + std::set<ImplSpec> implSpecsSet = Registrar<OperatorImpl_dummy>::getKeys(); + return std::vector<ImplSpec>(implSpecsSet.begin(), implSpecsSet.end()); + } +}; + +// Register it +using Conv2D_Op_Impl_dummy = OperatorImpl_dummy<Conv_Op<2>>; +REGISTRAR(Conv2D_Op_Impl_dummy, + {{ // Inputs + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::NHWC}, + {DataType::Any, DataFormat::Default}}, + { // Outputs + {DataType::Float32, DataFormat::NHWC}}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); + +using Conv2D_Op = Conv_Op<2>; +REGISTRAR(Conv2D_Op, "dummy", OperatorImpl_dummy<Conv2D_Op>::create); + +using ReLU_Op_Impl_dummy = OperatorImpl_dummy<ReLU_Op>; +REGISTRAR(ReLU_Op_Impl_dummy, + {{DataType::Any, DataFormat::Default}}, + {ProdConso::inPlaceModel, nullptr, nullptr}); + +REGISTRAR(ReLU_Op, "dummy", OperatorImpl_dummy<ReLU_Op>::create); + +REGISTRAR(Tensor, {"dummy", DataType::Float32}, Registrar<Tensor>::create({"cpu", DataType::Float32})); +//////////////////////////////////////////////////////////////////////////////// + + +TEST_CASE("[cpu/recipes] AdaptToBackend", "[AdaptToBackend][recipes]") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + ReLU("relu1"), + Conv(32, 64, {3, 3}, "conv2"), + ReLU("relu2"), + Conv(64, 10, {1, 1}, "conv3") + }); + g1->setBackend("dummy"); + auto convOp = std::static_pointer_cast<Conv2D_Op>(g1->getNode("conv1")->getOperator()); + REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::Default); + REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::Default); + REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::Default); + + g1->save("adapttobackend_before", true); + adaptToBackend(g1); + g1->save("adapttobackend_after", true); + + // FIXME: the last ~> should be ->, but no match in this case! + auto matches = SinglePassGraphMatching(g1).match("Conv2D#<-Transpose<-Producer;Conv2D#<1-Transpose<-Producer;Conv2D#<2-Producer;Conv2D#~>Transpose->ReLU"); + REQUIRE(matches.size() == 1); + convOp = std::static_pointer_cast<Conv2D_Op>(matches.begin()->graph->rootNode()->getOperator()); + REQUIRE(convOp->getInput(0)->dataFormat() == DataFormat::NHWC); + REQUIRE(convOp->getInput(1)->dataFormat() == DataFormat::NHWC); + REQUIRE(convOp->getOutput(0)->dataFormat() == DataFormat::NHWC); +} + +} // namespace Aidge