/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include "aidge/operator/Slice.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" #include <cassert> #include <cstddef> #include <string> #include <utility> #include <vector> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" void Aidge::Slice_OpImpl::forward() { const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp); const auto inputDims = op.getInput(0)->dims(); auto slicedDims = op.getInput(0)->dims(); std::size_t beginning = 0; DimSize_t nbAxes = op.getAttr<SliceAttr::Axes>().size(); for (std::size_t i = 0; i < nbAxes; ++i) { // For each slice operation get the params and cast them to size_t const std::int64_t axis_ = op.getAttr<SliceAttr::Axes>()[i]; const std::int64_t start_ = op.getAttr<SliceAttr::Starts>()[i]; const std::int64_t end_ = op.getAttr<SliceAttr::Ends>()[i]; const std::size_t axis = axis_ >= 0 ? axis_ : static_cast<std::size_t>(axis_) + inputDims.size(); const std::size_t start = start_ >= 0 ? start_ : start_ + inputDims[axis]; const std::size_t end = end_ >= 0 ? end_ : end_ + inputDims[axis]; std::size_t stride = 1; for (std::size_t j = inputDims.size() - 1; j > axis; --j) stride *= inputDims[j]; beginning += start * stride; const std::size_t sliceLength = end - start + 1; slicedDims[axis] = sliceLength; } const std::size_t nbDims = slicedDims.size(); // for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3} std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims); for (std::size_t i = 0; i < nbDims; ++i) { substractedDims[i] = inputDims[i] - slicedDims[i]; } // for slicedDims = {3,2,2,1}, prodSlicedDims = {12,4,2,1} std::vector<std::size_t> prodSlicedDims = std::vector<std::size_t>(nbDims); std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims + 1); prodSlicedDims[nbDims - 1] = slicedDims[nbDims - 1]; prodInputDims[nbDims - 1] = inputDims[nbDims - 1]; prodInputDims[nbDims] = 1; for (std::size_t i = 2; i <= nbDims; ++i) { prodSlicedDims[nbDims - i] = prodSlicedDims[nbDims - i + 1] * slicedDims[nbDims - i]; prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1] * inputDims[nbDims - i]; } std::size_t j = 0; std::size_t i = 0; for (; j < prodSlicedDims[0];) { op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning + i), 1, j); ++i; ++j; for (std::size_t idx = nbDims - 1; idx > 0; --idx) { i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx] * prodInputDims[idx + 1] : 0; } } } const std::string Aidge::Slice_Op::Type = "Slice"; bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { // check input have been associated if (!getInput(0) || (getInput(0)->empty())) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); std::vector<DimSize_t> outDims = getInput(0)->dims(); for (std::size_t i = 0; i < nbAxes; ++i) { // For each slice operation get the params and cast them to size_t const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims(); const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis]; const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis]; const std::size_t sliceLength = end - start + 1; // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) { AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); } outDims[axis] = sliceLength; } mOutputs[0]->resize(outDims); return true; } void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { if (Registrar<Slice_Op>::exists({name})){ SET_IMPL_MACRO(Slice_Op, *this, name); } else { mImpl = std::make_shared<Slice_OpImpl>(*this); } mOutputs[0]->setBackend(name, device); }