From 2bba4f964f28f45b45db5cea9c09e880940d5b54 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 8 Jan 2025 18:13:34 +0100 Subject: [PATCH] Multiple fixes related to adaptToBackend() --- include/aidge/recipes/Recipes.hpp | 6 +++++ .../backend/pybind_OperatorImpl.cpp | 1 + src/backend/OperatorImpl.cpp | 27 ++++++++++++++----- src/data/Tensor.cpp | 1 + src/operator/Transpose.cpp | 11 +++++--- src/recipes/AdaptToBackend.cpp | 1 + src/recipes/ExpandMetaOps.cpp | 15 +++++++++++ 7 files changed, 53 insertions(+), 9 deletions(-) diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 0fb405bfe..aa4d3ae1b 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -124,6 +124,12 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView); */ void explicitTranspose(std::shared_ptr<GraphView> graphView); +/** + * Replace a single meta operator by its micro graph. + * @return true if node is indeed a meta operator and could be expanded. +*/ +bool expandMetaOp(std::shared_ptr<Node> node); + /** * Flatten the graph by replacing the meta operators by their micro graph. * @param recursive If true, recursively replace meta operators until there is diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 49e45ed7e..cd94997cf 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -81,6 +81,7 @@ void init_OperatorImpl(py::module& m){ .def(py::init<const DynamicAttributes&>(), py::arg("attr") = DynamicAttributes()) .def(py::init<const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("io"), py::arg("attr") = DynamicAttributes()) .def(py::init<const ImplSpec::IOSpec&, const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes()) + .def(py::init<const std::vector<ImplSpec::IOSpec>&, const std::vector<ImplSpec::IOSpec>&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes()) .def("__eq__", static_cast<bool(*)(const ImplSpec&, const ImplSpec&)>(&operator==)) .def("__repr__", [](ImplSpec self){ return fmt::format("{}\n", self); diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index c74b538a4..8a4924c0e 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -250,9 +250,10 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& && requiredIOSpec.type != IOSpec.type) { const auto cast = Cast(IOSpec.type); + cast->getOperator()->setBackend(node->getOperator()->backend()); cast->addChild(parent, 0, i); - op->getInput(i)->setDataType(IOSpec.type); + op->getInput(i)->setDataType(requiredIOSpec.type); } // Input format @@ -263,10 +264,11 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(IOSpec.format); - transposeOp->getOperator()->setDataType(IOSpec.type); + transposeOp->getOperator()->setDataType(requiredIOSpec.type); + transposeOp->getOperator()->setBackend(node->getOperator()->backend()); transposeOp->addChild(parent, 0, i); - op->getInput(i)->setDataFormat(IOSpec.format); + op->getInput(i)->setDataFormat(requiredIOSpec.format); } // Input dims @@ -301,6 +303,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& && requiredIOSpec.type != IOSpec.type) { const auto cast = Cast(requiredIOSpec.type); + cast->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(cast, i, 0); op->getOutput(i)->setDataType(IOSpec.type); @@ -315,6 +318,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(requiredIOSpec.format); transposeOp->getOperator()->setDataType(requiredIOSpec.type); + transposeOp->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(transposeOp, i, 0); op->getOutput(i)->setDataFormat(IOSpec.format); @@ -340,7 +344,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& } } - return MetaOperator(std::string("Adapted_" + op->type()).c_str(), getConnectedGraphView(node)); + auto adaptedGraph = getConnectedGraphView(node); + if (adaptedGraph->getNodes().size() > 1) { + return MetaOperator(std::string("Adapted_" + op->type()).c_str(), adaptedGraph); + } + else { + return node; + } } std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const { @@ -354,8 +364,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSp auto adaptation = getAdaptation(availableSpec, requiredSpecs); if (adaptation) { - auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph(); - adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size())); + if (adaptation->getOperator()->isAtomic()) { + adaptations.insert(std::make_pair(adaptation, 1)); + } + else { + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph(); + adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size())); + } } } diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index c834167ab..e8a0e9ede 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -538,6 +538,7 @@ void Tensor::copyTranspose(const Tensor& src, const std::vector<DimSize_t>& tran } } + AIDGE_ASSERT(mImpl, "Tensor::copyTranspose(): an implementation is required, use setBackend() first!"); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, newDims); std::vector<size_t> indices(newDims.size(), 0); diff --git a/src/operator/Transpose.cpp b/src/operator/Transpose.cpp index d24b9c909..b550db16d 100644 --- a/src/operator/Transpose.cpp +++ b/src/operator/Transpose.cpp @@ -66,12 +66,17 @@ bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) { std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0); } - AIDGE_ASSERT(outputDimsOrder().size() == getInput(0)->nbDims(), - "Permutation vector must have the same rank as input tensor."); + AIDGE_ASSERT(outputDimsOrder().size() >= getInput(0)->nbDims(), + "Permutation vector ({}) must have at least the same rank as input tensor ({}).", outputDimsOrder(), getInput(0)->dims()); std::vector<DimSize_t> outputDims; - for (std::size_t i = 0; i < outputDimsOrder().size(); ++i) { + std::size_t i = 0; + for (; i < getInput(0)->nbDims(); ++i) { outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]); } + for (; i < outputDimsOrder().size(); ++i) { + AIDGE_ASSERT(i == outputDimsOrder()[i], + "Permutation vector ({}) must be the identity above the input tensor rank ({}).", outputDimsOrder(), getInput(0)->dims()); + } mOutputs[0]->resize(outputDims); return true; } diff --git a/src/recipes/AdaptToBackend.cpp b/src/recipes/AdaptToBackend.cpp index e625a52f6..bb4222c49 100644 --- a/src/recipes/AdaptToBackend.cpp +++ b/src/recipes/AdaptToBackend.cpp @@ -33,6 +33,7 @@ void Aidge::adaptToBackend(std::shared_ptr<GraphView> graphView) { Log::info("Adapted node {} (of type {}) to backend {}", node->name(), node->type(), impl->backend()); AIDGE_ASSERT(GraphView::replace({node}, {adaptedNode}), "Unable to replace adapted node!"); + expandMetaOp(adaptedNode); } } } diff --git a/src/recipes/ExpandMetaOps.cpp b/src/recipes/ExpandMetaOps.cpp index 16f0b4c52..459a1ca85 100644 --- a/src/recipes/ExpandMetaOps.cpp +++ b/src/recipes/ExpandMetaOps.cpp @@ -14,6 +14,21 @@ #include "aidge/recipes/Recipes.hpp" #include "aidge/operator/MetaOperator.hpp" +bool Aidge::expandMetaOp(std::shared_ptr<Node> node) { + auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator()); + + if (metaOp != nullptr) { + // Replace meta op by its micro-graph + // graph will be updated accordingly in GraphView::replace() + auto g = std::make_shared<GraphView>(); + g->add(node, false); + GraphView::replace(g, metaOp->getMicroGraph()); + return true; + } + + return false; +} + void Aidge::expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive) { bool found = false; const auto nodes = graph->getNodes(); -- GitLab