Skip to content
Snippets Groups Projects
Commit 2bba4f96 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Multiple fixes related to adaptToBackend()

parent dbddbbe7
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!290[Add] support for auto-concatenation and Fix multiple adaptToBackend() issues
......@@ -124,6 +124,12 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView);
*/
void explicitTranspose(std::shared_ptr<GraphView> graphView);
/**
* Replace a single meta operator by its micro graph.
* @return true if node is indeed a meta operator and could be expanded.
*/
bool expandMetaOp(std::shared_ptr<Node> node);
/**
* Flatten the graph by replacing the meta operators by their micro graph.
* @param recursive If true, recursively replace meta operators until there is
......
......@@ -81,6 +81,7 @@ void init_OperatorImpl(py::module& m){
.def(py::init<const DynamicAttributes&>(), py::arg("attr") = DynamicAttributes())
.def(py::init<const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("io"), py::arg("attr") = DynamicAttributes())
.def(py::init<const ImplSpec::IOSpec&, const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes())
.def(py::init<const std::vector<ImplSpec::IOSpec>&, const std::vector<ImplSpec::IOSpec>&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes())
.def("__eq__", static_cast<bool(*)(const ImplSpec&, const ImplSpec&)>(&operator==))
.def("__repr__", [](ImplSpec self){
return fmt::format("{}\n", self);
......
......@@ -250,9 +250,10 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& requiredIOSpec.type != IOSpec.type)
{
const auto cast = Cast(IOSpec.type);
cast->getOperator()->setBackend(node->getOperator()->backend());
cast->addChild(parent, 0, i);
op->getInput(i)->setDataType(IOSpec.type);
op->getInput(i)->setDataType(requiredIOSpec.type);
}
// Input format
......@@ -263,10 +264,11 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(IOSpec.format);
transposeOp->getOperator()->setDataType(IOSpec.type);
transposeOp->getOperator()->setDataType(requiredIOSpec.type);
transposeOp->getOperator()->setBackend(node->getOperator()->backend());
transposeOp->addChild(parent, 0, i);
op->getInput(i)->setDataFormat(IOSpec.format);
op->getInput(i)->setDataFormat(requiredIOSpec.format);
}
// Input dims
......@@ -301,6 +303,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& requiredIOSpec.type != IOSpec.type)
{
const auto cast = Cast(requiredIOSpec.type);
cast->getOperator()->setBackend(node->getOperator()->backend());
parent->addChild(cast, i, 0);
op->getOutput(i)->setDataType(IOSpec.type);
......@@ -315,6 +318,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(requiredIOSpec.format);
transposeOp->getOperator()->setDataType(requiredIOSpec.type);
transposeOp->getOperator()->setBackend(node->getOperator()->backend());
parent->addChild(transposeOp, i, 0);
op->getOutput(i)->setDataFormat(IOSpec.format);
......@@ -340,7 +344,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
}
}
return MetaOperator(std::string("Adapted_" + op->type()).c_str(), getConnectedGraphView(node));
auto adaptedGraph = getConnectedGraphView(node);
if (adaptedGraph->getNodes().size() > 1) {
return MetaOperator(std::string("Adapted_" + op->type()).c_str(), adaptedGraph);
}
else {
return node;
}
}
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const {
......@@ -354,8 +364,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSp
auto adaptation = getAdaptation(availableSpec, requiredSpecs);
if (adaptation) {
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph();
adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size()));
if (adaptation->getOperator()->isAtomic()) {
adaptations.insert(std::make_pair(adaptation, 1));
}
else {
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph();
adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size()));
}
}
}
......
......@@ -538,6 +538,7 @@ void Tensor::copyTranspose(const Tensor& src, const std::vector<DimSize_t>& tran
}
}
AIDGE_ASSERT(mImpl, "Tensor::copyTranspose(): an implementation is required, use setBackend() first!");
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, newDims);
std::vector<size_t> indices(newDims.size(), 0);
......
......@@ -66,12 +66,17 @@ bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) {
std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0);
}
AIDGE_ASSERT(outputDimsOrder().size() == getInput(0)->nbDims(),
"Permutation vector must have the same rank as input tensor.");
AIDGE_ASSERT(outputDimsOrder().size() >= getInput(0)->nbDims(),
"Permutation vector ({}) must have at least the same rank as input tensor ({}).", outputDimsOrder(), getInput(0)->dims());
std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < outputDimsOrder().size(); ++i) {
std::size_t i = 0;
for (; i < getInput(0)->nbDims(); ++i) {
outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]);
}
for (; i < outputDimsOrder().size(); ++i) {
AIDGE_ASSERT(i == outputDimsOrder()[i],
"Permutation vector ({}) must be the identity above the input tensor rank ({}).", outputDimsOrder(), getInput(0)->dims());
}
mOutputs[0]->resize(outputDims);
return true;
}
......
......@@ -33,6 +33,7 @@ void Aidge::adaptToBackend(std::shared_ptr<GraphView> graphView) {
Log::info("Adapted node {} (of type {}) to backend {}",
node->name(), node->type(), impl->backend());
AIDGE_ASSERT(GraphView::replace({node}, {adaptedNode}), "Unable to replace adapted node!");
expandMetaOp(adaptedNode);
}
}
}
......@@ -14,6 +14,21 @@
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/MetaOperator.hpp"
bool Aidge::expandMetaOp(std::shared_ptr<Node> node) {
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator());
if (metaOp != nullptr) {
// Replace meta op by its micro-graph
// graph will be updated accordingly in GraphView::replace()
auto g = std::make_shared<GraphView>();
g->add(node, false);
GraphView::replace(g, metaOp->getMicroGraph());
return true;
}
return false;
}
void Aidge::expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive) {
bool found = false;
const auto nodes = graph->getNodes();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment