From c85ae1703e5c1d1d78f19d1ebaaede1db6ba1bf4 Mon Sep 17 00:00:00 2001
From: Jerome Hue <jerome.hue@cea.fr>
Date: Thu, 14 Nov 2024 12:19:29 +0100
Subject: [PATCH] feat: Stack operator

---
 include/aidge/operator/Stack.hpp       | 88 ++++++++++++++++++++++++++
 src/operator/Stack.cpp                 | 86 +++++++++++++++++++++++++
 unit_tests/operator/Test_StackImpl.cpp | 47 ++++++++++++++
 3 files changed, 221 insertions(+)
 create mode 100644 include/aidge/operator/Stack.hpp
 create mode 100644 src/operator/Stack.cpp
 create mode 100644 unit_tests/operator/Test_StackImpl.cpp

diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp
new file mode 100644
index 000000000..24ac075cf
--- /dev/null
+++ b/include/aidge/operator/Stack.hpp
@@ -0,0 +1,88 @@
+#ifndef AIDGE_CORE_OPERATOR_STACK_H_
+#define AIDGE_CORE_OPERATOR_STACK_H_
+
+#include <memory>
+#include <string>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/StaticAttributes.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+
+class StackOpImpl : public OperatorImpl {
+  public:
+    StackOpImpl(const Operator &op, const std::string &backend = "")
+        : OperatorImpl(op, backend) {}
+    void forward() override;
+};
+
+enum class StackAttr { MaxElements };
+
+class StackOp
+    : public OperatorTensor,
+      public Registrable<
+          StackOp,
+          std::string,
+          std::function<std::unique_ptr<OperatorImpl>(const StackOp &)>> {
+
+  private:
+    using Attributes_ = StaticAttributes<StackAttr, std::uint32_t>;
+    template <StackAttr e> using attr = typename Attributes_::template attr<e>;
+    const std::shared_ptr<Attributes_> mAttributes;
+
+  public:
+    static const std::string s_type;
+
+    StackOp(std::uint32_t maxElements);
+
+    /**
+     * @brief Copy-constructor. Copy the operator attributes and its output
+     * tensor(s), but not its input tensors (the new operator has no input
+     * associated).
+     * @param op Operator to copy.
+     */
+    StackOp(const StackOp &op);
+
+    /**
+     * @brief Clone the operator using its copy-constructor.
+     * @see Operator::StackOp
+     */
+    std::shared_ptr<Operator> clone() const override;
+
+    void setBackend(const std::string &name,
+                    DeviceIdx_t device = 0) override final;
+
+    std::set<std::string> getAvailableBackends() const override;
+
+    bool forwardDims(bool allowDataDependency = false) override final;
+    void forward() override;
+
+    inline std::shared_ptr<Attributes> attributes() const override {
+        return mAttributes;
+    }
+    inline std::uint32_t &maxElements() const {
+        return mAttributes->template getAttr<StackAttr::MaxElements>();
+    }
+
+    static const std::vector<std::string> getInputsName() {
+        return {"data_input"};
+    }
+    static const std::vector<std::string> getOutputsName() {
+        return {"data_output"};
+    }
+};
+
+std::shared_ptr<Node> stack(std::uint32_t maxElements,
+                            const std::string &name = "");
+} // namespace Aidge
+
+namespace {
+template <>
+const char *const EnumStrings<Aidge::StackAttr>::data[] = {"max_elements"};
+}
+
+#endif
diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp
new file mode 100644
index 000000000..9bdfadd08
--- /dev/null
+++ b/src/operator/Stack.cpp
@@ -0,0 +1,86 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/operator/Stack.hpp"
+
+#include <memory>
+#include <string>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/StaticAttributes.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+
+const std::string StackOp::s_type = "Stack";
+
+void StackOpImpl::forward() {
+    const StackOp &op = dynamic_cast<const StackOp &>(mOp);
+    assert(op.getInput(0) && "missing input #0");
+    //*op.getOutput(0) = op.getInput(0)->extract({op.forwardStep()});
+}
+
+StackOp::StackOp(std::uint32_t maxElements)
+    : OperatorTensor(s_type, {InputCategory::Data}, 1),
+      mAttributes(std::make_shared<Attributes_>(
+          attr<StackAttr::MaxElements>(maxElements))) {
+    mImpl = std::make_shared<StackOpImpl>(*this);
+}
+
+StackOp::StackOp(const Aidge::StackOp &op)
+    : OperatorTensor(op), mAttributes(op.mAttributes) {
+    if (!op.backend().empty()) {
+        SET_IMPL_MACRO(StackOp, *this, op.backend());
+    } else {
+        mImpl = std::make_shared<StackOpImpl>(*this);
+    }
+}
+
+std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const {
+    return std::make_shared<StackOp>(*this);
+}
+
+bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) {
+    if (inputsAssociated()) {
+        auto inputDims = getInput(0)->dims();
+        inputDims.insert(inputDims.begin(), maxElements());
+        getOutput(0)->resize(inputDims);
+        return true;
+    }
+
+    return false;
+}
+
+void StackOp::setBackend(const std::string &name, DeviceIdx_t device) {
+    if (Registrar<StackOp>::exists({name})) {
+        SET_IMPL_MACRO(StackOp, *this, name);
+    } else {
+        mImpl = std::make_shared<StackOpImpl>(*this);
+    }
+    mOutputs[0]->setBackend(name, device);
+}
+
+std::set<std::string> StackOp::getAvailableBackends() const {
+    return Registrar<StackOp>::getKeys();
+}
+
+void StackOp::forward() {
+    Operator::forward();
+}
+
+std::shared_ptr<Node> stack(std::uint32_t maxElements,
+                            const std::string &name) {
+    return std::make_shared<Node>(std::make_shared<StackOp>(maxElements),
+                                  name);
+}
+} // namespace Aidge
diff --git a/unit_tests/operator/Test_StackImpl.cpp b/unit_tests/operator/Test_StackImpl.cpp
new file mode 100644
index 000000000..7652c03a7
--- /dev/null
+++ b/unit_tests/operator/Test_StackImpl.cpp
@@ -0,0 +1,47 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/generators/catch_generators_random.hpp>
+#include <catch2/matchers/catch_matchers_string.hpp>
+#include <cstddef>
+#include <memory>
+#include <random> 
+#include <vector>
+
+#include "aidge/operator/Stack.hpp"
+#include "aidge/data/Tensor.hpp"
+
+using Catch::Matchers::Equals;
+
+namespace Aidge {
+TEST_CASE("[core/operator] Stack(forwardDims)", "[Stack][forwardDims]") {
+
+    auto rd = Catch::Generators::Detail::getSeed;
+    std::mt19937 gen(rd());
+    std::uniform_int_distribution<std::size_t> dimsDist(1, 10);
+
+    auto maxElementsToStack = dimsDist(gen);
+
+    auto stackNode = stack(maxElementsToStack);
+    auto op = std::dynamic_pointer_cast<StackOp>(stackNode->getOperator());
+    std::shared_ptr<Tensor> t0 = std::make_shared<Tensor>(Aidge::Array1D<int,3>{{4,5,6}});
+
+    // input #0 should be associated with a Tensor
+    REQUIRE_THROWS_WITH(op->forwardDims(), Equals("Stack: input #0 should be associated with a Tensor"));
+
+    op->associateInput(0, t0);
+    REQUIRE_NOTHROW(op->forwardDims());
+
+    auto dims = op->getOutput(0)->dims();
+    REQUIRE(dims[0] == maxElementsToStack);
+}
+}  // namespace Aidge
-- 
GitLab