From d12e93cd3aeaa540c9fc4c37c5eeec77410ebc78 Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Fri, 14 Feb 2025 10:09:35 +0100 Subject: [PATCH] Minor changes (doc, functions signatures) --- include/aidge/data/Spikegen.hpp | 12 +++++++++++- src/data/SpikeGen.cpp | 14 +++++--------- unit_tests/data/Test_Spikegen.cpp | 0 3 files changed, 16 insertions(+), 10 deletions(-) create mode 100644 unit_tests/data/Test_Spikegen.cpp diff --git a/include/aidge/data/Spikegen.hpp b/include/aidge/data/Spikegen.hpp index cbb2963f2..30fd5decb 100644 --- a/include/aidge/data/Spikegen.hpp +++ b/include/aidge/data/Spikegen.hpp @@ -12,9 +12,19 @@ #ifndef AIDGE_CORE_DATA_SPIKEGEN_H_ #define AIDGE_CORE_DATA_SPIKEGEN_H_ +#include <cstdint> + #include "aidge/data/Tensor.hpp" + namespace Aidge { -std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor); + +/* + * @brief Spike rate encoding of input data + */ +Tensor spikegenRate(std::shared_ptr<Tensor> tensor, std::uint32_t numSteps); + + +Tensor spikegenLatency(std::shared_ptr<Tensor> tensor); } diff --git a/src/data/SpikeGen.cpp b/src/data/SpikeGen.cpp index 977c64de9..91b1ec433 100644 --- a/src/data/SpikeGen.cpp +++ b/src/data/SpikeGen.cpp @@ -15,11 +15,10 @@ #include "aidge/data/Spikegen.hpp" namespace Aidge { -std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) { +Tensor rateConvert(const Tensor& tensor) { - auto result = std::make_shared<Tensor>(tensor.clone()); + auto result = tensor.clone(); - // Bernoulli sampling std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution<float> dis(0.0f, 1.0f); @@ -29,17 +28,14 @@ std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) { auto val = tensor.get<float>(i); val = (val < 0.0f) ? 0.0f : ((val > 1.0f) ? 1.0f : val); auto randomValue = dis(gen); - result->set(i, randomValue < val ? 1.0f : 0.0f); + result.set(i, randomValue < val ? 1.0f : 0.0f); } return result; } -std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor) { - auto newTensor = tensor->repeat(10); - - newTensor.print(); // DEBUG - +Tensor spikegenRate(std::shared_ptr<Tensor> tensor, std::uint32_t numSteps) { + auto newTensor = tensor->repeat(numSteps); return rateConvert(newTensor); } } // namespace Aidge diff --git a/unit_tests/data/Test_Spikegen.cpp b/unit_tests/data/Test_Spikegen.cpp new file mode 100644 index 000000000..e69de29bb -- GitLab