Skip to content
Snippets Groups Projects
Commit d12e93cd authored by Jerome Hue's avatar Jerome Hue Committed by Maxence Naud
Browse files

Minor changes (doc, functions signatures)

parent cada2936
No related branches found
No related tags found
1 merge request!351feat: add rate spikegen for snns
......@@ -12,9 +12,19 @@
#ifndef AIDGE_CORE_DATA_SPIKEGEN_H_
#define AIDGE_CORE_DATA_SPIKEGEN_H_
#include <cstdint>
#include "aidge/data/Tensor.hpp"
namespace Aidge {
std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor);
/*
* @brief Spike rate encoding of input data
*/
Tensor spikegenRate(std::shared_ptr<Tensor> tensor, std::uint32_t numSteps);
Tensor spikegenLatency(std::shared_ptr<Tensor> tensor);
}
......
......@@ -15,11 +15,10 @@
#include "aidge/data/Spikegen.hpp"
namespace Aidge {
std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) {
Tensor rateConvert(const Tensor& tensor) {
auto result = std::make_shared<Tensor>(tensor.clone());
auto result = tensor.clone();
// Bernoulli sampling
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0.0f, 1.0f);
......@@ -29,17 +28,14 @@ std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) {
auto val = tensor.get<float>(i);
val = (val < 0.0f) ? 0.0f : ((val > 1.0f) ? 1.0f : val);
auto randomValue = dis(gen);
result->set(i, randomValue < val ? 1.0f : 0.0f);
result.set(i, randomValue < val ? 1.0f : 0.0f);
}
return result;
}
std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor) {
auto newTensor = tensor->repeat(10);
newTensor.print(); // DEBUG
Tensor spikegenRate(std::shared_ptr<Tensor> tensor, std::uint32_t numSteps) {
auto newTensor = tensor->repeat(numSteps);
return rateConvert(newTensor);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment