Skip to content
Snippets Groups Projects
Commit 5877c3d3 authored by Jerome Hue's avatar Jerome Hue
Browse files

Add a reset to zero for Leaky operator

Add a `LeakyResetType` enum to differentiate between Subtraction reset
and to zero reset.
parent 3ee75f49
No related branches found
No related tags found
No related merge requests found
Pipeline #66083 failed
This commit is part of merge request !344. Comments created here will be created in the context of that merge request.
......@@ -305,10 +305,17 @@ std::shared_ptr<Node> LSTM(DimSize_t in_channels,
bool noBias = false,
const std::string &name = "");
enum class LeakyReset {
Subtraction,
ToZero
};
std::shared_ptr<MetaOperator_Op> LeakyOp();
std::shared_ptr<Node> Leaky(const int nbTimeSteps,
const float beta,
const float threshold = 1.0,
const LeakyReset resetType = LeakyReset::Subtraction,
const std::string &name = "");
} // namespace Aidge
......
......@@ -406,6 +406,7 @@ void declare_LeakyOp(py::module &m) {
py::arg("nb_timesteps"),
py::arg("beta"),
py::arg("threshold") = 1.0,
py::arg("reset") = LeakyReset::Subtraction,
py::arg("name") = "",
R"mydelimiter(
Initialize a Leaky neuron operator.
......
......@@ -16,95 +16,100 @@ constexpr auto memorizeOpDataOutputRecIndex = 1;
std::shared_ptr<Node> Leaky(const int nbTimeSteps,
const float beta,
const float threshold,
const LeakyReset resetType,
const std::string &name) {
auto microGraph = std::make_shared<GraphView>();
auto inputNode = Identity((!name.empty()) ? name + "_input" : "");
auto addNode = Add(!name.empty() ? name + "_add" : "");
auto mulNode = Mul(!name.empty() ? name + "_mul" : "");
auto subNode = Sub(!name.empty() ? name + "_sub" : "");
auto hsNode = Heaviside(0, !name.empty() ? name + "_hs" : "");
auto subNode2 = Sub(!name.empty() ? name + "_threshold" : "");
auto reset = Mul(!name.empty() ? name + "_reset" : "");
/*
* U[t] = Input[T] + beta * U[T-1] - S[T-1] * U_th
* with S[T] = | 1, if U[T] - U_th > 0
* | 0 otherwise
*/
auto input = Identity((!name.empty()) ? name + "_input" : "");
auto decay = Mul(!name.empty() ? name + "_mul" : "");
auto spike = Heaviside(0, !name.empty() ? name + "_hs" : "");
auto subNode2 = Sub(!name.empty() ? name + "_threshold" : "");
auto betaTensor = std::make_shared<Tensor>(beta);
auto uthTensor = std::make_shared<Tensor>(static_cast<float>(threshold));
uniformFiller<float>(uthTensor, threshold, threshold);
auto decayRate = Producer(betaTensor, "leaky_beta", true);
auto uth = Producer(uthTensor, "leaky_uth", true);
auto potentialMem = Memorize(nbTimeSteps, (!name.empty()) ? name + "_potential" : "");
auto spikeMem = Memorize(nbTimeSteps, (!name.empty()) ? name + "_spike" : "");
uniformFiller<float>(uthTensor, threshold, threshold);
auto potentialMem =
Memorize(nbTimeSteps, (!name.empty()) ? name + "_potential" : "");
auto spikeMem =
Memorize(nbTimeSteps, (!name.empty()) ? name + "_spike" : "");
// U[t] = Input[T] + beta * U[T-1] - S[T-1] * U_th
// with S[T] = | 1, if U[T] - U_th > 0
// | 0 otherwise
// beta * U[T-1]
decayRate->addChild(/*otherNode=*/mulNode, /*outId=*/0, /*otherInId=*/1);
potentialMem->addChild(mulNode, 1, 0);
// Input[T] + beta * U[T-1]
mulNode->addChild(/*otherNode=*/addNode, /*outId=*/0, /*otherInId=*/1);
inputNode->addChild(/*otherNode=*/addNode, /*outId=*/0, /*otherInId=*/0);
// S[T-1] * U_th
spikeMem->addChild(reset,
/*outId=*/memorizeOpDataOutputRecIndex,
/*otherInId=*/0);
// TODO(#219) Handle hard/soft reset
uth->addChild(reset, 0, 1);
// Common connections
decayRate->addChild(decay, 0, 1);
potentialMem->addChild(decay, 1, 0);
std::shared_ptr<Node> potentialNode; // Node containing the final potential value
if (resetType == LeakyReset::Subtraction) {
auto decayPlusInput = Add(!name.empty() ? name + "_add" : "");
decay->addChild(decayPlusInput, 0, 1);
input->addChild(decayPlusInput, 0, 0);
auto potentialSubReset = Sub(!name.empty() ? name + "_sub" : "");
auto reset = Mul(!name.empty() ? name + "_reset" : "");
spikeMem->addChild(reset, 1, 0);
uth->addChild(reset, 0, 1);
decayPlusInput->addChild(potentialSubReset, 0, 0);
reset->addChild(potentialSubReset, 0, 1);
potentialSubReset->addChild(potentialMem, 0, 0);
potentialNode = potentialSubReset;
microGraph->add({decayPlusInput, potentialSubReset, reset});
} else if (resetType == LeakyReset::ToZero) {
auto oneMinusSpike = Sub(!name.empty() ? name + "_one_minus_spike" : "");
auto one = Producer(std::make_shared<Tensor>(1.0f), "one", true);
auto finalMul = Mul(!name.empty() ? name + "_final" : "");
auto decayPlusInput = Add(!name.empty() ? name + "_add" : "");
one->addChild(oneMinusSpike, 0, 0);
spikeMem->addChild(oneMinusSpike, 1, 1);
oneMinusSpike->addChild(finalMul, 0, 0);
decay->addChild(finalMul, 0, 1);
// finalMul = (1-S[t-1]) * (decay)
oneMinusSpike->addChild(finalMul, 0, 1);
decay->addChild(finalMul, 0, 1);
// (1-S[t-1]) * (decay) + WX[t]
finalMul->addChild(decayPlusInput, 0, 0);
input->addChild(decayPlusInput, 0, 1);
decayPlusInput->addChild(potentialMem, 0, 0);
potentialNode = decayPlusInput;
microGraph->add({oneMinusSpike, one, finalMul, decayPlusInput});
}
// Threshold comparison : (U[t] - Uth)
potentialNode->addChild(subNode2, 0, 0);
uth->addChild(subNode2, 0, 1);
// Input[T] + beta * U[T-1] - S[T-1] * U_th
addNode->addChild(subNode, 0, 0);
reset->addChild(subNode, 0, 1);
// heaviside
subNode2->addChild(spike, 0, 0);
spike->addChild(spikeMem, 0, 0);
// U[t] = (Input[T] + beta * U[T-1]) - S[T-1]
subNode->addChild(potentialMem, 0, 0);
microGraph->add(input);
microGraph->add({decay, potentialMem, decayRate,
uth, spikeMem, spike, subNode2}, false);
// U[T] - U_th
subNode->addChild(subNode2, 0, 0);
uth->addChild(subNode2, 0, 1);
microGraph->setOrderedInputs(
{{input, 0}, {potentialMem, 1}, {spikeMem, 1}});
// with S[T] = | 1, if U[T] - U_th > 0
subNode2->addChild(hsNode, 0, 0);
hsNode->addChild(spikeMem, 0, 0);
microGraph->add(inputNode);
microGraph->add({addNode,
mulNode,
potentialMem,
decayRate,
uth,
spikeMem,
hsNode,
subNode,
subNode2,
reset},
false);
// Use potentialNode for membrane potential output
microGraph->setOrderedOutputs({{potentialNode, 0}, {spike, 0}});
microGraph->setOrderedInputs(
{{inputNode, 0}, {potentialMem, 1}, {spikeMem, 1}});
// NOTE: Outputs are NOT the memory nodes (as it is done in LSTM), to avoid
// producing data during init. This way, we can plug an operator after
// our node, and get correct results.
microGraph->setOrderedOutputs({//{potentialMem, memorizeOpDataOutputIndex},
//{spikeMem, memorizeOpDataOutputIndex}
{subNode, 0},
{hsNode, 0}});
auto metaOp = MetaOperator(/*type*/ "Leaky",
/*graph*/ microGraph,
/*forcedInputsCategory=*/{},
/*name*/ "leaky");
return metaOp;
return MetaOperator("Leaky", microGraph, {}, name);
}
std::shared_ptr<MetaOperator_Op> LeakyOp() {
......
......@@ -159,4 +159,17 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
REQUIRE(myLeaky->nbOutputs() == 4);
REQUIRE(true);
}
SECTION("Leaky(Reset to zero)") {
auto myLeaky = Leaky(10, 1.0, 0.9, LeakyReset::ToZero);
auto op = std::static_pointer_cast<OperatorTensor>(myLeaky->getOperator());
auto inputs = myLeaky->inputs();
// Two memorize nodes + real data input
REQUIRE(myLeaky->nbInputs() == 3);
// Outputs for spike and memory + 2 Memorize node
REQUIRE(myLeaky->nbOutputs() == 4);
REQUIRE(true);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment