Skip to content
Snippets Groups Projects
Commit 406a5635 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix HorizontalTiling recipe

parent ce6d5a01
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!93Change Gather and Slice's attributes into intputs
Pipeline #40956 passed
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Add.hpp" #include "aidge/operator/Add.hpp"
#include "aidge/operator/Concat.hpp" #include "aidge/operator/Concat.hpp"
...@@ -82,21 +83,47 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -82,21 +83,47 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis])); newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis]));
clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2); clonedInputs[2] -> addChild(newNode, 0, 2);
// Slice for input and each parameter
std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size()); auto backend = outTensor->getImpl()->backend();
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis]));
inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1; auto sliceInputsNames = slice->getOperator()->getInputsName();
} // Create Slice's Starts producer node
std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size()); std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]); inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
} }
const std::shared_ptr<Tensor> starts = std::make_shared<Tensor>();
starts -> setDataType(DataType::Int64);
starts -> setBackend(backend);
starts -> resize(std::vector<std::size_t>({inputDimsStart.size()}));
starts -> getImpl() -> setRawPtr(inputDimsStart.data(), inputDimsStart.size());
auto startsNode = Producer(starts, sliceInputsNames[1]);
startsNode -> addChild(slice, 0, 1);
std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]);
}
const std::shared_ptr<Tensor> ends = std::make_shared<Tensor>();
ends -> setDataType(DataType::Int64);
ends -> setBackend(backend);
ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()}));
ends -> getImpl() -> setRawPtr(inputDimsEnd.data(), inputDimsEnd.size());
auto endsNode = Producer(ends, sliceInputsNames[2]);
endsNode -> addChild(slice, 0, 2);
// Create Slice's Axes producer node
std::vector<std::int64_t> usedDims(inputDimsEnd.size()); std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0)); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
Tensor(std::vector<std::size_t>({inputDimsStart.size()})); Tensor(std::vector<std::size_t>({inputDimsStart.size()}));
// TODO create producer nodes for the attributes const std::shared_ptr<Tensor> axes = std::make_shared<Tensor>();
// auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); axes -> setDataType(DataType::Int64);
auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis])); axes -> setBackend(backend);
axes -> resize(std::vector<std::size_t>({usedDims.size()}));
axes -> getImpl() -> setRawPtr(usedDims.data(), usedDims.size());
auto axesNode = Producer(axes, sliceInputsNames[3]);
axesNode -> addChild(slice, 0, 3);
slice -> addChild(newNode, 0, 0); slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i); newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment