Skip to content
Snippets Groups Projects
Commit 067796b1 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Split Convert operator into Cast and Move operators

parent b41f518d
No related branches found
No related tags found
No related merge requests found
......@@ -659,6 +659,22 @@ class Tensor : public Data,
return flatIdx + coordIdx[i];
}
/**
* Copy-cast data from a Tensor on the same device.
* If current tensor backend/device is set and is different from src, an
* assertion is raised.
* @param src Source tensor to copy-cast from.
*/
void copyCast(const Tensor& src);
/**
* Copy data from a Tensor from another backend/device.
* If current tensor data type is set and is different from src, an
* assertion is raised.
* @param src Source tensor to copy from.
*/
void copyFrom(const Tensor& src);
/**
* Copy-cast data from a Tensor.
* @param src Source tensor to copy-cast from.
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CAST_H_
#define AIDGE_CORE_OPERATOR_CAST_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Cast_Op : public OperatorTensor,
public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> {
public:
static constexpr const char* Type = "Cast";
Cast_Op() : OperatorTensor(Type, 1, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Cast_Op(const Cast_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Cast_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Cast_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Cast_Op>(*this);
}
void setBackend(const std::string& name, int device = 0) override {
mOutputs[0]->setBackend(name, device);
}
void forward() override;
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Cast(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Cast_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_CAST_H_ */
\ No newline at end of file
......@@ -9,8 +9,8 @@
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CONVERT_H_
#define AIDGE_CORE_OPERATOR_CONVERT_H_
#ifndef AIDGE_CORE_OPERATOR_MOVE_H_
#define AIDGE_CORE_OPERATOR_MOVE_H_
#include <cassert>
#include <memory>
......@@ -25,34 +25,34 @@
namespace Aidge {
class Convert_Op : public OperatorTensor,
public Registrable<Convert_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Convert_Op&)> {
class Move_Op : public OperatorTensor,
public Registrable<Move_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Move_Op&)> {
public:
static constexpr const char* Type = "Convert";
static constexpr const char* Type = "Move";
Convert_Op() : OperatorTensor(Type, 1, 0, 1) {}
Move_Op() : OperatorTensor(Type, 1, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Convert_Op(const Convert_Op& op)
Move_Op(const Move_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr;
mImpl = op.mImpl ? Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Convert_Op
* @see Operator::Move_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Convert_Op>(*this);
return std::make_shared<Move_Op>(*this);
}
void setBackend(const std::string& name, int device = 0) override {
if (mInputs[0]->getImpl() && Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
mImpl = Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this);
if (mInputs[0]->getImpl() && Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
mImpl = Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this);
}
mOutputs[0]->setBackend(name, device);
}
......@@ -65,18 +65,11 @@ public:
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
private:
/// @brief Store the input data to the output device, before type conversion.
/// Used only when there is both a change of device AND of data type.
/// Otherwise, data is either directly copied from the other device or
/// casted on the same device (requiring a single copy).
std::shared_ptr<Tensor> mMovedInput;
};
inline std::shared_ptr<Node> Convert(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Convert_Op>(), name);
inline std::shared_ptr<Node> Move(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Move_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_CONVERT_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_MOVE_H_ */
\ No newline at end of file
......@@ -94,7 +94,7 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No
* Add Convert operators where needed to ensure no conversion needs to be done
* at the Operator level.
*/
void explicitConvert(std::shared_ptr<GraphView> graphView);
void explicitCastMove(std::shared_ptr<GraphView> graphView);
} // namespace Aidge
......
......@@ -13,6 +13,40 @@
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Tensor::copyCast(const Tensor& src) {
if (&src == this) {
return;
}
// Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src
const auto deviceSrc = src.getImpl()->device();
setBackend(deviceSrc.first, deviceSrc.second);
}
resize(src.dims());
AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device");
getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType());
}
void Aidge::Tensor::copyFrom(const Tensor& src) {
if (&src == this) {
return;
}
// Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src
const auto deviceSrc = src.getImpl()->device();
setBackend(deviceSrc.first, deviceSrc.second);
}
resize(src.dims());
AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type");
getImpl()->copyFrom(*(src.getImpl()), src.size());
}
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) {
if (&src == this) {
return;
......
......@@ -10,14 +10,14 @@
********************************************************************************/
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Convert.hpp"
#include "aidge/operator/Cast.hpp"
void Aidge::Convert_Op::forward() {
void Aidge::Cast_Op::forward() {
if (mImpl) {
mImpl->forward();
}
else {
mOutputs[0]->copyCastFrom(*(mInputs[0]), mMovedInput);
mOutputs[0]->copyCast(*(mInputs[0]));
}
runHooks();
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Move.hpp"
void Aidge::Move_Op::forward() {
if (mImpl) {
mImpl->forward();
}
else {
mOutputs[0]->copyFrom(*(mInputs[0]));
}
runHooks();
}
......@@ -11,9 +11,10 @@
#include "aidge/recipies/Recipies.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Convert.hpp"
#include "aidge/operator/Cast.hpp"
#include "aidge/operator/Move.hpp"
void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) {
void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
const auto nodes = graph->getNodes();
for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type.
......@@ -24,44 +25,66 @@ void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) {
}
const auto& device = output->getImpl()->device();
if (node->type() == Convert_Op::Type) {
// Remove existing Convert operator, if not needed anymore
if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) {
// Remove existing Cast and Move operators, if not needed anymore
const auto parent = node->inputs()[0];
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() == output->dataType()
&& (input->getImpl()->device() == device))
{
// Add direct connection bypassing Convert node
// Add direct connection bypassing Cast/Move node
const auto childs = node->outputs()[0];
for (const auto& child : childs) {
parent.first->addChild(child.first, parent.second, child.second);
}
// Remove all Convert node connections
// Remove all node connections
node->resetConnections();
}
}
else {
// Insert Convert operator between node inputs and parent output, if needed
// Insert Cast and/or Move operator between node inputs and parent output, if needed
IOIndex_t inputIdx = 0;
for (auto parent : node->inputs()) {
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() != output->dataType()
|| (input->getImpl()->device() != device))
{
// A conversion Operator is needed
auto convert = Convert();
convert->addChild(node, 0, inputIdx);
parent.first->addChild(convert, parent.second, 0);
NodePtr moveOp = nullptr;
NodePtr castOp = nullptr;
if (input->getImpl()->device() != device) {
// Change of backend => a Move operator is required
moveOp = Move();
moveOp->getOperator()->setDataType(input->dataType());
castOp = moveOp;
}
if (input->dataType() != output->dataType()) {
// Change of date type => a Cast operator is required
castOp = Cast();
castOp->getOperator()->setDataType(output->dataType());
castOp->getOperator()->setBackend(device.first, device.second);
if (moveOp == nullptr) {
moveOp = castOp;
}
else {
moveOp->addChild(castOp, 0, 0);
}
}
if (moveOp != nullptr && castOp != nullptr) {
// Move and/or Cast Operator(s) are needed
castOp->addChild(node, 0, inputIdx);
parent.first->addChild(moveOp, parent.second, 0);
// Set backend AFTER connection in case a specific implementation
// of the operator exists for the input type.
convert->getOperator()->setBackend(device.first, device.second);
moveOp->getOperator()->setBackend(device.first, device.second);
// Add/update nodes in the GraphView
graph->add(convert);
graph->add(moveOp);
graph->add(castOp);
graph->add(parent.first);
graph->add(node);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment