From 93942ddc1bdc2cf144811972e094bb4a940fddcf Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 21 May 2024 11:45:48 +0200 Subject: [PATCH] Replaced swich case with refCastFrom() --- src/operator/Gather.cpp | 30 +++-------------- src/operator/Reshape.cpp | 30 +++-------------- src/operator/Slice.cpp | 72 +++++++++++----------------------------- 3 files changed, 29 insertions(+), 103 deletions(-) diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 4e5bd2573..adb250154 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -64,33 +64,13 @@ bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type()); } this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims(); + std::shared_ptr<Tensor> fallback; this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size()); - switch (mInputs[1]->dataType()) { - case DataType::Float64: - std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - case DataType::Float32: - std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - case DataType::Int64: - std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - case DataType::Int32: - std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - break; - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); - break; - } + const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(indices.getImpl()->rawPtr()), + indices.size(), + std::back_inserter(this->template getAttr<GatherAttr::Indices>())); } std::vector<DimSize_t> outDims = getInput(0)->dims(); diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index 0cce7a5b9..084f621a6 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -48,33 +48,13 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type()); } if(!getInput(1)->empty()) { + std::shared_ptr<Tensor> fallback; this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); - switch (mInputs[1]->dataType()) { - case DataType::Float64: - std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - case DataType::Float32: - std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - case DataType::Int64: - std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - case DataType::Int32: - std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - break; - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported."); - break; - } + const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(shape.getImpl()->rawPtr()), + shape.size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); } else { AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed"); diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 76cf64119..e0de68c54 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -127,61 +127,27 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType."); + std::shared_ptr<Tensor> fallback; this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); - this->template getAttr<SliceAttr::Ends>().clear(); - this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size()); - this->template getAttr<SliceAttr::Axes>().clear(); - this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size()); - switch (mInputs[1]->dataType()) { - case DataType::Float64: - std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - case DataType::Float32: - std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - case DataType::Int64: - std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - case DataType::Int32: - std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), - getInput(1)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()), - getInput(2)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()), - getInput(3)->size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - break; - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); - break; - } + const auto& starts = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(starts.getImpl()->rawPtr()), + starts.size(), + std::back_inserter(this->template getAttr<SliceAttr::Starts>())); + + this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size()); + const auto& ends = mInputs[2]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(ends.getImpl()->rawPtr()), + ends.size(), + std::back_inserter(this->template getAttr<SliceAttr::Ends>())); + + this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size()); + const auto& axes = mInputs[3]->refCastFrom(fallback, NativeType<int8_t>::type, "cpu"); + std::copy_n(static_cast<int8_t*>(axes.getImpl()->rawPtr()), + axes.size(), + std::back_inserter(this->template getAttr<SliceAttr::Axes>())); } DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); -- GitLab