Skip to content
Snippets Groups Projects
Commit 2bba4f96 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Multiple fixes related to adaptToBackend()

parent dbddbbe7
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!290[Add] support for auto-concatenation and Fix multiple adaptToBackend() issues
...@@ -124,6 +124,12 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView); ...@@ -124,6 +124,12 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView);
*/ */
void explicitTranspose(std::shared_ptr<GraphView> graphView); void explicitTranspose(std::shared_ptr<GraphView> graphView);
/**
* Replace a single meta operator by its micro graph.
* @return true if node is indeed a meta operator and could be expanded.
*/
bool expandMetaOp(std::shared_ptr<Node> node);
/** /**
* Flatten the graph by replacing the meta operators by their micro graph. * Flatten the graph by replacing the meta operators by their micro graph.
* @param recursive If true, recursively replace meta operators until there is * @param recursive If true, recursively replace meta operators until there is
......
...@@ -81,6 +81,7 @@ void init_OperatorImpl(py::module& m){ ...@@ -81,6 +81,7 @@ void init_OperatorImpl(py::module& m){
.def(py::init<const DynamicAttributes&>(), py::arg("attr") = DynamicAttributes()) .def(py::init<const DynamicAttributes&>(), py::arg("attr") = DynamicAttributes())
.def(py::init<const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("io"), py::arg("attr") = DynamicAttributes()) .def(py::init<const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("io"), py::arg("attr") = DynamicAttributes())
.def(py::init<const ImplSpec::IOSpec&, const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes()) .def(py::init<const ImplSpec::IOSpec&, const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes())
.def(py::init<const std::vector<ImplSpec::IOSpec>&, const std::vector<ImplSpec::IOSpec>&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes())
.def("__eq__", static_cast<bool(*)(const ImplSpec&, const ImplSpec&)>(&operator==)) .def("__eq__", static_cast<bool(*)(const ImplSpec&, const ImplSpec&)>(&operator==))
.def("__repr__", [](ImplSpec self){ .def("__repr__", [](ImplSpec self){
return fmt::format("{}\n", self); return fmt::format("{}\n", self);
......
...@@ -250,9 +250,10 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -250,9 +250,10 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& requiredIOSpec.type != IOSpec.type) && requiredIOSpec.type != IOSpec.type)
{ {
const auto cast = Cast(IOSpec.type); const auto cast = Cast(IOSpec.type);
cast->getOperator()->setBackend(node->getOperator()->backend());
cast->addChild(parent, 0, i); cast->addChild(parent, 0, i);
op->getInput(i)->setDataType(IOSpec.type); op->getInput(i)->setDataType(requiredIOSpec.type);
} }
// Input format // Input format
...@@ -263,10 +264,11 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -263,10 +264,11 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format); const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(IOSpec.format); transposeOp->getOperator()->setDataFormat(IOSpec.format);
transposeOp->getOperator()->setDataType(IOSpec.type); transposeOp->getOperator()->setDataType(requiredIOSpec.type);
transposeOp->getOperator()->setBackend(node->getOperator()->backend());
transposeOp->addChild(parent, 0, i); transposeOp->addChild(parent, 0, i);
op->getInput(i)->setDataFormat(IOSpec.format); op->getInput(i)->setDataFormat(requiredIOSpec.format);
} }
// Input dims // Input dims
...@@ -301,6 +303,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -301,6 +303,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& requiredIOSpec.type != IOSpec.type) && requiredIOSpec.type != IOSpec.type)
{ {
const auto cast = Cast(requiredIOSpec.type); const auto cast = Cast(requiredIOSpec.type);
cast->getOperator()->setBackend(node->getOperator()->backend());
parent->addChild(cast, i, 0); parent->addChild(cast, i, 0);
op->getOutput(i)->setDataType(IOSpec.type); op->getOutput(i)->setDataType(IOSpec.type);
...@@ -315,6 +318,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -315,6 +318,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(requiredIOSpec.format); transposeOp->getOperator()->setDataFormat(requiredIOSpec.format);
transposeOp->getOperator()->setDataType(requiredIOSpec.type); transposeOp->getOperator()->setDataType(requiredIOSpec.type);
transposeOp->getOperator()->setBackend(node->getOperator()->backend());
parent->addChild(transposeOp, i, 0); parent->addChild(transposeOp, i, 0);
op->getOutput(i)->setDataFormat(IOSpec.format); op->getOutput(i)->setDataFormat(IOSpec.format);
...@@ -340,7 +344,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -340,7 +344,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
} }
} }
return MetaOperator(std::string("Adapted_" + op->type()).c_str(), getConnectedGraphView(node)); auto adaptedGraph = getConnectedGraphView(node);
if (adaptedGraph->getNodes().size() > 1) {
return MetaOperator(std::string("Adapted_" + op->type()).c_str(), adaptedGraph);
}
else {
return node;
}
} }
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const { std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const {
...@@ -354,8 +364,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSp ...@@ -354,8 +364,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSp
auto adaptation = getAdaptation(availableSpec, requiredSpecs); auto adaptation = getAdaptation(availableSpec, requiredSpecs);
if (adaptation) { if (adaptation) {
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph(); if (adaptation->getOperator()->isAtomic()) {
adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size())); adaptations.insert(std::make_pair(adaptation, 1));
}
else {
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph();
adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size()));
}
} }
} }
......
...@@ -538,6 +538,7 @@ void Tensor::copyTranspose(const Tensor& src, const std::vector<DimSize_t>& tran ...@@ -538,6 +538,7 @@ void Tensor::copyTranspose(const Tensor& src, const std::vector<DimSize_t>& tran
} }
} }
AIDGE_ASSERT(mImpl, "Tensor::copyTranspose(): an implementation is required, use setBackend() first!");
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, newDims); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, newDims);
std::vector<size_t> indices(newDims.size(), 0); std::vector<size_t> indices(newDims.size(), 0);
......
...@@ -66,12 +66,17 @@ bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -66,12 +66,17 @@ bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) {
std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0); std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0);
} }
AIDGE_ASSERT(outputDimsOrder().size() == getInput(0)->nbDims(), AIDGE_ASSERT(outputDimsOrder().size() >= getInput(0)->nbDims(),
"Permutation vector must have the same rank as input tensor."); "Permutation vector ({}) must have at least the same rank as input tensor ({}).", outputDimsOrder(), getInput(0)->dims());
std::vector<DimSize_t> outputDims; std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < outputDimsOrder().size(); ++i) { std::size_t i = 0;
for (; i < getInput(0)->nbDims(); ++i) {
outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]); outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]);
} }
for (; i < outputDimsOrder().size(); ++i) {
AIDGE_ASSERT(i == outputDimsOrder()[i],
"Permutation vector ({}) must be the identity above the input tensor rank ({}).", outputDimsOrder(), getInput(0)->dims());
}
mOutputs[0]->resize(outputDims); mOutputs[0]->resize(outputDims);
return true; return true;
} }
......
...@@ -33,6 +33,7 @@ void Aidge::adaptToBackend(std::shared_ptr<GraphView> graphView) { ...@@ -33,6 +33,7 @@ void Aidge::adaptToBackend(std::shared_ptr<GraphView> graphView) {
Log::info("Adapted node {} (of type {}) to backend {}", Log::info("Adapted node {} (of type {}) to backend {}",
node->name(), node->type(), impl->backend()); node->name(), node->type(), impl->backend());
AIDGE_ASSERT(GraphView::replace({node}, {adaptedNode}), "Unable to replace adapted node!"); AIDGE_ASSERT(GraphView::replace({node}, {adaptedNode}), "Unable to replace adapted node!");
expandMetaOp(adaptedNode);
} }
} }
} }
...@@ -14,6 +14,21 @@ ...@@ -14,6 +14,21 @@
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperator.hpp"
bool Aidge::expandMetaOp(std::shared_ptr<Node> node) {
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator());
if (metaOp != nullptr) {
// Replace meta op by its micro-graph
// graph will be updated accordingly in GraphView::replace()
auto g = std::make_shared<GraphView>();
g->add(node, false);
GraphView::replace(g, metaOp->getMicroGraph());
return true;
}
return false;
}
void Aidge::expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive) { void Aidge::expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive) {
bool found = false; bool found = false;
const auto nodes = graph->getNodes(); const auto nodes = graph->getNodes();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment