From b2298d78e15ffbe670f75de06716e4058ebd4eec Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 28 May 2024 14:14:25 +0200 Subject: [PATCH] fix Slice node creation in HorizontalTiling --- src/recipes/HorizontalTiling.cpp | 47 ++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp index 342a4afe6..989754930 100644 --- a/src/recipes/HorizontalTiling.cpp +++ b/src/recipes/HorizontalTiling.cpp @@ -90,30 +90,59 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[2] -> addChild(newNode, 0, 2); + auto slice = Slice(); auto backend = outTensor->getImpl()->backend(); - // Create Slice's Starts attribute + // Create Slice's Starts producer node std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]); } - // Create Slice's Ends attribute + const std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(); + starts -> setDataType(DataType::Int64); + starts -> setBackend(backend); + starts -> resize(std::vector<std::size_t>({inputDimsStart.size()})); + starts -> getImpl() -> copyFromHost(inputDimsStart.data(), inputDimsStart.size()); + auto startsNode = Producer(starts, slice->name() + sliceInputsNames[1]); + startsNode -> addChild(slice, 0, 1); + + // Create Slice's Ends producer node std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]); } - - // Create Slice's Axes attribute + const std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(); + ends -> setDataType(DataType::Int64); + ends -> setBackend(backend); + ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()})); + ends -> getImpl() -> copyFromHost(inputDimsEnd.data(), inputDimsEnd.size()); + auto endsNode = Producer(ends, slice->name() + sliceInputsNames[2]); + endsNode -> addChild(slice, 0, 2); + + // Create Slice's Axes producer node std::vector<std::int8_t> usedDims(inputDimsEnd.size()); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0)); + const std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(); + axes -> setDataType(DataType::Int8); + axes -> setBackend(backend); + axes -> resize(std::vector<std::size_t>({usedDims.size()})); + axes -> getImpl() -> copyFromHost(usedDims.data(), usedDims.size()); + auto axesNode = Producer(axes, slice->name() + sliceInputsNames[3]); + axesNode -> addChild(slice, 0, 3); + + // Create Slice's Steps producer node + std::vector<std::int64_t> inputDimsSteps(inputDimsEnd.size(), static_cast<std::int64_t>(1)); + const std::shared_ptr<Tensor> steps = std::make_shared<Tensor>(); + steps -> setDataType(DataType::Int64); + steps -> setBackend(backend); + steps -> resize(std::vector<std::size_t>({inputDimsSteps.size()})); + steps -> getImpl() -> copyFromHost(inputDimsSteps.data(), inputDimsSteps.size()); + auto stepsNode = Producer(steps, slice->name() + sliceInputsNames[4]); + stepsNode -> addChild(slice, 0, 4); - // Create Slice's Steps attribute - std::vector<std::int64_t> steps(inputDimsEnd.size(), static_cast<std::int64_t>(1)); - - auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, steps, "Slice_" + std::to_string(currentFirstDims[axis])); slice -> addChild(newNode, 0, 0); newNode -> addChild(concat, 0, i); - tiledOperator.insert({slice, newNode}); + tiledOperator.insert({slice, newNode, startsNode, endsNode, axesNode, stepsNode}); } return tiledOperator; -- GitLab