diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 3d40ca644403ff64e0285fbaafc6ea8c100eb704..2e35238cd14e98cb86633c2d8fad29bec97c274a 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -659,6 +659,22 @@ class Tensor : public Data, return flatIdx + coordIdx[i]; } + /** + * Copy-cast data from a Tensor on the same device. + * If current tensor backend/device is set and is different from src, an + * assertion is raised. + * @param src Source tensor to copy-cast from. + */ + void copyCast(const Tensor& src); + + /** + * Copy data from a Tensor from another backend/device. + * If current tensor data type is set and is different from src, an + * assertion is raised. + * @param src Source tensor to copy from. + */ + void copyFrom(const Tensor& src); + /** * Copy-cast data from a Tensor. * @param src Source tensor to copy-cast from. diff --git a/include/aidge/operator/Cast.hpp b/include/aidge/operator/Cast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..38fb0b1ffcbb243658cb8d57af04e97a7179f5f2 --- /dev/null +++ b/include/aidge/operator/Cast.hpp @@ -0,0 +1,72 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_CAST_H_ +#define AIDGE_CORE_OPERATOR_CAST_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Cast_Op : public OperatorTensor, + public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> { +public: + static constexpr const char* Type = "Cast"; + + Cast_Op() : OperatorTensor(Type, 1, 0, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Cast_Op(const Cast_Op& op) + : OperatorTensor(op) + { + mImpl = op.mImpl ? Registrar<Cast_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Cast_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Cast_Op>(*this); + } + + void setBackend(const std::string& name, int device = 0) override { + mOutputs[0]->setBackend(name, device); + } + + void forward() override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Cast(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Cast_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_CAST_H_ */ \ No newline at end of file diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Move.hpp similarity index 52% rename from include/aidge/operator/Convert.hpp rename to include/aidge/operator/Move.hpp index aa35d6812d95ce67af0a639e195846f4baf0b54a..be7f8922ecbea3356e55216f2381da360e4e8f39 100644 --- a/include/aidge/operator/Convert.hpp +++ b/include/aidge/operator/Move.hpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#ifndef AIDGE_CORE_OPERATOR_CONVERT_H_ -#define AIDGE_CORE_OPERATOR_CONVERT_H_ +#ifndef AIDGE_CORE_OPERATOR_MOVE_H_ +#define AIDGE_CORE_OPERATOR_MOVE_H_ #include <cassert> #include <memory> @@ -25,34 +25,34 @@ namespace Aidge { -class Convert_Op : public OperatorTensor, - public Registrable<Convert_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Convert_Op&)> { +class Move_Op : public OperatorTensor, + public Registrable<Move_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Move_Op&)> { public: - static constexpr const char* Type = "Convert"; + static constexpr const char* Type = "Move"; - Convert_Op() : OperatorTensor(Type, 1, 0, 1) {} + Move_Op() : OperatorTensor(Type, 1, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Convert_Op(const Convert_Op& op) + Move_Op(const Move_Op& op) : OperatorTensor(op) { - mImpl = op.mImpl ? Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr; + mImpl = op.mImpl ? Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr; } /** * @brief Clone the operator using its copy-constructor. - * @see Operator::Convert_Op + * @see Operator::Move_Op */ std::shared_ptr<Operator> clone() const override { - return std::make_shared<Convert_Op>(*this); + return std::make_shared<Move_Op>(*this); } void setBackend(const std::string& name, int device = 0) override { - if (mInputs[0]->getImpl() && Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { - mImpl = Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this); + if (mInputs[0]->getImpl() && Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { + mImpl = Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this); } mOutputs[0]->setBackend(name, device); } @@ -65,18 +65,11 @@ public: static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } - -private: - /// @brief Store the input data to the output device, before type conversion. - /// Used only when there is both a change of device AND of data type. - /// Otherwise, data is either directly copied from the other device or - /// casted on the same device (requiring a single copy). - std::shared_ptr<Tensor> mMovedInput; }; -inline std::shared_ptr<Node> Convert(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Convert_Op>(), name); +inline std::shared_ptr<Node> Move(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Move_Op>(), name); } } -#endif /* AIDGE_CORE_OPERATOR_CONVERT_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_MOVE_H_ */ \ No newline at end of file diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp index f1d9a93273ef6ffef3f9eed98a22cebd33599353..a433f1d0c053a888212406e5684fa1b2a48f28fe 100644 --- a/include/aidge/recipies/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -94,7 +94,7 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No * Add Convert operators where needed to ensure no conversion needs to be done * at the Operator level. */ -void explicitConvert(std::shared_ptr<GraphView> graphView); +void explicitCastMove(std::shared_ptr<GraphView> graphView); } // namespace Aidge diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 15e6782a08a9d080688d95be9f16f0d66aec3fbe..8a950ba2febcfb5b0f22420af0feae40204c5646 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -13,6 +13,40 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +void Aidge::Tensor::copyCast(const Tensor& src) { + if (&src == this) { + return; + } + + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + } + resize(src.dims()); + + AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device"); + getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); +} + +void Aidge::Tensor::copyFrom(const Tensor& src) { + if (&src == this) { + return; + } + + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + } + resize(src.dims()); + + AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type"); + getImpl()->copyFrom(*(src.getImpl()), src.size()); +} + void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { if (&src == this) { return; diff --git a/src/operator/Convert.cpp b/src/operator/Cast.cpp similarity index 80% rename from src/operator/Convert.cpp rename to src/operator/Cast.cpp index fcdcf4013f14f0673c0e9a53dbdb6d30543540f7..0ac6d5f5304b62a832f84eaadcb7b6dbfe0a34e0 100644 --- a/src/operator/Convert.cpp +++ b/src/operator/Cast.cpp @@ -10,14 +10,14 @@ ********************************************************************************/ #include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/Convert.hpp" +#include "aidge/operator/Cast.hpp" -void Aidge::Convert_Op::forward() { +void Aidge::Cast_Op::forward() { if (mImpl) { mImpl->forward(); } else { - mOutputs[0]->copyCastFrom(*(mInputs[0]), mMovedInput); + mOutputs[0]->copyCast(*(mInputs[0])); } runHooks(); diff --git a/src/operator/Move.cpp b/src/operator/Move.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d828e994d61bb7307a55c120685ac4a2808fc086 --- /dev/null +++ b/src/operator/Move.cpp @@ -0,0 +1,24 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Move.hpp" + +void Aidge::Move_Op::forward() { + if (mImpl) { + mImpl->forward(); + } + else { + mOutputs[0]->copyFrom(*(mInputs[0])); + } + + runHooks(); +} diff --git a/src/recipies/ExplicitConvert.cpp b/src/recipies/ExplicitCastMove.cpp similarity index 53% rename from src/recipies/ExplicitConvert.cpp rename to src/recipies/ExplicitCastMove.cpp index d25b133c3374029fe763dd9a4bbc98df58c45a0d..f0cd205020beacd998007da55fa5a9474707a7f6 100644 --- a/src/recipies/ExplicitConvert.cpp +++ b/src/recipies/ExplicitCastMove.cpp @@ -11,9 +11,10 @@ #include "aidge/recipies/Recipies.hpp" #include "aidge/operator/OperatorTensor.hpp" -#include "aidge/operator/Convert.hpp" +#include "aidge/operator/Cast.hpp" +#include "aidge/operator/Move.hpp" -void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) { +void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { const auto nodes = graph->getNodes(); for (auto node : nodes) { // TODO: currently, Operator data type is only reflected in its output tensor data type. @@ -24,44 +25,66 @@ void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) { } const auto& device = output->getImpl()->device(); - if (node->type() == Convert_Op::Type) { - // Remove existing Convert operator, if not needed anymore + if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) { + // Remove existing Cast and Move operators, if not needed anymore const auto parent = node->inputs()[0]; const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); if (input->dataType() == output->dataType() && (input->getImpl()->device() == device)) { - // Add direct connection bypassing Convert node + // Add direct connection bypassing Cast/Move node const auto childs = node->outputs()[0]; for (const auto& child : childs) { parent.first->addChild(child.first, parent.second, child.second); } - // Remove all Convert node connections + // Remove all node connections node->resetConnections(); } } else { - // Insert Convert operator between node inputs and parent output, if needed + // Insert Cast and/or Move operator between node inputs and parent output, if needed IOIndex_t inputIdx = 0; for (auto parent : node->inputs()) { if (parent.first != nullptr) { const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); - - if (input->dataType() != output->dataType() - || (input->getImpl()->device() != device)) - { - // A conversion Operator is needed - auto convert = Convert(); - convert->addChild(node, 0, inputIdx); - parent.first->addChild(convert, parent.second, 0); + + NodePtr moveOp = nullptr; + NodePtr castOp = nullptr; + + if (input->getImpl()->device() != device) { + // Change of backend => a Move operator is required + moveOp = Move(); + moveOp->getOperator()->setDataType(input->dataType()); + castOp = moveOp; + } + + if (input->dataType() != output->dataType()) { + // Change of date type => a Cast operator is required + castOp = Cast(); + castOp->getOperator()->setDataType(output->dataType()); + castOp->getOperator()->setBackend(device.first, device.second); + + if (moveOp == nullptr) { + moveOp = castOp; + } + else { + moveOp->addChild(castOp, 0, 0); + } + } + + if (moveOp != nullptr && castOp != nullptr) { + // Move and/or Cast Operator(s) are needed + castOp->addChild(node, 0, inputIdx); + parent.first->addChild(moveOp, parent.second, 0); // Set backend AFTER connection in case a specific implementation // of the operator exists for the input type. - convert->getOperator()->setBackend(device.first, device.second); + moveOp->getOperator()->setBackend(device.first, device.second); // Add/update nodes in the GraphView - graph->add(convert); + graph->add(moveOp); + graph->add(castOp); graph->add(parent.first); graph->add(node); }