From d4d09c91b42e17b0628db6af9fb0b056c04c4235 Mon Sep 17 00:00:00 2001
From: Jerome Hue <jerome.hue@cea.fr>
Date: Fri, 7 Mar 2025 14:20:11 +0100
Subject: [PATCH] chore: Clean and improve the Leaky MetaOperator test

---
 unit_tests/operator/Test_MetaOperator.cpp | 166 +++++++---------------
 1 file changed, 55 insertions(+), 111 deletions(-)

diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index de720f5b..f781e5e2 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -750,155 +750,99 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
 
         std::random_device rd;
         std::mt19937 gen(rd());
-        std::uniform_real_distribution<float> valueDist(
-            0.1f,
-            1.1f); // Random float distribution between 0 and 1
-        std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2),
-                                                               std::size_t(4));
-        std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(3),
-                                                              std::size_t(3));
+        std::uniform_real_distribution<float> valueDist(0.1f,1.1f);
+        std::uniform_int_distribution<std::size_t> dimSizeDist(2,4);
+        std::uniform_int_distribution<std::size_t> nbDimsDist(3,3); // fixed to 3.
         std::uniform_int_distribution<int> boolDist(0, 1);
         std::uniform_real_distribution<float> betaDist(0,1);
+        std::uniform_real_distribution<float> thresholDist(0.1,3);
 
-        const std::size_t nbDims = nbDimsDist(gen);
-        Log::info("Nbdims : {}", nbDims);
-        std::vector<std::size_t> dims;
-        for (std::size_t i = 0; i < nbDims; ++i) {
-            dims.push_back(dimSizeDist(gen));
-        }
-        Log::info("timesteps : {}", dims[0]);
-        Log::info("dimensions : ");
-        for (auto dim : dims) {
-            Log::info("{}", dim);
-        }
-
-        const auto nbTimeSteps = dims[0];
         const auto beta = betaDist(gen);
+        const auto threshold = thresholDist(gen);
+        const auto nbDims = nbDimsDist(gen);
+        std::vector<std::size_t> dims(nbDims);
+        std::generate(dims.begin(), dims.end(), [&]() { return dimSizeDist(gen); });
+        const auto nbTimeSteps = dims[0];
 
-        auto myLeaky = Leaky(nbTimeSteps, beta, 1.0, LeakyReset::Subtraction, "leaky");
-        auto op =
-            std::static_pointer_cast<MetaOperator_Op>(myLeaky->getOperator());
-        // auto stack = Stack(2);
-        auto mem_rec = Stack(nbTimeSteps, "mem_rec");
-        auto spk_rec = Stack(nbTimeSteps, "spk_rec");
-        auto pop = Pop("popinput");
+        auto myLeaky = Leaky(nbTimeSteps, beta, threshold, LeakyReset::Subtraction, "leaky");
+        auto op = std::static_pointer_cast<MetaOperator_Op>(myLeaky->getOperator());
+        auto memoryRecord = Stack(nbTimeSteps, "mem_rec");
+        auto spikeRecord = Stack(nbTimeSteps, "spk_rec");
+        auto pop = Pop("input");
 
-        // Here we test LSTM as it is was flatten in the graph.
-        // We just borrow its micro-graph into our larger myGraph graph.
-        auto myGraph = std::make_shared<GraphView>();
+        auto leakyOutputs = op->getMicroGraph()->getOrderedOutputs();
+        auto leakyInputs = op->getMicroGraph()->getOrderedInputs();
+        pop->addChild(leakyInputs[0].first, 0, 0);
+        leakyOutputs[1].first->addChild(memoryRecord,0,0);
+        leakyOutputs[0].first->addChild(spikeRecord,0,0);
 
-        pop->addChild(op->getMicroGraph()->getOrderedInputs()[0].first, 0, 0);
-        // 0 for mem 1 for stack
-        op->getMicroGraph()->getOrderedOutputs()[1].first->addChild(mem_rec,
-                                                                    0,
-                                                                    0);
-        op->getMicroGraph()->getOrderedOutputs()[0].first->addChild(spk_rec,
-                                                                    0,
-                                                                    0);
-        for (auto node : op->getMicroGraph()->getOrderedOutputs()) {
-            Log::info("name  of output {}", node.first->name());
-        }
-
-        myGraph->add(pop);
+        auto myGraph = std::make_shared<GraphView>();
         myGraph->add(op->getMicroGraph());
-        myGraph->add(mem_rec);
-        myGraph->add(spk_rec);
-        myGraph->save("mg", true, true);
+        myGraph->add({pop, memoryRecord, spikeRecord});
 
-        // 3 outputs
         REQUIRE(myLeaky->nbInputs() == 3);
         REQUIRE(myLeaky->inputCategory(0) == InputCategory::Data);
-        // Two spikes connected to nothing, + the Add node real output
         REQUIRE(myLeaky->nbOutputs() == 4);
 
-        std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
-            Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}},
-                                     {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}});
-
-        // std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(
-        //     Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}},
-        //                              {{2.0, 3.0}, {4.0, 5.0},
-        //                              {6.0, 7.0}}}});
-
-        // Generate input
-        std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>();
-        T0->setDataType(DataType::Float32);
-        T0->setBackend("cpu");
-
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>();
-        expectedOutput->setDataType(DataType::Float32);
-        expectedOutput->setBackend("cpu");
-
         const auto nb_elements =
             std::accumulate(dims.cbegin(),
                             dims.cend(),
                             std::size_t(1),
                             std::multiplies<std::size_t>());
-        float *input = new float[nb_elements];
-        float *result = new float[nb_elements];
+        const auto nbElementsPerTimeStep = nb_elements / dims[0];
 
-        for (std::size_t i = 0; i < nb_elements; ++i) {
-            input[i] = valueDist(gen);
-        }
-        T0->resize(dims);
-        T0->getImpl()->setRawPtr(input, nb_elements);
-        T0->print();
 
-        // Elements popped at each time step
-        auto nbElementsPerTimeStep = nb_elements / dims[0];
+        // Compute the expected result using ad-hoc implementation
 
         // Init
-        for (int i = 0; i < nbElementsPerTimeStep; ++i) {
-            result[i] = input[i];
-        }
-
-        // Reccurence
-        for (int i = 1; i < dims[0]; ++i) {
-            auto offset = nbElementsPerTimeStep * i;
-            auto prev = nbElementsPerTimeStep * (i - 1);
-            for (int j = 0; j < nbElementsPerTimeStep; ++j) {
-                auto reset = (result[prev + j] > 1.0 ? 1 : 0);
-                result[offset + j] =
-                    result[prev + j] * beta + input[offset + j] - reset;
+        auto *input = new float[nb_elements];
+        std::generate_n(input, nb_elements, [&]() { return valueDist(gen); });
+        auto *result = new float[nb_elements];
+        std::copy(input, input + nbElementsPerTimeStep, result);
+
+        // Recurrence calculation for each timestep
+        for (int timestep = 1; timestep < nbTimeSteps; ++timestep) {
+            const auto currentOffset = nbElementsPerTimeStep * timestep;
+            const auto previousOffset = nbElementsPerTimeStep * (timestep - 1);
+
+            for (int element = 0; element < nbElementsPerTimeStep; ++element) {
+                const auto previousValue = result[previousOffset + element];
+                const auto resetValue = (previousValue > threshold) ? threshold : 0;
+
+                result[currentOffset + element] =
+                    previousValue * beta + input[currentOffset + element] - resetValue;
             }
         }
 
+        auto expectedOutput = std::make_shared<Tensor>(DataType::Float32);
+        expectedOutput->setBackend("cpu");
         expectedOutput->resize(dims);
         expectedOutput->getImpl()->setRawPtr(result, nb_elements);
-        Log::info("Expected ouptut : ");
-        expectedOutput->print();
 
-        std::shared_ptr<Tensor> myInit =
-            std::make_shared<Tensor>(Array2D<float, 3, 3>{
-                {{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}});
 
-        auto initMemdims =
-            std::vector<std::size_t>(dims.begin() + 1, dims.end());
-        Log::info("dimensions : ");
-        for (auto dim : initMemdims) {
-            Log::info("{}", dim);
-        }
-        std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
-            Array2D<float, 3, 2>{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}});
+        // Compute the real result using our operator implemenation
+        auto inputTensor = std::make_shared<Tensor>(DataType::Float32);
+        inputTensor->setBackend("cpu");
+        inputTensor->resize(dims);
+        inputTensor->getImpl()->setRawPtr(input, nb_elements);
 
-        std::shared_ptr<Tensor> myInitR =
-            std::make_shared<Tensor>(initMemdims);
-        myInitR->setDataType(DataType::Float32);
-        myInitR->setBackend("cpu");
-        uniformFiller<float>(myInitR, 0, 0);
+        auto memoryInit = std::make_shared<Tensor>(DataType::Float32);
+        memoryInit->setBackend("cpu");
+        memoryInit->resize(std::vector<std::size_t>(dims.begin() + 1, dims.end()));
+        memoryInit->zeros();
 
-        pop->getOperator()->associateInput(0, T0);
-        op->associateInput(1, myInitR);
-        op->associateInput(2, myInitR);
+        pop->getOperator()->associateInput(0, inputTensor);
+        op->associateInput(1, memoryInit);
+        op->associateInput(2, memoryInit);
 
         myGraph->compile("cpu", DataType::Float32);
-
         auto scheduler = SequentialScheduler(myGraph);
         REQUIRE_NOTHROW(scheduler.generateScheduling());
         REQUIRE_NOTHROW(scheduler.forward(true));
 
+        // Compare expected output with actual output
         auto memOp =
-            std::static_pointer_cast<OperatorTensor>(spk_rec->getOperator());
+            std::static_pointer_cast<OperatorTensor>(spikeRecord->getOperator());
         REQUIRE(approxEq<float>(*(memOp->getOutput(0)), *(expectedOutput)));
     }
 }
-- 
GitLab