From 93942ddc1bdc2cf144811972e094bb4a940fddcf Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 21 May 2024 11:45:48 +0200
Subject: [PATCH] Replaced swich case with refCastFrom()

---
 src/operator/Gather.cpp  | 30 +++--------------
 src/operator/Reshape.cpp | 30 +++--------------
 src/operator/Slice.cpp   | 72 +++++++++++-----------------------------
 3 files changed, 29 insertions(+), 103 deletions(-)

diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp
index 4e5bd2573..adb250154 100644
--- a/src/operator/Gather.cpp
+++ b/src/operator/Gather.cpp
@@ -64,33 +64,13 @@ bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) {
                 AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type());
             }
             this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims();
+            std::shared_ptr<Tensor> fallback;
             this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
             this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
-            switch (mInputs[1]->dataType()) {
-                case DataType::Float64:
-                    std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
-                    break;
-                case DataType::Float32:
-                    std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
-                    break;
-                case DataType::Int64:
-                    std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
-                    break;
-                case DataType::Int32:
-                    std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
-                    break;
-                default:
-                    AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
-                    break;
-            }
+            const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
+            std::copy_n(static_cast<int64_t*>(indices.getImpl()->rawPtr()),
+                        indices.size(),
+                        std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
         }
         std::vector<DimSize_t> outDims = getInput(0)->dims();
 
diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp
index 0cce7a5b9..084f621a6 100644
--- a/src/operator/Reshape.cpp
+++ b/src/operator/Reshape.cpp
@@ -48,33 +48,13 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
                 AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
             }
             if(!getInput(1)->empty()) {
+                std::shared_ptr<Tensor> fallback;
                 this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
                 this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
-                switch (mInputs[1]->dataType()) {
-                    case DataType::Float64:
-                        std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
-                                    getInput(1)->size(),
-                                    std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
-                        break;
-                    case DataType::Float32:
-                        std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
-                                    getInput(1)->size(),
-                                    std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
-                        break;
-                    case DataType::Int64:
-                        std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
-                                    getInput(1)->size(),
-                                    std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
-                        break;
-                    case DataType::Int32:
-                        std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
-                                    getInput(1)->size(),
-                                    std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
-                        break;
-                    default:
-                        AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported.");
-                        break;
-                }
+                const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
+                std::copy_n(static_cast<int64_t*>(shape.getImpl()->rawPtr()),
+                            shape.size(),
+                            std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
             }
             else {
                 AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp
index 76cf64119..e0de68c54 100644
--- a/src/operator/Slice.cpp
+++ b/src/operator/Slice.cpp
@@ -127,61 +127,27 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
 
             AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType.");
 
+            std::shared_ptr<Tensor> fallback;
             this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs
             this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
-            this->template getAttr<SliceAttr::Ends>().clear();
-            this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size());
-            this->template getAttr<SliceAttr::Axes>().clear();
-            this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size());
-            switch (mInputs[1]->dataType()) {
-                case DataType::Float64:
-                    std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
-                    std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()),
-                                getInput(2)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
-                    std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()),
-                                getInput(3)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
-                    break;
-                case DataType::Float32:
-                    std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
-                    std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()),
-                                getInput(2)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
-                    std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()),
-                                getInput(3)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
-                    break;
-                case DataType::Int64:
-                    std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
-                    std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()),
-                                getInput(2)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
-                    std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()),
-                                getInput(3)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
-                    break;
-                case DataType::Int32:
-                    std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
-                                getInput(1)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
-                    std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()),
-                                getInput(2)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
-                    std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()),
-                                getInput(3)->size(),
-                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));                                
-                    break;
-                default:
-                    AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
-                    break;
-            }
+            const auto& starts = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
+            std::copy_n(static_cast<int64_t*>(starts.getImpl()->rawPtr()),
+                        starts.size(),
+                        std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
+
+            this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs
+            this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size());
+            const auto& ends = mInputs[2]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
+            std::copy_n(static_cast<int64_t*>(ends.getImpl()->rawPtr()),
+                        ends.size(),
+                        std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
+
+            this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs
+            this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size());
+            const auto& axes = mInputs[3]->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
+            std::copy_n(static_cast<int8_t*>(axes.getImpl()->rawPtr()),
+                        axes.size(),
+                        std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
         }
 
         DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
-- 
GitLab