Skip to content
Snippets Groups Projects
Commit 3532ead8 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

Merge branch 'better_inputs_to_attr' of...

Merge branch 'better_inputs_to_attr' of gitlab.eclipse.org:eclipse/aidge/aidge_core into fix/add_missing_attr
parents fb95aeb4 697c8533
No related branches found
No related tags found
No related merge requests found
......@@ -125,9 +125,9 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float32(
static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int64(
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<long>::create);
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int16(
{"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt16(
......
......@@ -79,6 +79,7 @@ public:
return std::make_shared<Gather_Op>(*this);
}
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
......@@ -76,6 +76,7 @@ public:
return std::make_shared<Reshape_Op>(*this);
}
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
......
......@@ -72,6 +72,7 @@ public:
*/
std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); }
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
......
......@@ -51,64 +51,61 @@ void Aidge::Gather_OpImpl::forward() {
const std::string Aidge::Gather_Op::Type = "Gather";
bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) {
bool Aidge::Gather_Op::dimsForwarded() const {
if (getInput(1) && !getInput(1)->empty()) {
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Gather_Op::forwardDims(bool allowDataDependency) {
// check data input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!getInput(0)->empty()) {
if (this->template getAttr<GatherAttr::Indices>().empty())
{
if(getInput(1)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type());
}
this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims();
this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
if (getInput(0)->empty()) {
return false;
}
if (getInput(1) && !getInput(1)->empty()) {
if (!this->template getAttr<GatherAttr::Indices>().empty()) {
Log::notice("Gather_Op: ignoring non-empty Indices attribute because input#1 takes precedence");
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
this->template getAttr<GatherAttr::Axis>():
this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if( !this->template getAttr<GatherAttr::GatheredShape>().empty())
{
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx),
this->template getAttr<GatherAttr::GatheredShape>().begin(),
this->template getAttr<GatherAttr::GatheredShape>().end());
if (!allowDataDependency) {
Log::warn("Gather_Op: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
mOutputs[0]->resize(outDims);
return true;
std::shared_ptr<Tensor> fallback;
this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims();
this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(indices.getImpl()->hostPtr()),
indices.size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
}
return false;
AIDGE_ASSERT(!this->template getAttr<GatherAttr::Indices>().empty(), "Missing input#1 or Indices attribute");
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
this->template getAttr<GatherAttr::Axis>():
this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if( !this->template getAttr<GatherAttr::GatheredShape>().empty())
{
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx),
this->template getAttr<GatherAttr::GatheredShape>().begin(),
this->template getAttr<GatherAttr::GatheredShape>().end());
}
mOutputs[0]->resize(outDims);
return true;
}
void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
......@@ -182,7 +182,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
void Aidge::OperatorTensor::forward() {
if (!dimsForwarded()) {
forwardDims();
// Allow data dependent forwardDims at this point (data is available)
forwardDims(true);
}
Operator::forward();
......
......@@ -30,87 +30,78 @@ void Aidge::Reshape_OpImpl::forward() {
const std::string Aidge::Reshape_Op::Type = "Reshape";
bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
bool Aidge::Reshape_Op::dimsForwarded() const {
if (getInput(1) && !getInput(1)->empty()) {
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Reshape_Op::forwardDims(bool allowDataDependency) {
// check input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
// Fill shape attr if empty
if (this->template getAttr<ReshapeAttr::Shape>().empty()) {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
}
if(!getInput(1)->empty()) {
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported.");
break;
}
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
}
if (getInput(0)->empty()) {
return false;
}
if (getInput(1) && !getInput(1)->empty()) {
if (!this->template getAttr<ReshapeAttr::Shape>().empty()) {
Log::notice("Reshape_Op: ignoring non-empty Shape attribute because input#1 takes precedence");
}
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>())
{
dimSize = getInput(0) -> dims()[i];
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
if (dimSize != 0) {
outSize *= static_cast<DimSize_t>(dimSize);
}
if (!allowDataDependency) {
Log::warn("Reshape_Op: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
std::shared_ptr<Tensor> fallback;
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(shape.getImpl()->hostPtr()),
shape.size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
}
AIDGE_ASSERT(!this->template getAttr<ReshapeAttr::Shape>().empty(), "Missing input#1 or Shape attribute");
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>())
{
dimSize = getInput(0) -> dims()[i];
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
if (dimSize != 0) {
outSize *= static_cast<DimSize_t>(dimSize);
}
}
mOutputs[0]->resize(outDims);
return true;
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
}
return false;
mOutputs[0]->resize(outDims);
return true;
}
void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
......@@ -28,150 +28,147 @@
const std::string Aidge::Slice_Op::Type = "Slice";
bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
bool Aidge::Slice_Op::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->empty())
|| (getInput(2) && !getInput(2)->empty())
|| (getInput(3) && !getInput(3)->empty()))
{
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) {
// check inputs have been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if(!getInput(0)->empty())
{
if(this->template getAttr<SliceAttr::Starts>().empty() || this->template getAttr<SliceAttr::Ends>().empty() || this->template getAttr<SliceAttr::Axes>().empty())
{
if(getInput(1)->empty() || getInput(2)->empty() || getInput(3)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Starts, Ends and Axes must be provided either as input or attributes", type());
}
if (getInput(0)->empty()) {
return false;
}
AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType.");
this->template getAttr<SliceAttr::Starts>().clear();
this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
this->template getAttr<SliceAttr::Ends>().clear();
this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size());
this->template getAttr<SliceAttr::Axes>().clear();
this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Input DataType is not supported.", type());
break;
}
std::shared_ptr<Tensor> fallback;
if (getInput(1) && !getInput(1)->empty()) {
if (!this->template getAttr<SliceAttr::Starts>().empty()) {
Log::notice("Slice_Op: ignoring non-empty Starts attribute because input#1 takes precedence");
}
// Fill Steps attr if empty
if(this->template getAttr<SliceAttr::Steps>().empty()) {
// In case the input Steps is not provided, default value is 1
this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(getInput(1)->size(), 1);
if (getInput(4) && !getInput(4)->empty()) {
this->template getAttr<SliceAttr::Steps>().clear();
this->template getAttr<SliceAttr::Steps>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
}
if (!allowDataDependency) {
Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i];
if(step == 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
}
if(step * (static_cast<int64_t>(end) - static_cast<int64_t>(start)) < 0) {
if(step < 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type());
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type());
}
this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
const auto& starts = getInput(1)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(starts.getImpl()->hostPtr()),
starts.size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
}
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Starts>().empty(), "Missing input#1 or Starts attribute");
if (getInput(2) && !getInput(2)->empty()) {
if (!this->template getAttr<SliceAttr::Ends>().empty()) {
Log::notice("Slice_Op: ignoring non-empty Ends attribute because input#2 takes precedence");
}
if (!allowDataDependency) {
Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#2");
return false;
}
this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size());
const auto& ends = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(ends.getImpl()->hostPtr()),
ends.size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
}
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Ends>().empty(), "Missing input#2 or Ends attribute");
if (getInput(3) && !getInput(3)->empty()) {
if (!this->template getAttr<SliceAttr::Axes>().empty()) {
Log::notice("Slice_Op: ignoring non-empty Axes attribute because input#3 takes precedence");
}
if (!allowDataDependency) {
Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#3");
return false;
}
this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size());
const auto& axes = getInput(3)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
std::copy_n(static_cast<int8_t*>(axes.getImpl()->hostPtr()),
axes.size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
}
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute");
if (getInput(4) && !getInput(4)->empty()) {
if (!this->template getAttr<SliceAttr::Steps>().empty()) {
Log::notice("Slice_Op: ignoring non-empty Steps attribute because input#4 takes precedence");
}
if (!allowDataDependency) {
Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#4");
return false;
}
this->template getAttr<SliceAttr::Steps>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Steps>().reserve(getInput(4)->size());
const auto& steps = getInput(4)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(steps.getImpl()->hostPtr()),
steps.size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
}
// Fill Steps attr if empty
if(this->template getAttr<SliceAttr::Steps>().empty()) {
// In case the input Steps is not provided, default value is 1
this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(this->template getAttr<SliceAttr::Axes>().size(), 1);
}
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
const std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i];
AIDGE_ASSERT(step != 0, "Slice_Op: Step must be a non-zero value!");
if(step * (static_cast<int64_t>(end) - static_cast<int64_t>(start)) < 0) {
if(step < 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type());
}
const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step)));
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type());
}
outDims[axis] = sliceLength;
}
mOutputs[0]->resize(outDims);
return true;
const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step)));
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice_Op: ROI of Slice operator out of bounds");
}
outDims[axis] = sliceLength;
}
return false;
mOutputs[0]->resize(outDims);
return true;
}
void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment