Skip to content
Snippets Groups Projects
Commit e0774723 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added ExplicitConvert recipe

parent a7d650ed
No related branches found
No related tags found
1 merge request!57Add Convert operator (a.k.a. Transmitter)
Pipeline #35482 failed
......@@ -174,6 +174,18 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel
void setBackend(const std::string &name, int device = 0) override {
mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for weight and bias inputs
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
}
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for weight and bias inputs
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -51,7 +51,7 @@ public:
}
void setBackend(const std::string& name, int device = 0) override {
if (Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
if (mInputs[0]->getImpl() && Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
mImpl = Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this);
}
mOutputs[0]->setBackend(name, device);
......
......@@ -89,6 +89,14 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No
// std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
// void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
/**
* Add Convert operators where needed to ensure no conversion needs to be done
* at the Operator level.
*/
void explicitConvert(std::shared_ptr<GraphView> graphView);
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
......@@ -149,6 +149,7 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataType(dataType);
}
/*
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
......@@ -157,4 +158,5 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
getInput(i)->setDataType(dataType);
}
}
*/
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/recipies/Recipies.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Convert.hpp"
void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) {
const auto nodes = graph->getNodes();
for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?)
const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
if (output->getImpl() == nullptr) {
continue;
}
const auto& device = output->getImpl()->device();
if (node->type() == Convert_Op::Type) {
// Remove existing Convert operator, if not needed anymore
const auto& parent = node->inputs()[0];
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() == output->dataType()
&& (input->getImpl()->device() == device))
{
// Add direct connection bypassing Convert node
for (auto child : node->outputs()[0]) {
parent.first->addChild(child.first, parent.second, child.second);
}
// Remove all Convert node connections
node->resetConnections();
}
}
else {
// Insert Convert operator between node inputs and parent output, if needed
IOIndex_t inputIdx = 0;
for (auto parent : node->inputs()) {
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() != output->dataType()
|| (input->getImpl()->device() != device))
{
// A conversion Operator is needed
auto convert = Convert();
convert->addChild(node, 0, inputIdx);
parent.first->addChild(convert, parent.second, 0);
// Set backend AFTER connection in case a specific implementation
// of the operator exists for the input type.
convert->getOperator()->setBackend(device.first, device.second);
// Add/update nodes in the GraphView
graph->add(convert);
graph->add(parent.first);
graph->add(node);
}
}
++inputIdx;
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment