diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 7cd8c67262221fbf9c1b2415ebf98db56274cce5..9e1f0335edd47ae1405fc9f9949340c13a67221f 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -125,9 +125,9 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float32( static Registrar<Tensor> registrarTensorImpl_cpu_Float16( {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); static Registrar<Tensor> registrarTensorImpl_cpu_Int64( - {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<long>::create); + {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create); static Registrar<Tensor> registrarTensorImpl_cpu_Int32( - {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create); + {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); static Registrar<Tensor> registrarTensorImpl_cpu_Int16( {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); static Registrar<Tensor> registrarTensorImpl_cpu_UInt16( diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 6680f2e1d6de5157024f9e7ca65b14256e53eae2..a04e4be69c9fd1a6ed7753ed512c7f5e45b925d9 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -79,6 +79,7 @@ public: return std::make_shared<Gather_Op>(*this); } + bool dimsForwarded() const override final; bool forwardDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 01c32004ec20babcef74471c527e3831a0dcc32c..12fbda88b0044f836b298e0cf818724f53f821a7 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -76,6 +76,7 @@ public: return std::make_shared<Reshape_Op>(*this); } + bool dimsForwarded() const override final; bool forwardDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 57a6aa2eafede5c5d0e64819b16f6a186de38306..c8f16bb1ad769299a89d3f8a05e46960fe824711 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -72,6 +72,7 @@ public: */ std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } + bool dimsForwarded() const override final; bool forwardDims(bool allowDataDependency = false) override final; void setBackend(const std::string &name, DeviceIdx_t device = 0) override; diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 4e5bd2573a0e1b0cc78256a68dad88332877067b..b0b9a0e84882cae55a9a3c336684d43e208cb503 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -51,64 +51,61 @@ void Aidge::Gather_OpImpl::forward() { const std::string Aidge::Gather_Op::Type = "Gather"; -bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) { +bool Aidge::Gather_Op::dimsForwarded() const { + if (getInput(1) && !getInput(1)->empty()) { + // output dims are data dependent + return false; + } + + return OperatorTensor::dimsForwarded(); +} + +bool Aidge::Gather_Op::forwardDims(bool allowDataDependency) { // check data input has been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } - if (!getInput(0)->empty()) { - if (this->template getAttr<GatherAttr::Indices>().empty()) - { - if(getInput(1)->empty()) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type()); - } - this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims(); - this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs - this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size()); - switch (mInputs[1]->dataType()) { - case DataType::Float64: - std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - case DataType::Float32: - std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - case DataType::Int64: - std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - case DataType::Int32: - std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); - break; - } + if (getInput(0)->empty()) { + return false; + } + + if (getInput(1) && !getInput(1)->empty()) { + if (!this->template getAttr<GatherAttr::Indices>().empty()) { + Log::notice("Gather_Op: ignoring non-empty Indices attribute because input#1 takes precedence"); } - std::vector<DimSize_t> outDims = getInput(0)->dims(); - std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? - this->template getAttr<GatherAttr::Axis>(): - this->template getAttr<GatherAttr::Axis>()+outDims.size(); - outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - if( !this->template getAttr<GatherAttr::GatheredShape>().empty()) - { - outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), - this->template getAttr<GatherAttr::GatheredShape>().begin(), - this->template getAttr<GatherAttr::GatheredShape>().end()); + if (!allowDataDependency) { + Log::warn("Gather_Op: unable to forwardDims() because output dims are data dependent on input#1"); + return false; } - mOutputs[0]->resize(outDims); - return true; + + std::shared_ptr<Tensor> fallback; + this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims(); + this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs + this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size()); + const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(indices.getImpl()->hostPtr()), + indices.size(), + std::back_inserter(this->template getAttr<GatherAttr::Indices>())); } - return false; + AIDGE_ASSERT(!this->template getAttr<GatherAttr::Indices>().empty(), "Missing input#1 or Indices attribute"); + + std::vector<DimSize_t> outDims = getInput(0)->dims(); + + std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? + this->template getAttr<GatherAttr::Axis>(): + this->template getAttr<GatherAttr::Axis>()+outDims.size(); + outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); + if( !this->template getAttr<GatherAttr::GatheredShape>().empty()) + { + outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), + this->template getAttr<GatherAttr::GatheredShape>().begin(), + this->template getAttr<GatherAttr::GatheredShape>().end()); + } + mOutputs[0]->resize(outDims); + return true; } void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 2a60f580f3279170a0f1ff417cea96ae7cfa981f..25c9deb2adaca65748d7f6981de574d0a674af5d 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -182,7 +182,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { void Aidge::OperatorTensor::forward() { if (!dimsForwarded()) { - forwardDims(); + // Allow data dependent forwardDims at this point (data is available) + forwardDims(true); } Operator::forward(); diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index da379fbc9fa364dde475900bc40a82a5bdec19fb..adbd5fae8a11bfc5009ed4b920d28624db71bb0d 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -30,87 +30,78 @@ void Aidge::Reshape_OpImpl::forward() { const std::string Aidge::Reshape_Op::Type = "Reshape"; -bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { +bool Aidge::Reshape_Op::dimsForwarded() const { + if (getInput(1) && !getInput(1)->empty()) { + // output dims are data dependent + return false; + } + + return OperatorTensor::dimsForwarded(); +} + +bool Aidge::Reshape_Op::forwardDims(bool allowDataDependency) { // check input has been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } - if (!getInput(0)->empty()) { - std::vector<DimSize_t> outDims; - // variables to handle a negative dimension - bool foundNegativeDimension = false; - std::size_t outSize = 1; - DimIdx_t negativeIndex = 0; - - // Fill shape attr if empty - if (this->template getAttr<ReshapeAttr::Shape>().empty()) { - if (!getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type()); - } - if(!getInput(1)->empty()) { - this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs - this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); - switch (mInputs[1]->dataType()) { - case DataType::Float64: - std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - case DataType::Float32: - std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - case DataType::Int64: - std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - case DataType::Int32: - std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported."); - break; - } - } - else { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed"); - } + + if (getInput(0)->empty()) { + return false; + } + + if (getInput(1) && !getInput(1)->empty()) { + if (!this->template getAttr<ReshapeAttr::Shape>().empty()) { + Log::notice("Reshape_Op: ignoring non-empty Shape attribute because input#1 takes precedence"); } - for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) - { - std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; - if (dimSize < 0) { - if (foundNegativeDimension) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); - } - foundNegativeDimension = true; - dimSize = 1; - negativeIndex = static_cast<DimIdx_t>(i); - } - else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>()) - { - dimSize = getInput(0) -> dims()[i]; - } - outDims.push_back(static_cast<DimSize_t>(dimSize)); - if (dimSize != 0) { - outSize *= static_cast<DimSize_t>(dimSize); - } + if (!allowDataDependency) { + Log::warn("Reshape_Op: unable to forwardDims() because output dims are data dependent on input#1"); + return false; } - if (foundNegativeDimension) { - outDims[negativeIndex] = (getInput(0) -> size()) / outSize; + std::shared_ptr<Tensor> fallback; + this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs + this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); + const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(shape.getImpl()->hostPtr()), + shape.size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + } + + AIDGE_ASSERT(!this->template getAttr<ReshapeAttr::Shape>().empty(), "Missing input#1 or Shape attribute"); + + std::vector<DimSize_t> outDims; + // variables to handle a negative dimension + bool foundNegativeDimension = false; + std::size_t outSize = 1; + DimIdx_t negativeIndex = 0; + for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) + { + int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; + if (dimSize < 0) { + if (foundNegativeDimension) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); + } + foundNegativeDimension = true; + dimSize = 1; + negativeIndex = static_cast<DimIdx_t>(i); + } + else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>()) + { + dimSize = getInput(0) -> dims()[i]; } + outDims.push_back(static_cast<DimSize_t>(dimSize)); + if (dimSize != 0) { + outSize *= static_cast<DimSize_t>(dimSize); + } + } - mOutputs[0]->resize(outDims); - return true; + if (foundNegativeDimension) { + outDims[negativeIndex] = (getInput(0) -> size()) / outSize; } - return false; + mOutputs[0]->resize(outDims); + return true; } void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 7ae8cf872899180fc80c991d49922d2ceca744c9..bc888d419987e5d75c9ceb60e7baf8817bca3d2d 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -28,150 +28,147 @@ const std::string Aidge::Slice_Op::Type = "Slice"; -bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { +bool Aidge::Slice_Op::dimsForwarded() const { + if ((getInput(1) && !getInput(1)->empty()) + || (getInput(2) && !getInput(2)->empty()) + || (getInput(3) && !getInput(3)->empty())) + { + // output dims are data dependent + return false; + } + + return OperatorTensor::dimsForwarded(); +} + +bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { // check inputs have been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } - if(!getInput(0)->empty()) - { - if(this->template getAttr<SliceAttr::Starts>().empty() || this->template getAttr<SliceAttr::Ends>().empty() || this->template getAttr<SliceAttr::Axes>().empty()) - { - if(getInput(1)->empty() || getInput(2)->empty() || getInput(3)->empty()) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Starts, Ends and Axes must be provided either as input or attributes", type()); - } + if (getInput(0)->empty()) { + return false; + } - AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType."); - - this->template getAttr<SliceAttr::Starts>().clear(); - this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); - this->template getAttr<SliceAttr::Ends>().clear(); - this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size()); - this->template getAttr<SliceAttr::Axes>().clear(); - this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size()); - switch (mInputs[1]->dataType()) { - case DataType::Float64: - std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - case DataType::Float32: - std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - case DataType::Int64: - std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - case DataType::Int32: - std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Input DataType is not supported.", type()); - break; - } + std::shared_ptr<Tensor> fallback; + + if (getInput(1) && !getInput(1)->empty()) { + if (!this->template getAttr<SliceAttr::Starts>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Starts attribute because input#1 takes precedence"); } - // Fill Steps attr if empty - if(this->template getAttr<SliceAttr::Steps>().empty()) { - // In case the input Steps is not provided, default value is 1 - this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(getInput(1)->size(), 1); - - if (getInput(4) && !getInput(4)->empty()) { - this->template getAttr<SliceAttr::Steps>().clear(); - this->template getAttr<SliceAttr::Steps>().reserve(getInput(1)->size()); - switch (mInputs[1]->dataType()) { - case DataType::Float64: - std::copy_n(static_cast<double*>(mInputs[4]->getImpl()->rawPtr()), - getInput(4)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Steps>())); - break; - case DataType::Float32: - std::copy_n(static_cast<float*>(mInputs[4]->getImpl()->rawPtr()), - getInput(4)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Steps>())); - break; - case DataType::Int64: - std::copy_n(static_cast<std::int64_t*>(mInputs[4]->getImpl()->rawPtr()), - getInput(4)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Steps>())); - break; - case DataType::Int32: - std::copy_n(static_cast<std::int32_t*>(mInputs[4]->getImpl()->rawPtr()), - getInput(4)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Steps>())); - break; - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); - break; - } - } + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#1"); + return false; } - DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); - std::vector<DimSize_t> outDims = getInput(0)->dims(); - for (std::size_t i = 0; i < nbAxes; ++i) { - DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? - static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) : - static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); - DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i]; - if(step == 0) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type()); - } - if(step * (static_cast<int64_t>(end) - static_cast<int64_t>(start)) < 0) { - if(step < 0) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type()); - } - else { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type()); - } + + this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); + const auto& starts = getInput(1)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(starts.getImpl()->hostPtr()), + starts.size(), + std::back_inserter(this->template getAttr<SliceAttr::Starts>())); + } + + AIDGE_ASSERT(!this->template getAttr<SliceAttr::Starts>().empty(), "Missing input#1 or Starts attribute"); + + if (getInput(2) && !getInput(2)->empty()) { + if (!this->template getAttr<SliceAttr::Ends>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Ends attribute because input#2 takes precedence"); + } + + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#2"); + return false; + } + + this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size()); + const auto& ends = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(ends.getImpl()->hostPtr()), + ends.size(), + std::back_inserter(this->template getAttr<SliceAttr::Ends>())); + } + + AIDGE_ASSERT(!this->template getAttr<SliceAttr::Ends>().empty(), "Missing input#2 or Ends attribute"); + + if (getInput(3) && !getInput(3)->empty()) { + if (!this->template getAttr<SliceAttr::Axes>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Axes attribute because input#3 takes precedence"); + } + + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#3"); + return false; + } + + this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size()); + const auto& axes = getInput(3)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu"); + std::copy_n(static_cast<int8_t*>(axes.getImpl()->hostPtr()), + axes.size(), + std::back_inserter(this->template getAttr<SliceAttr::Axes>())); + } + + AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute"); + + if (getInput(4) && !getInput(4)->empty()) { + if (!this->template getAttr<SliceAttr::Steps>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Steps attribute because input#4 takes precedence"); + } + + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#4"); + return false; + } + + this->template getAttr<SliceAttr::Steps>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Steps>().reserve(getInput(4)->size()); + const auto& steps = getInput(4)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(steps.getImpl()->hostPtr()), + steps.size(), + std::back_inserter(this->template getAttr<SliceAttr::Steps>())); + } + // Fill Steps attr if empty + if(this->template getAttr<SliceAttr::Steps>().empty()) { + // In case the input Steps is not provided, default value is 1 + this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(this->template getAttr<SliceAttr::Axes>().size(), 1); + } + + const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); + std::vector<DimSize_t> outDims = getInput(0)->dims(); + for (std::size_t i = 0; i < nbAxes; ++i) { + const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? + static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) : + static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); + const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); + const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); + const std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i]; + + AIDGE_ASSERT(step != 0, "Slice_Op: Step must be a non-zero value!"); + if(step * (static_cast<int64_t>(end) - static_cast<int64_t>(start)) < 0) { + if(step < 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type()); } - const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step))); - // Check if slice length is valid - if (sliceLength > getInput(0)->dims()[axis]) - { - AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type()); } - outDims[axis] = sliceLength; } - mOutputs[0]->resize(outDims); - return true; + + const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step))); + // Check if slice length is valid + if (sliceLength > getInput(0)->dims()[axis]) + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice_Op: ROI of Slice operator out of bounds"); + } + outDims[axis] = sliceLength; } - return false; + mOutputs[0]->resize(outDims); + return true; } void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {