Skip to content
Snippets Groups Projects
Commit b282d8a0 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Robustified explicitCastMove following split in two operators

parent 067796b1
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,8 @@
#include "aidge/operator/Move.hpp"
void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
const auto nodes = graph->getNodes();
// First, remove existing Cast and Move operators, if not needed anymore
auto nodes = graph->getNodes();
for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?)
......@@ -27,71 +28,96 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) {
// Remove existing Cast and Move operators, if not needed anymore
AIDGE_INTERNAL_ASSERT(node->inputs().size() == 1);
const auto parent = node->inputs()[0];
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
// Check parent is not nullptr, as this Operator may be an entry point of the graph without parent
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() == output->dataType()
&& (input->getImpl()->device() == device))
{
// Add direct connection bypassing Cast/Move node
const auto childs = node->outputs()[0];
for (const auto& child : childs) {
parent.first->addChild(child.first, parent.second, child.second);
}
if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType())
|| (node->type() == Move_Op::Type && input->getImpl() != nullptr && input->getImpl()->device() == device))
{
// Add direct connection bypassing Cast/Move node
const auto childs = node->outputs()[0];
for (const auto& child : childs) {
parent.first->addChild(child.first, parent.second, child.second);
}
// Remove all node connections
node->resetConnections();
// Remove all node connections
node->resetConnections();
// Remove node from view
graph->remove(node);
}
}
}
else {
// Insert Cast and/or Move operator between node inputs and parent output, if needed
IOIndex_t inputIdx = 0;
for (auto parent : node->inputs()) {
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
}
NodePtr moveOp = nullptr;
NodePtr castOp = nullptr;
// Note: why two steps and not merge the two node loops?
// User may have changed some data type/backends on top of existing Cast/Move operators
// This may lead to situation where a Cast should be removed but a Move should
// be inserted at the same place. In this case, some conversion may be missed
// depending on the order of iteration over the nodes (which are non ordered!).
if (input->getImpl()->device() != device) {
// Change of backend => a Move operator is required
moveOp = Move();
moveOp->getOperator()->setDataType(input->dataType());
castOp = moveOp;
}
// Second, insert Cast and/or Move operator between node inputs and parent output, if needed
nodes = graph->getNodes();
for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?)
const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
if (output->getImpl() == nullptr) {
continue;
}
const auto& device = output->getImpl()->device();
if (input->dataType() != output->dataType()) {
// Change of date type => a Cast operator is required
castOp = Cast();
castOp->getOperator()->setDataType(output->dataType());
castOp->getOperator()->setBackend(device.first, device.second);
IOIndex_t inputIdx = 0;
for (auto parent : node->inputs()) {
// TODO: possible optimization: currently, a Cast/Move Operator may
// be added several time to the same output, if it has multiple childs,
// even if it is the same conversion each time.
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (moveOp == nullptr) {
moveOp = castOp;
}
else {
moveOp->addChild(castOp, 0, 0);
}
}
NodePtr moveOp = nullptr;
NodePtr castOp = nullptr;
if (node->type() != Move_Op::Type && input->getImpl()->device() != device) {
// Change of backend => a Move operator is required
moveOp = Move();
moveOp->getOperator()->setDataType(input->dataType());
castOp = moveOp;
}
if (moveOp != nullptr && castOp != nullptr) {
// Move and/or Cast Operator(s) are needed
castOp->addChild(node, 0, inputIdx);
parent.first->addChild(moveOp, parent.second, 0);
// Set backend AFTER connection in case a specific implementation
// of the operator exists for the input type.
moveOp->getOperator()->setBackend(device.first, device.second);
if (node->type() != Cast_Op::Type && input->dataType() != output->dataType()) {
// Change of date type => a Cast operator is required
castOp = Cast();
castOp->getOperator()->setDataType(output->dataType());
castOp->getOperator()->setBackend(device.first, device.second);
// Add/update nodes in the GraphView
graph->add(moveOp);
graph->add(castOp);
graph->add(parent.first);
graph->add(node);
if (moveOp == nullptr) {
moveOp = castOp;
}
else {
moveOp->addChild(castOp, 0, 0);
}
}
++inputIdx;
if (moveOp != nullptr && castOp != nullptr) {
// Move and/or Cast Operator(s) are needed
castOp->addChild(node, 0, inputIdx);
parent.first->addChild(moveOp, parent.second, 0);
// Set backend AFTER connection in case a specific implementation
// of the operator exists for the input type.
moveOp->getOperator()->setBackend(device.first, device.second);
// Add/update nodes in the GraphView
graph->add(moveOp);
graph->add(castOp);
graph->add(parent.first);
graph->add(node);
}
}
++inputIdx;
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment