From e07747238016b6107d83adcc954718c8f8bf0474 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 7 Dec 2023 23:54:39 +0100 Subject: [PATCH] Added ExplicitConvert recipe --- include/aidge/operator/Conv.hpp | 12 +++++ include/aidge/operator/Convert.hpp | 2 +- include/aidge/recipies/Recipies.hpp | 8 ++++ src/operator/OperatorTensor.cpp | 2 + src/recipies/ExplicitConvert.cpp | 73 +++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 src/recipies/ExplicitConvert.cpp diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 8ba1dfbf7..fc16359aa 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -174,6 +174,18 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); + + // By default, automatically set backend for weight and bias inputs + getInput(1)->setBackend(name, device); + getInput(2)->setBackend(name, device); + } + + void setDataType(const DataType& dt) const override { + mOutputs[0]->setDataType(dt); + + // By default, automatically set data type for weight and bias inputs + getInput(1)->setDataType(dt); + getInput(2)->setDataType(dt); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Convert.hpp index f115a243d..6a08fbf0d 100644 --- a/include/aidge/operator/Convert.hpp +++ b/include/aidge/operator/Convert.hpp @@ -51,7 +51,7 @@ public: } void setBackend(const std::string& name, int device = 0) override { - if (Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { + if (mInputs[0]->getImpl() && Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { mImpl = Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this); } mOutputs[0]->setBackend(name, device); diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp index 5ad08a658..f1d9a9327 100644 --- a/include/aidge/recipies/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -89,6 +89,14 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No // std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); // void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); + +/** + * Add Convert operators where needed to ensure no conversion needs to be done + * at the Operator level. +*/ +void explicitConvert(std::shared_ptr<GraphView> graphView); + + } // namespace Aidge #endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 1237fdc0b..ccee3865e 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -149,6 +149,7 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); } + /* for (IOIndex_t i = 0; i < nbInputs(); ++i) { if (!getInput(i)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set"); @@ -157,4 +158,5 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { getInput(i)->setDataType(dataType); } } + */ } \ No newline at end of file diff --git a/src/recipies/ExplicitConvert.cpp b/src/recipies/ExplicitConvert.cpp new file mode 100644 index 000000000..59d30d16a --- /dev/null +++ b/src/recipies/ExplicitConvert.cpp @@ -0,0 +1,73 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/recipies/Recipies.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Convert.hpp" + +void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) { + const auto nodes = graph->getNodes(); + for (auto node : nodes) { + // TODO: currently, Operator data type is only reflected in its output tensor data type. + // But an Operator might have multiple outputs of different data type(?) + const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0); + if (output->getImpl() == nullptr) { + continue; + } + const auto& device = output->getImpl()->device(); + + if (node->type() == Convert_Op::Type) { + // Remove existing Convert operator, if not needed anymore + const auto& parent = node->inputs()[0]; + const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); + + if (input->dataType() == output->dataType() + && (input->getImpl()->device() == device)) + { + // Add direct connection bypassing Convert node + for (auto child : node->outputs()[0]) { + parent.first->addChild(child.first, parent.second, child.second); + } + + // Remove all Convert node connections + node->resetConnections(); + } + } + else { + // Insert Convert operator between node inputs and parent output, if needed + IOIndex_t inputIdx = 0; + for (auto parent : node->inputs()) { + if (parent.first != nullptr) { + const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); + + if (input->dataType() != output->dataType() + || (input->getImpl()->device() != device)) + { + // A conversion Operator is needed + auto convert = Convert(); + convert->addChild(node, 0, inputIdx); + parent.first->addChild(convert, parent.second, 0); + // Set backend AFTER connection in case a specific implementation + // of the operator exists for the input type. + convert->getOperator()->setBackend(device.first, device.second); + + // Add/update nodes in the GraphView + graph->add(convert); + graph->add(parent.first); + graph->add(node); + } + } + + ++inputIdx; + } + } + } +} -- GitLab