From 5c480cffebd20cd497476cef233c5f1eefef762b Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 26 Feb 2025 14:48:17 +0000
Subject: [PATCH] [upd] ConstantOfShape kernel to use Tensor as inputs and
 avoid redundant size computation

---
 .../cpu/operator/ConstantOfShapeImpl.hpp        |  8 +++-----
 .../operator/ConstantOfShapeImpl_kernels.hpp    | 17 ++++-------------
 src/operator/ConstantOfShapeImpl.cpp            |  9 +++------
 3 files changed, 10 insertions(+), 24 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/ConstantOfShapeImpl.hpp b/include/aidge/backend/cpu/operator/ConstantOfShapeImpl.hpp
index 83e7e030..b595ec93 100644
--- a/include/aidge/backend/cpu/operator/ConstantOfShapeImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ConstantOfShapeImpl.hpp
@@ -12,23 +12,21 @@
 #ifndef AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_
 #define AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_
 
-#include <cstddef>
 #include <memory>
-#include <vector>
 
 #include "aidge/backend/cpu/operator/OperatorImpl.hpp"
 #include "aidge/operator/ConstantOfShape.hpp"
 #include "aidge/utils/Registrar.hpp"
-#include "aidge/utils/Types.h"
 
 namespace Aidge {
+
+class Tensor;
 // Operator implementation entry point for the backend
 using ConstantOfShapeImpl_cpu = OperatorImpl_cpu<ConstantOfShape_Op,
-    void(const std::vector<DimSize_t>, const Tensor&, void *)>;
+    void(const std::shared_ptr<Tensor>&, const Tensor&)>;
 
 // Implementation entry point registration to Operator
 REGISTRAR(ConstantOfShape_Op, "cpu", Aidge::ConstantOfShapeImpl_cpu::create);
 } // namespace Aidge
 
 #endif /* _AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_ */
-
diff --git a/include/aidge/backend/cpu/operator/ConstantOfShapeImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConstantOfShapeImpl_kernels.hpp
index 18ab9c0a..c42cc76a 100644
--- a/include/aidge/backend/cpu/operator/ConstantOfShapeImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConstantOfShapeImpl_kernels.hpp
@@ -30,20 +30,11 @@
 namespace Aidge {
 template <class O>
 void ConstantOfShapeimpl_cpu_forward_kernel(
-    const std::vector<DimSize_t> output_dims, const Tensor &value,
-    void *output_) {
+    const std::shared_ptr<Tensor>& output_, const Tensor &value) {
 
-  O *output = static_cast<O *>(output_);
-  O val;
-  std::copy(static_cast<O *>(value.getImpl()->hostPtr()),
-            static_cast<O *>(value.getImpl()->hostPtr()) +
-                static_cast<NbElts_t>(1),
-            &val);
-  const size_t output_size = std::accumulate(
-      output_dims.begin(), output_dims.end(), 1, std::multiplies<DimSize_t>());
-  for (size_t i = 0; i < output_size; ++i) {
-    output[i] = val;
-  }
+  O* output = static_cast<O*>(output_->getImpl()->hostPtr());
+  const O val = *reinterpret_cast<O*>(value.getImpl()->hostPtr());
+  std::fill_n(output, output_->size(), val);
 }
 
 // Kernels registration to implementation entry point
diff --git a/src/operator/ConstantOfShapeImpl.cpp b/src/operator/ConstantOfShapeImpl.cpp
index 16e4b762..1d41160b 100644
--- a/src/operator/ConstantOfShapeImpl.cpp
+++ b/src/operator/ConstantOfShapeImpl.cpp
@@ -13,15 +13,14 @@
 
 #include <functional>
 #include <memory>
-#include <vector>
+#include <stdexcept>   // std::runtime_error
 
 #include "aidge/backend/cpu/operator/ConstantOfShapeImpl_kernels.hpp"
-#include "aidge/data/Data.hpp"
 #include "aidge/data/Tensor.hpp"
 #include "aidge/operator/ConstantOfShape.hpp"
+#include "aidge/backend/OperatorImpl.hpp"  // Aidge::getBestMatch, Aidge::getRequiredSpec
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Registrar.hpp"
-#include "aidge/utils/Types.h"
 
 template <>
 void Aidge::ConstantOfShapeImpl_cpu::forward() {
@@ -33,9 +32,7 @@ void Aidge::ConstantOfShapeImpl_cpu::forward() {
     const auto impl = Registrar<ConstantOfShapeImpl_cpu>::create(getBestMatch(getRequiredSpec()));
 
     // Call kernel
-    impl.forward(op_.getOutput(0)->dims(),
-             op_.value(), 
-             op_.getOutput(0)->getImpl()->rawPtr());
+    impl.forward(op_.getOutput(0), op_.value());
 }
 
 template <>
-- 
GitLab