Skip to content
Snippets Groups Projects
Commit e0774723 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added ExplicitConvert recipe

parent a7d650ed
No related branches found
No related tags found
No related merge requests found
......@@ -174,6 +174,18 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel
void setBackend(const std::string &name, int device = 0) override {
mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for weight and bias inputs
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
}
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for weight and bias inputs
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -51,7 +51,7 @@ public:
}
void setBackend(const std::string& name, int device = 0) override {
if (Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
if (mInputs[0]->getImpl() && Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
mImpl = Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this);
}
mOutputs[0]->setBackend(name, device);
......
......@@ -89,6 +89,14 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No
// std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
// void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
/**
* Add Convert operators where needed to ensure no conversion needs to be done
* at the Operator level.
*/
void explicitConvert(std::shared_ptr<GraphView> graphView);
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
......@@ -149,6 +149,7 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataType(dataType);
}
/*
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
......@@ -157,4 +158,5 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
getInput(i)->setDataType(dataType);
}
}
*/
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/recipies/Recipies.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Convert.hpp"
void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) {
const auto nodes = graph->getNodes();
for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?)
const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
if (output->getImpl() == nullptr) {
continue;
}
const auto& device = output->getImpl()->device();
if (node->type() == Convert_Op::Type) {
// Remove existing Convert operator, if not needed anymore
const auto& parent = node->inputs()[0];
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() == output->dataType()
&& (input->getImpl()->device() == device))
{
// Add direct connection bypassing Convert node
for (auto child : node->outputs()[0]) {
parent.first->addChild(child.first, parent.second, child.second);
}
// Remove all Convert node connections
node->resetConnections();
}
}
else {
// Insert Convert operator between node inputs and parent output, if needed
IOIndex_t inputIdx = 0;
for (auto parent : node->inputs()) {
if (parent.first != nullptr) {
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if (input->dataType() != output->dataType()
|| (input->getImpl()->device() != device))
{
// A conversion Operator is needed
auto convert = Convert();
convert->addChild(node, 0, inputIdx);
parent.first->addChild(convert, parent.second, 0);
// Set backend AFTER connection in case a specific implementation
// of the operator exists for the input type.
convert->getOperator()->setBackend(device.first, device.second);
// Add/update nodes in the GraphView
graph->add(convert);
graph->add(parent.first);
graph->add(node);
}
}
++inputIdx;
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment