diff --git a/include/aidge/data/Spikegen.hpp b/include/aidge/data/Spikegen.hpp index cbb2963f20aaa94f03e73c62f4d0fa47af52bbbc..30fd5decba675d168d7872f03221bfa6664f7f71 100644 --- a/include/aidge/data/Spikegen.hpp +++ b/include/aidge/data/Spikegen.hpp @@ -12,9 +12,19 @@ #ifndef AIDGE_CORE_DATA_SPIKEGEN_H_ #define AIDGE_CORE_DATA_SPIKEGEN_H_ +#include <cstdint> + #include "aidge/data/Tensor.hpp" + namespace Aidge { -std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor); + +/* + * @brief Spike rate encoding of input data + */ +Tensor spikegenRate(std::shared_ptr<Tensor> tensor, std::uint32_t numSteps); + + +Tensor spikegenLatency(std::shared_ptr<Tensor> tensor); } diff --git a/src/data/SpikeGen.cpp b/src/data/SpikeGen.cpp index 977c64de9ad985fa1c78a9537a4ee515d4344f10..91b1ec433c6a56c9097b7b693c05bc3f5cec4f2f 100644 --- a/src/data/SpikeGen.cpp +++ b/src/data/SpikeGen.cpp @@ -15,11 +15,10 @@ #include "aidge/data/Spikegen.hpp" namespace Aidge { -std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) { +Tensor rateConvert(const Tensor& tensor) { - auto result = std::make_shared<Tensor>(tensor.clone()); + auto result = tensor.clone(); - // Bernoulli sampling std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution<float> dis(0.0f, 1.0f); @@ -29,17 +28,14 @@ std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) { auto val = tensor.get<float>(i); val = (val < 0.0f) ? 0.0f : ((val > 1.0f) ? 1.0f : val); auto randomValue = dis(gen); - result->set(i, randomValue < val ? 1.0f : 0.0f); + result.set(i, randomValue < val ? 1.0f : 0.0f); } return result; } -std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor) { - auto newTensor = tensor->repeat(10); - - newTensor.print(); // DEBUG - +Tensor spikegenRate(std::shared_ptr<Tensor> tensor, std::uint32_t numSteps) { + auto newTensor = tensor->repeat(numSteps); return rateConvert(newTensor); } } // namespace Aidge diff --git a/unit_tests/data/Test_Spikegen.cpp b/unit_tests/data/Test_Spikegen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391