Skip to content
Snippets Groups Projects
Commit b38d3d23 authored by Jerome Hue's avatar Jerome Hue
Browse files

First draft of spikegen rate convert function

parent fd46b935
No related branches found
No related tags found
No related merge requests found
/********************************************************************************
* Copyright (c) 2025 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_DATA_SPIKEGEN_H_
#define AIDGE_CORE_DATA_SPIKEGEN_H_
// Spikegen algorithm :
//
// time_data = data.repeat(time_steps)
// spike_data = rate_conv(time_data)
// return spike_data
#include "aidge/data/Tensor.hpp"
namespace Aidge {
std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor);
}
#endif
/********************************************************************************
* Copyright (c) 2025 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <random>
#include "aidge/data/Spikegen.hpp"
namespace Aidge {
std::shared_ptr<Tensor> rateConvert(const Tensor& tensor) {
auto result = std::make_shared<Tensor>(tensor.clone());
// Bernoulli sampling
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0.0f, 1.0f);
// Clip values between 0 and 1, equivalent to torch.clamp(min=0, max=1)
for (size_t i = 0; i < tensor.size(); i++) {
auto val = tensor.get<float>(i);
val = (val < 0.0f) ? 0.0f : ((val > 1.0f) ? 1.0f : val);
auto randomValue = dis(gen);
result->set(i, randomValue < val ? 1.0f : 0.0f);
}
return result;
}
std::shared_ptr<Tensor> spikegenRate(std::shared_ptr<Tensor> tensor) {
auto newTensor = tensor->repeat(10);
newTensor.print(); // DEBUG
return rateConvert(newTensor);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment