Skip to content
Snippets Groups Projects
Commit b282d8a0 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Robustified explicitCastMove following split in two operators

parent 067796b1
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include "aidge/operator/Move.hpp" #include "aidge/operator/Move.hpp"
void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
const auto nodes = graph->getNodes(); // First, remove existing Cast and Move operators, if not needed anymore
auto nodes = graph->getNodes();
for (auto node : nodes) { for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type. // TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?) // But an Operator might have multiple outputs of different data type(?)
...@@ -27,71 +28,96 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { ...@@ -27,71 +28,96 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) { if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) {
// Remove existing Cast and Move operators, if not needed anymore // Remove existing Cast and Move operators, if not needed anymore
AIDGE_INTERNAL_ASSERT(node->inputs().size() == 1);
const auto parent = node->inputs()[0]; const auto parent = node->inputs()[0];
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); // Check parent is not nullptr, as this Operator may be an entry point of the graph without parent
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() == output->dataType() if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType())
&& (input->getImpl()->device() == device)) || (node->type() == Move_Op::Type && input->getImpl() != nullptr && input->getImpl()->device() == device))
{ {
// Add direct connection bypassing Cast/Move node // Add direct connection bypassing Cast/Move node
const auto childs = node->outputs()[0]; const auto childs = node->outputs()[0];
for (const auto& child : childs) { for (const auto& child : childs) {
parent.first->addChild(child.first, parent.second, child.second); parent.first->addChild(child.first, parent.second, child.second);
} }
// Remove all node connections // Remove all node connections
node->resetConnections(); node->resetConnections();
// Remove node from view
graph->remove(node);
}
} }
} }
else { }
// Insert Cast and/or Move operator between node inputs and parent output, if needed
IOIndex_t inputIdx = 0;
for (auto parent : node->inputs()) {
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
NodePtr moveOp = nullptr; // Note: why two steps and not merge the two node loops?
NodePtr castOp = nullptr; // User may have changed some data type/backends on top of existing Cast/Move operators
// This may lead to situation where a Cast should be removed but a Move should
// be inserted at the same place. In this case, some conversion may be missed
// depending on the order of iteration over the nodes (which are non ordered!).
if (input->getImpl()->device() != device) { // Second, insert Cast and/or Move operator between node inputs and parent output, if needed
// Change of backend => a Move operator is required nodes = graph->getNodes();
moveOp = Move(); for (auto node : nodes) {
moveOp->getOperator()->setDataType(input->dataType()); // TODO: currently, Operator data type is only reflected in its output tensor data type.
castOp = moveOp; // But an Operator might have multiple outputs of different data type(?)
} const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
if (output->getImpl() == nullptr) {
continue;
}
const auto& device = output->getImpl()->device();
if (input->dataType() != output->dataType()) { IOIndex_t inputIdx = 0;
// Change of date type => a Cast operator is required for (auto parent : node->inputs()) {
castOp = Cast(); // TODO: possible optimization: currently, a Cast/Move Operator may
castOp->getOperator()->setDataType(output->dataType()); // be added several time to the same output, if it has multiple childs,
castOp->getOperator()->setBackend(device.first, device.second); // even if it is the same conversion each time.
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (moveOp == nullptr) { NodePtr moveOp = nullptr;
moveOp = castOp; NodePtr castOp = nullptr;
}
else { if (node->type() != Move_Op::Type && input->getImpl()->device() != device) {
moveOp->addChild(castOp, 0, 0); // Change of backend => a Move operator is required
} moveOp = Move();
} moveOp->getOperator()->setDataType(input->dataType());
castOp = moveOp;
}
if (moveOp != nullptr && castOp != nullptr) { if (node->type() != Cast_Op::Type && input->dataType() != output->dataType()) {
// Move and/or Cast Operator(s) are needed // Change of date type => a Cast operator is required
castOp->addChild(node, 0, inputIdx); castOp = Cast();
parent.first->addChild(moveOp, parent.second, 0); castOp->getOperator()->setDataType(output->dataType());
// Set backend AFTER connection in case a specific implementation castOp->getOperator()->setBackend(device.first, device.second);
// of the operator exists for the input type.
moveOp->getOperator()->setBackend(device.first, device.second);
// Add/update nodes in the GraphView if (moveOp == nullptr) {
graph->add(moveOp); moveOp = castOp;
graph->add(castOp); }
graph->add(parent.first); else {
graph->add(node); moveOp->addChild(castOp, 0, 0);
} }
} }
++inputIdx; if (moveOp != nullptr && castOp != nullptr) {
// Move and/or Cast Operator(s) are needed
castOp->addChild(node, 0, inputIdx);
parent.first->addChild(moveOp, parent.second, 0);
// Set backend AFTER connection in case a specific implementation
// of the operator exists for the input type.
moveOp->getOperator()->setBackend(device.first, device.second);
// Add/update nodes in the GraphView
graph->add(moveOp);
graph->add(castOp);
graph->add(parent.first);
graph->add(node);
}
} }
++inputIdx;
} }
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment