/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include "aidge/operator/Slice.hpp" #include <algorithm> #include <cassert> #include <cstddef> #include <cstdint> #include <string> #include <utility> #include <vector> #include <fmt/format.h> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Data.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" const std::string Aidge::Slice_Op::Type = "Slice"; bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } if(!getInput(0)->empty()) { if(this->template getAttr<SliceAttr::Starts>().empty() || this->template getAttr<SliceAttr::Ends>().empty() || this->template getAttr<SliceAttr::Axes>().empty()) { if(getInput(1)->empty() || getInput(2)->empty() || getInput(3)->empty()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Starts, Ends and Axes must be provided either as input or attributes", type()); } AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType."); this->template getAttr<SliceAttr::Starts>().clear(); this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); this->template getAttr<SliceAttr::Ends>().clear(); this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size()); this->template getAttr<SliceAttr::Axes>().clear(); this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size()); switch (mInputs[1]->dataType()) { case DataType::Float64: std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), getInput(1)->size(), std::back_inserter(this->template getAttr<SliceAttr::Starts>())); std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()), getInput(2)->size(), std::back_inserter(this->template getAttr<SliceAttr::Ends>())); std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()), getInput(3)->size(), std::back_inserter(this->template getAttr<SliceAttr::Axes>())); break; case DataType::Float32: std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), getInput(1)->size(), std::back_inserter(this->template getAttr<SliceAttr::Starts>())); std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()), getInput(2)->size(), std::back_inserter(this->template getAttr<SliceAttr::Ends>())); std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()), getInput(3)->size(), std::back_inserter(this->template getAttr<SliceAttr::Axes>())); break; case DataType::Int64: std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), getInput(1)->size(), std::back_inserter(this->template getAttr<SliceAttr::Starts>())); std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()), getInput(2)->size(), std::back_inserter(this->template getAttr<SliceAttr::Ends>())); std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()), getInput(3)->size(), std::back_inserter(this->template getAttr<SliceAttr::Axes>())); break; case DataType::Int32: std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), getInput(1)->size(), std::back_inserter(this->template getAttr<SliceAttr::Starts>())); std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()), getInput(2)->size(), std::back_inserter(this->template getAttr<SliceAttr::Ends>())); std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()), getInput(3)->size(), std::back_inserter(this->template getAttr<SliceAttr::Axes>())); break; default: AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Input DataType is not supported.", type()); break; } } // Fill Steps attr if empty if(this->template getAttr<SliceAttr::Steps>().empty()) { // In case the input Steps is not provided, default value is 1 this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(getInput(1)->size(), 1); if (getInput(4) && !getInput(4)->empty()) { this->template getAttr<SliceAttr::Steps>().clear(); this->template getAttr<SliceAttr::Steps>().reserve(getInput(1)->size()); switch (mInputs[1]->dataType()) { case DataType::Float64: std::copy_n(static_cast<double*>(mInputs[4]->getImpl()->rawPtr()), getInput(4)->size(), std::back_inserter(this->template getAttr<SliceAttr::Steps>())); break; case DataType::Float32: std::copy_n(static_cast<float*>(mInputs[4]->getImpl()->rawPtr()), getInput(4)->size(), std::back_inserter(this->template getAttr<SliceAttr::Steps>())); break; case DataType::Int64: std::copy_n(static_cast<std::int64_t*>(mInputs[4]->getImpl()->rawPtr()), getInput(4)->size(), std::back_inserter(this->template getAttr<SliceAttr::Steps>())); break; case DataType::Int32: std::copy_n(static_cast<std::int32_t*>(mInputs[4]->getImpl()->rawPtr()), getInput(4)->size(), std::back_inserter(this->template getAttr<SliceAttr::Steps>())); break; default: AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); break; } } } DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); std::vector<DimSize_t> outDims = getInput(0)->dims(); for (std::size_t i = 0; i < nbAxes; ++i) { DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) : static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); if(this->template getAttr<SliceAttr::Steps>()[i] == 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type()); } const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i])); // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) { AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); } outDims[axis] = sliceLength; } mOutputs[0]->resize(outDims); return true; } return false; } void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { SET_IMPL_MACRO(Slice_Op, *this, name); mOutputs[0]->setBackend(name, device); }