From 067796b1fb3b8704af63f431ad278f4ee1a30117 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 14 Dec 2023 14:11:14 +0100 Subject: [PATCH] Split Convert operator into Cast and Move operators --- include/aidge/data/Tensor.hpp | 16 +++++ include/aidge/operator/Cast.hpp | 72 +++++++++++++++++++ .../aidge/operator/{Convert.hpp => Move.hpp} | 37 ++++------ include/aidge/recipies/Recipies.hpp | 2 +- src/data/Tensor.cpp | 34 +++++++++ src/operator/{Convert.cpp => Cast.cpp} | 6 +- src/operator/Move.cpp | 24 +++++++ ...plicitConvert.cpp => ExplicitCastMove.cpp} | 57 ++++++++++----- 8 files changed, 205 insertions(+), 43 deletions(-) create mode 100644 include/aidge/operator/Cast.hpp rename include/aidge/operator/{Convert.hpp => Move.hpp} (52%) rename src/operator/{Convert.cpp => Cast.cpp} (80%) create mode 100644 src/operator/Move.cpp rename src/recipies/{ExplicitConvert.cpp => ExplicitCastMove.cpp} (53%) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 3d40ca644..2e35238cd 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -659,6 +659,22 @@ class Tensor : public Data, return flatIdx + coordIdx[i]; } + /** + * Copy-cast data from a Tensor on the same device. + * If current tensor backend/device is set and is different from src, an + * assertion is raised. + * @param src Source tensor to copy-cast from. + */ + void copyCast(const Tensor& src); + + /** + * Copy data from a Tensor from another backend/device. + * If current tensor data type is set and is different from src, an + * assertion is raised. + * @param src Source tensor to copy from. + */ + void copyFrom(const Tensor& src); + /** * Copy-cast data from a Tensor. * @param src Source tensor to copy-cast from. diff --git a/include/aidge/operator/Cast.hpp b/include/aidge/operator/Cast.hpp new file mode 100644 index 000000000..38fb0b1ff --- /dev/null +++ b/include/aidge/operator/Cast.hpp @@ -0,0 +1,72 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_CAST_H_ +#define AIDGE_CORE_OPERATOR_CAST_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Cast_Op : public OperatorTensor, + public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> { +public: + static constexpr const char* Type = "Cast"; + + Cast_Op() : OperatorTensor(Type, 1, 0, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Cast_Op(const Cast_Op& op) + : OperatorTensor(op) + { + mImpl = op.mImpl ? Registrar<Cast_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Cast_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Cast_Op>(*this); + } + + void setBackend(const std::string& name, int device = 0) override { + mOutputs[0]->setBackend(name, device); + } + + void forward() override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Cast(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Cast_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_CAST_H_ */ \ No newline at end of file diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Move.hpp similarity index 52% rename from include/aidge/operator/Convert.hpp rename to include/aidge/operator/Move.hpp index aa35d6812..be7f8922e 100644 --- a/include/aidge/operator/Convert.hpp +++ b/include/aidge/operator/Move.hpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#ifndef AIDGE_CORE_OPERATOR_CONVERT_H_ -#define AIDGE_CORE_OPERATOR_CONVERT_H_ +#ifndef AIDGE_CORE_OPERATOR_MOVE_H_ +#define AIDGE_CORE_OPERATOR_MOVE_H_ #include <cassert> #include <memory> @@ -25,34 +25,34 @@ namespace Aidge { -class Convert_Op : public OperatorTensor, - public Registrable<Convert_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Convert_Op&)> { +class Move_Op : public OperatorTensor, + public Registrable<Move_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Move_Op&)> { public: - static constexpr const char* Type = "Convert"; + static constexpr const char* Type = "Move"; - Convert_Op() : OperatorTensor(Type, 1, 0, 1) {} + Move_Op() : OperatorTensor(Type, 1, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Convert_Op(const Convert_Op& op) + Move_Op(const Move_Op& op) : OperatorTensor(op) { - mImpl = op.mImpl ? Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr; + mImpl = op.mImpl ? Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr; } /** * @brief Clone the operator using its copy-constructor. - * @see Operator::Convert_Op + * @see Operator::Move_Op */ std::shared_ptr<Operator> clone() const override { - return std::make_shared<Convert_Op>(*this); + return std::make_shared<Move_Op>(*this); } void setBackend(const std::string& name, int device = 0) override { - if (mInputs[0]->getImpl() && Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { - mImpl = Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this); + if (mInputs[0]->getImpl() && Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { + mImpl = Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this); } mOutputs[0]->setBackend(name, device); } @@ -65,18 +65,11 @@ public: static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } - -private: - /// @brief Store the input data to the output device, before type conversion. - /// Used only when there is both a change of device AND of data type. - /// Otherwise, data is either directly copied from the other device or - /// casted on the same device (requiring a single copy). - std::shared_ptr<Tensor> mMovedInput; }; -inline std::shared_ptr<Node> Convert(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Convert_Op>(), name); +inline std::shared_ptr<Node> Move(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Move_Op>(), name); } } -#endif /* AIDGE_CORE_OPERATOR_CONVERT_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_MOVE_H_ */ \ No newline at end of file diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp index f1d9a9327..a433f1d0c 100644 --- a/include/aidge/recipies/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -94,7 +94,7 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No * Add Convert operators where needed to ensure no conversion needs to be done * at the Operator level. */ -void explicitConvert(std::shared_ptr<GraphView> graphView); +void explicitCastMove(std::shared_ptr<GraphView> graphView); } // namespace Aidge diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 15e6782a0..8a950ba2f 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -13,6 +13,40 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +void Aidge::Tensor::copyCast(const Tensor& src) { + if (&src == this) { + return; + } + + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + } + resize(src.dims()); + + AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device"); + getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); +} + +void Aidge::Tensor::copyFrom(const Tensor& src) { + if (&src == this) { + return; + } + + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + } + resize(src.dims()); + + AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type"); + getImpl()->copyFrom(*(src.getImpl()), src.size()); +} + void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { if (&src == this) { return; diff --git a/src/operator/Convert.cpp b/src/operator/Cast.cpp similarity index 80% rename from src/operator/Convert.cpp rename to src/operator/Cast.cpp index fcdcf4013..0ac6d5f53 100644 --- a/src/operator/Convert.cpp +++ b/src/operator/Cast.cpp @@ -10,14 +10,14 @@ ********************************************************************************/ #include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/Convert.hpp" +#include "aidge/operator/Cast.hpp" -void Aidge::Convert_Op::forward() { +void Aidge::Cast_Op::forward() { if (mImpl) { mImpl->forward(); } else { - mOutputs[0]->copyCastFrom(*(mInputs[0]), mMovedInput); + mOutputs[0]->copyCast(*(mInputs[0])); } runHooks(); diff --git a/src/operator/Move.cpp b/src/operator/Move.cpp new file mode 100644 index 000000000..d828e994d --- /dev/null +++ b/src/operator/Move.cpp @@ -0,0 +1,24 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Move.hpp" + +void Aidge::Move_Op::forward() { + if (mImpl) { + mImpl->forward(); + } + else { + mOutputs[0]->copyFrom(*(mInputs[0])); + } + + runHooks(); +} diff --git a/src/recipies/ExplicitConvert.cpp b/src/recipies/ExplicitCastMove.cpp similarity index 53% rename from src/recipies/ExplicitConvert.cpp rename to src/recipies/ExplicitCastMove.cpp index d25b133c3..f0cd20502 100644 --- a/src/recipies/ExplicitConvert.cpp +++ b/src/recipies/ExplicitCastMove.cpp @@ -11,9 +11,10 @@ #include "aidge/recipies/Recipies.hpp" #include "aidge/operator/OperatorTensor.hpp" -#include "aidge/operator/Convert.hpp" +#include "aidge/operator/Cast.hpp" +#include "aidge/operator/Move.hpp" -void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) { +void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { const auto nodes = graph->getNodes(); for (auto node : nodes) { // TODO: currently, Operator data type is only reflected in its output tensor data type. @@ -24,44 +25,66 @@ void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) { } const auto& device = output->getImpl()->device(); - if (node->type() == Convert_Op::Type) { - // Remove existing Convert operator, if not needed anymore + if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) { + // Remove existing Cast and Move operators, if not needed anymore const auto parent = node->inputs()[0]; const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); if (input->dataType() == output->dataType() && (input->getImpl()->device() == device)) { - // Add direct connection bypassing Convert node + // Add direct connection bypassing Cast/Move node const auto childs = node->outputs()[0]; for (const auto& child : childs) { parent.first->addChild(child.first, parent.second, child.second); } - // Remove all Convert node connections + // Remove all node connections node->resetConnections(); } } else { - // Insert Convert operator between node inputs and parent output, if needed + // Insert Cast and/or Move operator between node inputs and parent output, if needed IOIndex_t inputIdx = 0; for (auto parent : node->inputs()) { if (parent.first != nullptr) { const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); - - if (input->dataType() != output->dataType() - || (input->getImpl()->device() != device)) - { - // A conversion Operator is needed - auto convert = Convert(); - convert->addChild(node, 0, inputIdx); - parent.first->addChild(convert, parent.second, 0); + + NodePtr moveOp = nullptr; + NodePtr castOp = nullptr; + + if (input->getImpl()->device() != device) { + // Change of backend => a Move operator is required + moveOp = Move(); + moveOp->getOperator()->setDataType(input->dataType()); + castOp = moveOp; + } + + if (input->dataType() != output->dataType()) { + // Change of date type => a Cast operator is required + castOp = Cast(); + castOp->getOperator()->setDataType(output->dataType()); + castOp->getOperator()->setBackend(device.first, device.second); + + if (moveOp == nullptr) { + moveOp = castOp; + } + else { + moveOp->addChild(castOp, 0, 0); + } + } + + if (moveOp != nullptr && castOp != nullptr) { + // Move and/or Cast Operator(s) are needed + castOp->addChild(node, 0, inputIdx); + parent.first->addChild(moveOp, parent.second, 0); // Set backend AFTER connection in case a specific implementation // of the operator exists for the input type. - convert->getOperator()->setBackend(device.first, device.second); + moveOp->getOperator()->setBackend(device.first, device.second); // Add/update nodes in the GraphView - graph->add(convert); + graph->add(moveOp); + graph->add(castOp); graph->add(parent.first); graph->add(node); } -- GitLab