From 21d62a07ac2c17bd3872b9a7e9092f3e00278e40 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 23 Feb 2024 10:58:18 +0000
Subject: [PATCH] [Add]p Pototype of faster kernel for arthmetic operators in
 DivImpl

---
 .../aidge/backend/cpu/operator/DivImpl.hpp    |  14 +-
 .../cpu/operator/DivImpl_forward_kernels.hpp  |  55 ++++--
 src/operator/DivImpl.cpp                      | 157 +++++++++++++++---
 3 files changed, 185 insertions(+), 41 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/DivImpl.hpp b/include/aidge/backend/cpu/operator/DivImpl.hpp
index 9d161401..710e288d 100644
--- a/include/aidge/backend/cpu/operator/DivImpl.hpp
+++ b/include/aidge/backend/cpu/operator/DivImpl.hpp
@@ -12,20 +12,21 @@
 #ifndef AIDGE_CPU_OPERATOR_DIVIMPL_H_
 #define AIDGE_CPU_OPERATOR_DIVIMPL_H_
 
+#include <memory>
+#include <tuple>
+#include <vector>
+
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/operator/Div.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
-#include <memory>
-#include <vector>
 
 namespace Aidge {
-// class Div_Op;
 
 // compute kernel registry for forward and backward
 class DivImplForward_cpu
-    : public Registrable<DivImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
+    // : public Registrable<DivImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
+    : public Registrable<DivImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const std::size_t, const void*, const void*,void*)> {
 };
 class DivImplBackward_cpu
     : public Registrable<DivImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
@@ -40,7 +41,8 @@ public:
     }
 
     NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
-    void forward() override;
+
+    void forward() override final;
 };
 
 namespace {
diff --git a/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp
index 494fb5ad..3cdcefa9 100644
--- a/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp
@@ -12,16 +12,46 @@
 #ifndef AIDGE_CPU_OPERATOR_DIVIMPL_FORWARD_KERNEL_H_
 #define AIDGE_CPU_OPERATOR_DIVIMPL_FORWARD_KERNEL_H_
 
+#include <numeric>     // std::accumulate
+#include <cstddef>     // std::size_t
+#include <functional>  // std::multiplies
+
 #include "aidge/utils/Registrar.hpp"
 
 #include "aidge/backend/cpu/data/Broadcasting.hpp"
 #include "aidge/backend/cpu/operator/DivImpl.hpp"
 
 namespace Aidge {
+// template <class I1, class I2, class O>
+// void DivImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims,
+//                                 const std::vector<std::size_t>& input2Dims,
+//                                 const std::vector<std::size_t>& outputDims,
+//                                 const void* input1_,
+//                                 const void* input2_,
+//                                 void* output_) {
+
+//     const I1* input_1 = static_cast<const I1*>(input1_);
+//     const I2* input_2 = static_cast<const I2*>(input2_);
+//     O* output = static_cast<O*>(output_);
+
+//     const std::size_t totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+
+// 	for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex)
+// 	{
+// 		std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, oIndex);
+
+// 		std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
+// 		std::size_t idx2 = getFlattenedIndex(input2Dims, indexes);
+
+//         // TODO assert if input_2 is bad?
+//         output[oIndex] = input_1[idx1] / input_2[idx2];
+//     }
+// }
+
 template <class I1, class I2, class O>
-void DivImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims,
-                                const std::vector<std::size_t>& input2Dims,
-                                const std::vector<std::size_t>& outputDims,
+constexpr void DivImpl_cpu_forward_kernel(const std::size_t input1size_,
+                                const std::size_t input2size_,
+                                const std::size_t output1size_,
                                 const void* input1_,
                                 const void* input2_,
                                 void* output_) {
@@ -30,22 +60,15 @@ void DivImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims,
     const I2* input_2 = static_cast<const I2*>(input2_);
     O* output = static_cast<O*>(output_);
 
-    size_t totalElements = 1;
-    for (size_t dimSize : outputDims) {
-        totalElements *= dimSize;
+// suppose values are contiguous in memory
+    for (std::size_t i = 0; i < output1size_; ++i) {
+        const std::size_t in1_id = (input1size_ != 1) ? i : 0;
+        const std::size_t in2_id = (input2size_ != 1) ? i : 0;
+        output[i] = static_cast<O>(input_1[in1_id] / input_2[in2_id]);
     }
+}
 
-	for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) 
-	{
-		std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex);
-
-		std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
-		std::size_t idx2 = getFlattenedIndex(input2Dims, indexes);
 
-        // TODO assert if input_2 is bad?
-        output[oIndex] = input_1[idx1] / input_2[idx2];
-    }
-}
 
 namespace {
 static Registrar<DivImplForward_cpu> registrarDivImplForward_cpu_Float32(
diff --git a/src/operator/DivImpl.cpp b/src/operator/DivImpl.cpp
index fc6207cf..292a3b56 100644
--- a/src/operator/DivImpl.cpp
+++ b/src/operator/DivImpl.cpp
@@ -9,19 +9,15 @@
  *
  ********************************************************************************/
 
-#include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
+#include <memory>
 #include <vector>
 
-#include "aidge/operator/Div.hpp"
-#include "aidge/utils/Types.h"
 #include "aidge/backend/cpu/data/Broadcasting.hpp"
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
-
 #include "aidge/backend/cpu/operator/DivImpl.hpp"
 #include "aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/utils/Types.h"
 
 Aidge::NbElts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
     // this implementation can be in-place
@@ -29,22 +25,145 @@ Aidge::NbElts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_
 }
 
 void Aidge::DivImpl_cpu::forward() {
+    // Find the correct kernel type
+    // auto kernelFunc = Registrar<DivImplForward_cpu>::create({
+    //     std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
+    //     std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
+    //     std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
+
+    // const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
+    //                                                                std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
+    // const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
+    //                                                                std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims());
+
+
+    // auto a = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
+    // auto b = std::static_pointer_cast<Tensor>(mOp.getRawInput(1));
+
+    // // Call kernel
+    // kernelFunc(inputDims0,
+    //     inputDims1,
+    //     std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
+    //     getCPUPtr(mOp.getRawInput(0)),
+    //     getCPUPtr(mOp.getRawInput(1)),
+    //     getCPUPtr(mOp.getRawOutput(0)));
+
+/////////////////////////////////////////////////////////////////
+
+    // [5,2,1,7] & [2,6,7]
+    // 1. Same number of dimensions -> [5,2,1,7] & [1,2,6,7]
+    // 2. Find the highest equal dimension -> 3
+    //    Exception: if the first diverging dimension is the last one, then -> 4 (dims.size())
+    // 3. Compute the highest number of contiguous data -> 7
+    // 4. Compute stride and offset step for the broadcast mechnism
+    // 5. Call a simple kernel
+
     // Find the correct kernel type
     auto kernelFunc = Registrar<DivImplForward_cpu>::create({
         std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
         std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
 
-    const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
-                                                                   std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
-    const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
-                                                                   std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims());
-
-    // Call kernel
-    kernelFunc(inputDims0,
-        inputDims1,
-        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
-        getCPUPtr(mOp.getRawInput(0)),
-        getCPUPtr(mOp.getRawInput(1)),
-        getCPUPtr(mOp.getRawOutput(0)));
+    // Compute compatible input dimensions
+    std::vector<std::size_t>        dims0   = static_cast<const Div_Op&>(mOp).getInput(0)->dims();
+    std::vector<std::size_t>        dims1   = static_cast<const Div_Op&>(mOp).getInput(1)->dims();
+    const std::vector<std::size_t>& outDims = static_cast<const Div_Op&>(mOp).getOutput(0)->dims();
+
+    // if (dims0 == dims1) {
+    //     const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin(), dims0.cend(), std::size_t(1), std::multiplies<std::size_t>());
+    //     kernelFunc(input0_contiguous_size, input0_contiguous_size, input0_contiguous_size,
+    //                 getCPUPtr(mOp.getRawInput(0)),
+    //                 getCPUPtr(mOp.getRawInput(1)),
+    //                 getCPUPtr(mOp.getRawOutput(0)));
+    //     return;
+    // }
+
+    if (dims0.size() > dims1.size()) {
+        dims1.insert(dims1.cbegin(), dims0.size() - dims1.size(), std::size_t(1));
+    }
+    else if (dims1.size() > dims0.size()) {
+        dims0.insert(dims0.cbegin(), dims1.size() - dims0.size(), std::size_t(1));
+    }
+
+    const std::size_t nbDims = dims0.size();
+
+    // Find the highest equal dimension
+    std::size_t contiguousIdx = nbDims - 1;
+    for (; contiguousIdx+1 > 0; --contiguousIdx) {
+        if (dims0[contiguousIdx] != dims1[contiguousIdx]) {
+            if (contiguousIdx == (nbDims -1)) {
+                if (dims0[contiguousIdx] == 1) {
+                    while ((dims0[contiguousIdx] == 1) && (contiguousIdx+1 > 0)) {
+                        --contiguousIdx;
+                    }
+                }
+                else {
+                    while ((dims1[contiguousIdx] == 1) && (contiguousIdx+1 > 0)) {
+                        --contiguousIdx;
+                    }
+                }
+            }
+            break;
+        }
+    }
+    ++contiguousIdx;
+
+    // Compute the highest number of contiguous data for each Tensor
+    const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin()+contiguousIdx, dims0.cend(), std::size_t(1), std::multiplies<std::size_t>());
+    const std::size_t input1_contiguous_size = std::accumulate(dims1.cbegin()+contiguousIdx, dims1.cend(), std::size_t(1), std::multiplies<std::size_t>());
+    const std::size_t output_contiguous_size = std::accumulate(outDims.cbegin()+contiguousIdx, outDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+
+    // initialize strides to iterate through data because of broadcasting
+    std::size_t *stride_post0;
+    std::size_t *stride_post1;
+    std::int32_t *stride_step0;
+    std::int32_t *stride_step1;
+    if (contiguousIdx > 0) {
+        stride_post0 = new std::size_t[contiguousIdx];
+        stride_post0[contiguousIdx - 1] = 1;
+        stride_post1 = new std::size_t[contiguousIdx];
+        stride_post1[contiguousIdx - 1] = 1;
+        for (std::size_t i = contiguousIdx - 2; i != static_cast<std::size_t>(-1); --i) {
+            stride_post0[i] = stride_post0[i+1]*dims0[i+1];
+            stride_post1[i] = stride_post1[i+1]*dims1[i+1];
+        }
+        stride_step0 = new std::int32_t[contiguousIdx];
+        stride_step1 = new std::int32_t[contiguousIdx];
+        for (std::size_t i = 0; i != contiguousIdx; ++i) {
+            stride_step0[i] = (dims0[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post0[i]) : 1;
+            stride_step1[i] = (dims1[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post1[i]) : 1;
+        }
+    }
+
+    // variables for arrays offsets
+    std::size_t offsetIn0 = 0;
+    std::size_t offsetIn1 = 0;
+    std::size_t offsetOut = 0;
+
+
+    std::size_t dim = contiguousIdx - 1;
+    const std::size_t nbStacks = std::accumulate(outDims.cbegin(), outDims.cbegin() + contiguousIdx, std::size_t(1), std::multiplies<std::size_t>());
+    for (std::size_t stack = 0; stack < nbStacks;) {
+        kernelFunc(input0_contiguous_size, input1_contiguous_size, output_contiguous_size,
+                    getCPUPtr(mOp.getRawInput(0), offsetIn0*input0_contiguous_size),
+                    getCPUPtr(mOp.getRawInput(1), offsetIn1*input1_contiguous_size),
+                    getCPUPtr(mOp.getRawOutput(0), offsetOut*output_contiguous_size));
+        if (++stack < nbStacks) {
+            std::size_t tmp_stack = stack;
+            while(tmp_stack % outDims[dim] == 0) {
+                tmp_stack /= outDims[dim];
+                dim--;
+            }
+            offsetIn0 += stride_step0[dim];
+            offsetIn1 += stride_step1[dim];
+            ++offsetOut;
+            dim = contiguousIdx - 1;
+        }
+    }
+    if (contiguousIdx > 0) {
+        delete[] stride_post0;
+        delete[] stride_post1;
+        delete[] stride_step0;
+        delete[] stride_step1;
+    }
 }
-- 
GitLab