From a1b0a9894bd3225a7d3b019f13d2cdbc19949a79 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Mon, 5 Feb 2024 14:35:42 +0100
Subject: [PATCH] Working concept

---
 include/aidge/backend/cpu.hpp                 |  1 +
 .../backend/cpu/operator/MemorizeImpl.hpp     | 43 ++++++++++
 .../backend/cpu/operator/ProducerImpl.hpp     |  1 -
 src/operator/MemorizeImpl.cpp                 | 80 +++++++++++++++++++
 src/operator/ProducerImpl.cpp                 | 10 ---
 unit_tests/scheduler/Test_Scheduler.cpp       | 53 +++++++++++-
 6 files changed, 175 insertions(+), 13 deletions(-)
 create mode 100644 include/aidge/backend/cpu/operator/MemorizeImpl.hpp
 create mode 100644 src/operator/MemorizeImpl.cpp

diff --git a/include/aidge/backend/cpu.hpp b/include/aidge/backend/cpu.hpp
index f7859805..a0d232f6 100644
--- a/include/aidge/backend/cpu.hpp
+++ b/include/aidge/backend/cpu.hpp
@@ -24,6 +24,7 @@
 #include "aidge/backend/cpu/operator/FCImpl.hpp"
 #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
 #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
+#include "aidge/backend/cpu/operator/MemorizeImpl.hpp"
 #include "aidge/backend/cpu/operator/MulImpl.hpp"
 #include "aidge/backend/cpu/operator/PadImpl.hpp"
 #include "aidge/backend/cpu/operator/PowImpl.hpp"
diff --git a/include/aidge/backend/cpu/operator/MemorizeImpl.hpp b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp
new file mode 100644
index 00000000..c003e7b5
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp
@@ -0,0 +1,43 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_
+#define AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Memorize.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/backend/cpu/data/GetCPUPtr.h"
+#include <memory>
+#include <vector>
+
+namespace Aidge {
+class MemorizeImpl_cpu : public OperatorImpl {
+public:
+    MemorizeImpl_cpu(const Memorize_Op& op) : OperatorImpl(op) {}
+
+    static std::unique_ptr<MemorizeImpl_cpu> create(const Memorize_Op& op) {
+        return std::make_unique<MemorizeImpl_cpu>(op);
+    }
+
+    NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
+    NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
+    void updateConsummerProducer() override final;
+    void forward() override;
+};
+
+namespace {
+static Registrar<Memorize_Op> registrarMemorizeImpl_cpu("cpu", Aidge::MemorizeImpl_cpu::create);
+}
+}  // namespace Aidge
+
+#endif /* AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ */
diff --git a/include/aidge/backend/cpu/operator/ProducerImpl.hpp b/include/aidge/backend/cpu/operator/ProducerImpl.hpp
index c1d27f7e..2e9c90c4 100644
--- a/include/aidge/backend/cpu/operator/ProducerImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ProducerImpl.hpp
@@ -29,7 +29,6 @@ public:
         return std::make_unique<ProducerImpl_cpu>(op);
     }
 
-    NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
     void forward() override;
 };
 
diff --git a/src/operator/MemorizeImpl.cpp b/src/operator/MemorizeImpl.cpp
new file mode 100644
index 00000000..64cb3bcf
--- /dev/null
+++ b/src/operator/MemorizeImpl.cpp
@@ -0,0 +1,80 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+
+#include "aidge/operator/Memorize.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/backend/cpu/data/GetCPUPtr.h"
+
+#include "aidge/backend/cpu/operator/MemorizeImpl.hpp"
+
+Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbRequiredData(
+    Aidge::IOIndex_t inputIdx) const
+{
+    const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
+    const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
+
+    if (scheduleStep == 0 && inputIdx == 0) {
+        // No data input is required for the initial step.
+        // Initialization data is required however.
+        return 0;
+    }
+    else if (scheduleStep > 0 && inputIdx == 1) {
+        // No initialization data is required after the initial step.
+        return 0;
+    }
+    else {
+        return OperatorImpl::getNbRequiredData(inputIdx);
+    }
+}
+
+Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbProducedData(
+    Aidge::IOIndex_t outputIdx) const
+{
+    const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
+    const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
+    const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
+
+    if (outputIdx == 1 && scheduleStep >= endStep) {
+        return endStep * std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size();
+    }
+    else {
+        return scheduleStep * std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size();
+    }
+}
+
+void Aidge::MemorizeImpl_cpu::updateConsummerProducer() {
+    OperatorImpl::updateConsummerProducer();
+
+    const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
+    const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
+    const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
+    AIDGE_ASSERT(scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded");
+}
+
+void Aidge::MemorizeImpl_cpu::forward() {
+    const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
+    const unsigned int forwardStep = op.template getAttr<MemorizeAttr::ForwardStep>();
+    const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
+    AIDGE_ASSERT(forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded");
+
+    if (forwardStep == 0) {
+        op.getOutput(0)->getImpl()->copy(op.getInput(1)->getImpl()->rawPtr(), op.getInput(1)->size());
+    }
+    else {
+        op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size());
+    }
+}
diff --git a/src/operator/ProducerImpl.cpp b/src/operator/ProducerImpl.cpp
index 4c5883a9..2decd955 100644
--- a/src/operator/ProducerImpl.cpp
+++ b/src/operator/ProducerImpl.cpp
@@ -20,16 +20,6 @@
 
 #include "aidge/backend/cpu/operator/ProducerImpl.hpp"
 
-Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData(
-    Aidge::IOIndex_t outputIdx) const
-{
-    // Requires the whole tensors, regardless of available data on inputs
-    assert(outputIdx == 0 && "operator has only one output");
-    (void) outputIdx;
-
-    return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size();
-}
-
 void Aidge::ProducerImpl_cpu::forward()
 {
 }
diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp
index 8ea8e726..e1815725 100644
--- a/unit_tests/scheduler/Test_Scheduler.cpp
+++ b/unit_tests/scheduler/Test_Scheduler.cpp
@@ -205,5 +205,54 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
     SECTION("Test Residual graph") {
     }
 
-    SECTION("Test Recurrent graph") {}
-}
\ No newline at end of file
+    SECTION("Test Recurrent graph") {
+        std::shared_ptr<Tensor> in = std::make_shared<Tensor>(
+                Array2D<int, 2, 3>{{{1, 2, 3}, {4, 5, 6}}});
+        std::shared_ptr<Tensor> initTensor = std::make_shared<Tensor>(
+                Array2D<int, 2, 3>{{{0, 0, 0}, {1, 1, 1}}});
+        std::shared_ptr<Tensor> biasTensor = std::make_shared<Tensor>(
+                Array2D<int, 2, 3>{{{2, 0, 0}, {1, 0, 0}}});
+
+        auto add1 = Add(2, "add1");
+        auto mem = Memorize(3, "mem1");
+        auto add2 = Add(2, "add2");
+        auto bias = Producer(biasTensor, "bias");
+        auto init = Producer(initTensor, "init");
+        init->getOperator()->setBackend("cpu");
+        init->getOperator()->setDataType(Aidge::DataType::Int32);
+
+        std::shared_ptr<GraphView> g = Sequential({add1, mem, add2});
+        init->addChild(mem, 0, 1);
+        mem->addChild(add1, 1, 1);
+        bias->addChild(add2, 0, 1);
+        add1->getOperator()->setInput(0, in);
+        // Update GraphView inputs/outputs following previous connections:
+        g->add(mem);
+        g->add(add1);
+        g->add(add2);
+        //g->add(init);   // not working because of forwardDims()
+        // TODO: FIXME:
+        // forwardDims() starts with inputNodes(). If the initializer
+        // of Memorize is inside the graph, forwardDims() will get stuck
+        // to the node taking the recursive connection because Memorize
+        // output dims must first be computed from the initializer.
+        g->add(bias);
+
+        g->setBackend("cpu");
+        g->setDataType(Aidge::DataType::Int32);
+        g->save("graphRecurrent");
+        g->forwardDims();
+        SequentialScheduler scheduler(g);
+        REQUIRE_NOTHROW(scheduler.forward(true, true));
+        scheduler.saveSchedulingDiagram("schedulingRecurrent");
+
+        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(
+                Array2D<int, 2, 3>{{{5, 6, 9}, {14, 16, 19}}});
+        std::shared_ptr<Tensor> result =
+                std::static_pointer_cast<Tensor>(g->getNode("add2")->getOperator()->getRawOutput(0));
+        result->print();
+        expectedOutput->print();
+        bool equal = (*result == *expectedOutput);
+        REQUIRE(equal);
+    }
+}
-- 
GitLab