diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp index 7e08457bc88c47f756d25e62701cf196b0bde355..b94364f6199bc56291f729ca40d7486f48dd3c19 100644 --- a/src/recipes/HorizontalTiling.cpp +++ b/src/recipes/HorizontalTiling.cpp @@ -23,6 +23,7 @@ #include "aidge/operator/OperatorTensor.hpp" #include "aidge/data/Data.hpp" #include "aidge/utils/Types.h" +#include "aidge/operator/Producer.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/Concat.hpp" @@ -82,21 +83,47 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis])); clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[2] -> addChild(newNode, 0, 2); - // Slice for input and each parameter - std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size()); - for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { - inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1; - } + + auto backend = outTensor->getImpl()->backend(); + auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis])); + auto sliceInputsNames = slice->getOperator()->getInputsName(); + // Create Slice's Starts producer node std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]); } + const std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(); + starts -> setDataType(DataType::Int64); + starts -> setBackend(backend); + starts -> resize(std::vector<std::size_t>({inputDimsStart.size()})); + starts -> getImpl() -> setRawPtr(inputDimsStart.data(), inputDimsStart.size()); + auto startsNode = Producer(starts, sliceInputsNames[1]); + startsNode -> addChild(slice, 0, 1); + + std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size()); + for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { + inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]); + } + const std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(); + ends -> setDataType(DataType::Int64); + ends -> setBackend(backend); + ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()})); + ends -> getImpl() -> setRawPtr(inputDimsEnd.data(), inputDimsEnd.size()); + auto endsNode = Producer(ends, sliceInputsNames[2]); + endsNode -> addChild(slice, 0, 2); + + // Create Slice's Axes producer node std::vector<std::int64_t> usedDims(inputDimsEnd.size()); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0)); - Tensor(std::vector<std::size_t>({inputDimsStart.size()})); - // TODO create producer nodes for the attributes - // auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); - auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis])); + Tensor(std::vector<std::size_t>({inputDimsStart.size()})); + const std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(); + axes -> setDataType(DataType::Int64); + axes -> setBackend(backend); + axes -> resize(std::vector<std::size_t>({usedDims.size()})); + axes -> getImpl() -> setRawPtr(usedDims.data(), usedDims.size()); + auto axesNode = Producer(axes, sliceInputsNames[3]); + axesNode -> addChild(slice, 0, 3); + slice -> addChild(newNode, 0, 0); newNode -> addChild(concat, 0, i);