From d0b9adfad77ab95e37e71065b42091265c6f1700 Mon Sep 17 00:00:00 2001
From: Jerome Hue <jerome.hue@cea.fr>
Date: Wed, 19 Feb 2025 17:09:29 +0100
Subject: [PATCH 1/2] Add a reset to zero for Leaky operator

Add a `LeakyResetType` enum to differentiate between Subtraction reset
and to zero reset.
---
 include/aidge/operator/MetaOperatorDefs.hpp   |   7 +
 .../operator/pybind_MetaOperatorDefs.cpp      |   1 +
 src/operator/MetaOperatorDefs/Leaky.cpp       | 153 +++++++++---------
 unit_tests/operator/Test_MetaOperator.cpp     |  13 ++
 4 files changed, 100 insertions(+), 74 deletions(-)

diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp
index c4ceccf53..975fcffaa 100644
--- a/include/aidge/operator/MetaOperatorDefs.hpp
+++ b/include/aidge/operator/MetaOperatorDefs.hpp
@@ -305,10 +305,17 @@ std::shared_ptr<Node> LSTM(DimSize_t in_channels,
                            bool noBias = false,
                            const std::string &name = "");
 
+enum class LeakyReset {
+    Subtraction,
+    ToZero
+};
+
+
 std::shared_ptr<MetaOperator_Op> LeakyOp();
 std::shared_ptr<Node> Leaky(const int nbTimeSteps,
                             const float beta,
                             const float threshold = 1.0,
+                            const LeakyReset resetType = LeakyReset::Subtraction,
                             const std::string &name = "");
 
 } // namespace Aidge
diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp
index 2b2cdea12..c55ec533a 100644
--- a/python_binding/operator/pybind_MetaOperatorDefs.cpp
+++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp
@@ -406,6 +406,7 @@ void declare_LeakyOp(py::module &m) {
           py::arg("nb_timesteps"),
           py::arg("beta"),
           py::arg("threshold") = 1.0,
+          py::arg("reset") = LeakyReset::Subtraction,
           py::arg("name") = "",
     R"mydelimiter(
         Initialize a Leaky neuron operator.
diff --git a/src/operator/MetaOperatorDefs/Leaky.cpp b/src/operator/MetaOperatorDefs/Leaky.cpp
index c5e8ab3f1..b5dc65cca 100644
--- a/src/operator/MetaOperatorDefs/Leaky.cpp
+++ b/src/operator/MetaOperatorDefs/Leaky.cpp
@@ -16,95 +16,100 @@ constexpr auto memorizeOpDataOutputRecIndex = 1;
 std::shared_ptr<Node> Leaky(const int nbTimeSteps,
                             const float beta,
                             const float threshold,
+                            const LeakyReset resetType,
                             const std::string &name) {
 
     auto microGraph = std::make_shared<GraphView>();
 
-    auto inputNode = Identity((!name.empty()) ? name + "_input" : "");
-    auto addNode = Add(!name.empty() ? name + "_add" : "");
-    auto mulNode = Mul(!name.empty() ? name + "_mul" : "");
-    auto subNode = Sub(!name.empty() ? name + "_sub" : "");
-    auto hsNode = Heaviside(0, !name.empty() ? name + "_hs" : "");
-    auto subNode2 = Sub(!name.empty() ? name + "_threshold" : "");
-    auto reset = Mul(!name.empty() ? name + "_reset" : "");
+    /*
+     * U[t] = Input[T] + beta * U[T-1] - S[T-1] * U_th
+     * with S[T] = | 1, if U[T] - U_th > 0
+     *             | 0 otherwise
+    */
+
 
+    auto input = Identity((!name.empty()) ? name + "_input" : "");
+    auto decay = Mul(!name.empty() ? name + "_mul" : "");
+    auto spike = Heaviside(0, !name.empty() ? name + "_hs" : "");
+    auto subNode2 = Sub(!name.empty() ? name + "_threshold" : "");
     auto betaTensor = std::make_shared<Tensor>(beta);
     auto uthTensor = std::make_shared<Tensor>(static_cast<float>(threshold));
-    uniformFiller<float>(uthTensor, threshold, threshold);
-
     auto decayRate = Producer(betaTensor, "leaky_beta", true);
     auto uth = Producer(uthTensor, "leaky_uth", true);
+    auto potentialMem = Memorize(nbTimeSteps, (!name.empty()) ? name + "_potential" : "");
+    auto spikeMem = Memorize(nbTimeSteps, (!name.empty()) ? name + "_spike" : "");
+    uniformFiller<float>(uthTensor, threshold, threshold);
 
-    auto potentialMem =
-        Memorize(nbTimeSteps, (!name.empty()) ? name + "_potential" : "");
-    auto spikeMem =
-        Memorize(nbTimeSteps, (!name.empty()) ? name + "_spike" : "");
-
-    // U[t] = Input[T] + beta * U[T-1] - S[T-1] * U_th
-    // with S[T] = | 1, if U[T] - U_th > 0
-    //             | 0 otherwise
-
-    // beta * U[T-1]
-    decayRate->addChild(/*otherNode=*/mulNode, /*outId=*/0, /*otherInId=*/1);
-    potentialMem->addChild(mulNode, 1, 0);
-
-    // Input[T] + beta * U[T-1]
-    mulNode->addChild(/*otherNode=*/addNode, /*outId=*/0, /*otherInId=*/1);
-    inputNode->addChild(/*otherNode=*/addNode, /*outId=*/0, /*otherInId=*/0);
-
-    // S[T-1] * U_th
-    spikeMem->addChild(reset,
-                       /*outId=*/memorizeOpDataOutputRecIndex,
-                       /*otherInId=*/0);
-
-    // TODO(#219) Handle hard/soft reset
-    uth->addChild(reset, 0, 1);
+    // Common connections
+    decayRate->addChild(decay, 0, 1);
+    potentialMem->addChild(decay, 1, 0);
+
+    std::shared_ptr<Node> potentialNode; // Node containing the final potential value
+
+    if (resetType == LeakyReset::Subtraction) {
+        auto decayPlusInput = Add(!name.empty() ? name + "_add" : "");
+        decay->addChild(decayPlusInput, 0, 1);
+        input->addChild(decayPlusInput, 0, 0);
+
+        auto potentialSubReset = Sub(!name.empty() ? name + "_sub" : "");
+        auto reset = Mul(!name.empty() ? name + "_reset" : "");
+
+        spikeMem->addChild(reset, 1, 0);
+        uth->addChild(reset, 0, 1);
+
+        decayPlusInput->addChild(potentialSubReset, 0, 0);
+        reset->addChild(potentialSubReset, 0, 1);
+
+        potentialSubReset->addChild(potentialMem, 0, 0);
+        
+        potentialNode = potentialSubReset;
+        microGraph->add({decayPlusInput, potentialSubReset, reset});
+        
+    } else if (resetType == LeakyReset::ToZero) {
+        auto oneMinusSpike = Sub(!name.empty() ? name + "_one_minus_spike" : "");
+        auto one = Producer(std::make_shared<Tensor>(1.0f), "one", true);
+        auto finalMul = Mul(!name.empty() ? name + "_final" : "");
+        auto decayPlusInput = Add(!name.empty() ? name + "_add" : "");
+
+        one->addChild(oneMinusSpike, 0, 0);
+        spikeMem->addChild(oneMinusSpike, 1, 1);
+
+        oneMinusSpike->addChild(finalMul, 0, 0);
+        decay->addChild(finalMul, 0, 1);
+
+        // finalMul = (1-S[t-1]) * (decay)
+        oneMinusSpike->addChild(finalMul, 0, 1);
+        decay->addChild(finalMul, 0, 1);
+
+        // (1-S[t-1]) * (decay) + WX[t]
+        finalMul->addChild(decayPlusInput, 0, 0);
+        input->addChild(decayPlusInput, 0, 1);
+        
+        decayPlusInput->addChild(potentialMem, 0, 0);
+        potentialNode = decayPlusInput;
+        
+        microGraph->add({oneMinusSpike, one, finalMul, decayPlusInput});
+    }
+
+    // Threshold comparison : (U[t] - Uth)
+    potentialNode->addChild(subNode2, 0, 0);
+    uth->addChild(subNode2, 0, 1);
 
-    // Input[T] + beta * U[T-1] - S[T-1] * U_th
-    addNode->addChild(subNode, 0, 0);
-    reset->addChild(subNode, 0, 1);
+    // heaviside
+    subNode2->addChild(spike, 0, 0);
+    spike->addChild(spikeMem, 0, 0);
 
-    // U[t] = (Input[T] + beta * U[T-1]) - S[T-1]
-    subNode->addChild(potentialMem, 0, 0);
+    microGraph->add(input);
+    microGraph->add({decay, potentialMem, decayRate, 
+                     uth, spikeMem, spike, subNode2}, false);
 
-    // U[T] - U_th
-    subNode->addChild(subNode2, 0, 0);
-    uth->addChild(subNode2, 0, 1);
+    microGraph->setOrderedInputs(
+        {{input, 0}, {potentialMem, 1}, {spikeMem, 1}});
 
-    // with S[T] = | 1, if U[T] - U_th > 0
-    subNode2->addChild(hsNode, 0, 0);
-    hsNode->addChild(spikeMem, 0, 0);
-
-    microGraph->add(inputNode);
-    microGraph->add({addNode,
-                     mulNode,
-                     potentialMem,
-                     decayRate,
-                     uth,
-                     spikeMem,
-                     hsNode,
-                     subNode,
-                     subNode2,
-                     reset},
-                    false);
+    // Use potentialNode for membrane potential output
+    microGraph->setOrderedOutputs({{potentialNode, 0}, {spike, 0}});
 
-    microGraph->setOrderedInputs(
-        {{inputNode, 0}, {potentialMem, 1}, {spikeMem, 1}});
-
-    // NOTE: Outputs are NOT the memory nodes (as it is done in LSTM), to avoid
-    // producing data during init. This way, we can plug an operator after
-    // our node, and get correct results.
-    microGraph->setOrderedOutputs({//{potentialMem, memorizeOpDataOutputIndex},
-                                   //{spikeMem, memorizeOpDataOutputIndex}
-                                   {subNode, 0},
-                                   {hsNode, 0}});
-
-    auto metaOp = MetaOperator(/*type*/ "Leaky",
-                               /*graph*/ microGraph,
-                               /*forcedInputsCategory=*/{},
-                               /*name*/ "leaky");
-
-    return metaOp;
+    return MetaOperator("Leaky", microGraph, {}, name);
 }
 
 std::shared_ptr<MetaOperator_Op> LeakyOp() {
diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index 042b04f01..9c951b268 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -159,4 +159,17 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
         REQUIRE(myLeaky->nbOutputs() == 4);
     	REQUIRE(true);
     }
+
+    SECTION("Leaky(Reset to zero)") {
+        auto myLeaky = Leaky(10, 1.0, 0.9, LeakyReset::ToZero);
+        auto op = std::static_pointer_cast<OperatorTensor>(myLeaky->getOperator());
+
+        auto inputs = myLeaky->inputs();
+
+        // Two memorize nodes + real data input
+        REQUIRE(myLeaky->nbInputs() == 3);
+        // Outputs for spike and memory + 2 Memorize node 
+        REQUIRE(myLeaky->nbOutputs() == 4);
+    	REQUIRE(true);
+    }
 }
-- 
GitLab


From f0e41abc437e1a19fd2d8bb41287ed8fffe8b6c3 Mon Sep 17 00:00:00 2001
From: Jerome Hue <jerome.hue@cea.fr>
Date: Thu, 27 Feb 2025 14:28:35 +0100
Subject: [PATCH 2/2] Add missing python binding for LeakyReset

The binding for the LeakyReset enumeration was missing, which was causing
an error when running python tests.
---
 python_binding/operator/pybind_MetaOperatorDefs.cpp | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp
index c55ec533a..9e266cfe2 100644
--- a/python_binding/operator/pybind_MetaOperatorDefs.cpp
+++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp
@@ -400,6 +400,18 @@ void declare_LSTMOp(py::module &m) {
     )mydelimiter");
 }
 
+void declare_LeakyResetEnum(py::module &m) {
+    py::enum_<LeakyReset>(m, "leaky_reset", R"doc(
+        Enumeration for the Leaky neuron reset mode.
+
+        Subtraction: Membrane potential is subtracted by threshold upon spiking.
+        ToZero     : Membrane potential is forced to 0 upon spiking.
+    )doc")
+        .value("subtraction", LeakyReset::Subtraction)
+        .value("to_zero", LeakyReset::ToZero)
+        .export_values();
+}
+
 
 void declare_LeakyOp(py::module &m) {
     m.def("Leaky", &Leaky,
@@ -441,6 +453,7 @@ void init_MetaOperatorDefs(py::module &m) {
   declare_PaddedMaxPoolingOp<2>(m);
 //   declare_PaddedMaxPoolingOp<3>(m);
   declare_LSTMOp(m);
+  declare_LeakyResetEnum(m);
   declare_LeakyOp(m);
 
   py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperatorOp", py::multiple_inheritance())
-- 
GitLab