diff --git a/include/aidge/data/Spikegen.hpp b/include/aidge/data/Spikegen.hpp index 1267fda9dffd4a49fd37e2f1c65a51072db4a8b4..cbb2963f20aaa94f03e73c62f4d0fa47af52bbbc 100644 --- a/include/aidge/data/Spikegen.hpp +++ b/include/aidge/data/Spikegen.hpp @@ -1,10 +1,21 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + #ifndef AIDGE_CORE_DATA_SPIKEGEN_H_ #define AIDGE_CORE_DATA_SPIKEGEN_H_ -// Spikegen algorithm : -// -// time_data = data.repeat(time_steps) -// spike_data = rate_conv(time_data) -// return spike_data +#include "aidge/data/Tensor.hpp" +namespace Aidge { +std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor); +} + #endif diff --git a/src/data/SpikeGen.cpp b/src/data/SpikeGen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..977c64de9ad985fa1c78a9537a4ee515d4344f10 --- /dev/null +++ b/src/data/SpikeGen.cpp @@ -0,0 +1,45 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> +#include <random> + +#include "aidge/data/Spikegen.hpp" + +namespace Aidge { +std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) { + + auto result = std::make_shared<Tensor>(tensor.clone()); + + // Bernoulli sampling + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<float> dis(0.0f, 1.0f); + + // Clip values between 0 and 1, equivalent to torch.clamp(min=0, max=1) + for (size_t i = 0; i < tensor.size(); i++) { + auto val = tensor.get<float>(i); + val = (val < 0.0f) ? 0.0f : ((val > 1.0f) ? 1.0f : val); + auto randomValue = dis(gen); + result->set(i, randomValue < val ? 1.0f : 0.0f); + } + + return result; +} + +std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor) { + auto newTensor = tensor->repeat(10); + + newTensor.print(); // DEBUG + + return rateConvert(newTensor); +} +} // namespace Aidge