From c8d3a3e4a9483ca4eb2d2f412cc4904a6f599456 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 30 Nov 2023 13:50:32 +0000
Subject: [PATCH] [Add] working prototype for horizontal tiling

---
 include/aidge/recipies/Recipies.hpp |  2 +-
 src/recipies/HorizontalTiling.cpp   | 93 +++++++++++++++++++++++++++++
 2 files changed, 94 insertions(+), 1 deletion(-)
 create mode 100644 src/recipies/HorizontalTiling.cpp

diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp
index 26f4cc9da..a17ead8f8 100644
--- a/include/aidge/recipies/Recipies.hpp
+++ b/include/aidge/recipies/Recipies.hpp
@@ -84,7 +84,7 @@ void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
  */
 void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
 
-// std::set<std::shared_ptr<Node>> getHorizontalTiling(const std::shared_ptr<Node>& node, const DimIdx_t axis, const std::size_t nbSlices);
+std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<Node>& node, const DimIdx_t axis, const std::size_t nbSlices);
 // void horizontalTiling(std::shared_ptr<Node> node, DimIdx_t dim, std::size_t nbSlices);
 // std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
 // void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
diff --git a/src/recipies/HorizontalTiling.cpp b/src/recipies/HorizontalTiling.cpp
new file mode 100644
index 000000000..d8eb01593
--- /dev/null
+++ b/src/recipies/HorizontalTiling.cpp
@@ -0,0 +1,93 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <set>
+#include <memory>
+#include <vector>
+#include <utility>
+
+#include "aidge/recipies/Recipies.hpp"
+
+#include "aidge/graph/Node.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/data/Data.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/operator/Add.hpp"
+#include "aidge/operator/Concat.hpp"
+#include "aidge/operator/Slice.hpp"
+
+// TODO: assert Operator uses Tensors when implemented
+std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std::shared_ptr<Aidge::Node>& node,
+                                                            const Aidge::DimIdx_t axis,
+                                                            const std::size_t nbSlices)
+{
+    if (node->getOperator()->type() != "Conv") {
+        AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
+    }
+    const auto& op = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
+    if (op->nbOutputs() != 1 || op->nbData() > 1) {
+        AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
+    }
+    if (!op->outputDimsForwarded()) {
+        AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling");
+    }
+    // start by doing a tiling with strict dimensions division
+    const auto& outTensor = op->getOutput(0);
+    if (op->getOutput(0)->dims()[axis] % nbSlices != 0) {
+        AIDGE_INTERNAL_ASSERT("axis should be a multiple of nbSlices");
+    }
+
+    // dimensions of a Slice
+    std::vector<DimSize_t> outputDims = outTensor->dims();
+    outputDims[axis] /= nbSlices;
+
+    std::vector<DimSize_t> currentFirstDims = std::vector<DimSize_t>(outTensor->nbDims(), 0);
+
+    std::set<std::shared_ptr<Aidge::Node>> res;
+    auto concat = Concat(nbSlices, axis);
+    res.insert(concat);
+
+    // check slice sizes
+    // const auto inputDims = op->computeReceptiveField(currentFirstDims[axis], outputDims, 0);
+    // std::vector<bool> shareTensor(node->nbInputs(), false);
+    // for (DimSize_t inputID = 0; inputID < node->nbInputs(); ++inputID) {
+    //     const auto inTensor = std::dynamic_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputID));
+    //     if (inTensor->dims() == inputDims[inputID].second)
+    //         shareTensor[inputID] = true;
+    // }
+
+    std::vector<std::shared_ptr<Node>> clonedInputs = std::vector<std::shared_ptr<Node>>(node->nbInputs(), nullptr);
+    for (std::size_t i = node->nbData(); i < node ->nbInputs(); ++i) {
+        clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators();
+        clonedInputs[i] -> setName(node -> name() + "_0");
+        res.insert(clonedInputs[i]);
+    }
+
+    for (; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis]) {
+        const auto inputDims = op->computeReceptiveField(outTensor->getIdx(currentFirstDims), outputDims, 0);
+        auto newNode = node -> clone(); // no input associated to clones
+        newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis]));
+        clonedInputs[1] -> addChild(newNode, 0, 1);
+        clonedInputs[2] -> addChild(newNode, 0, 2);
+        // Slice for input and each parameter
+        auto slice = Slice(inputDims[0].first, inputDims[0].second, "Slice_" + std::to_string(currentFirstDims[axis]));
+        slice -> addChild(newNode, 0, 0);
+        newNode -> addChild(concat, 0, currentFirstDims[axis]);
+
+        res.insert(slice);
+        res.insert(newNode);
+    }
+
+    return res;
+}
\ No newline at end of file
-- 
GitLab