Skip to content
Snippets Groups Projects
Commit d0b9adfa authored by Jerome Hue's avatar Jerome Hue
Browse files

Add a reset to zero for Leaky operator

Add a `LeakyResetType` enum to differentiate between Subtraction reset
and to zero reset.
parent 374b93b9
No related branches found
No related tags found
1 merge request!344feat: Reset to zero for Leaky
Pipeline #66596 failed
...@@ -305,10 +305,17 @@ std::shared_ptr<Node> LSTM(DimSize_t in_channels, ...@@ -305,10 +305,17 @@ std::shared_ptr<Node> LSTM(DimSize_t in_channels,
bool noBias = false, bool noBias = false,
const std::string &name = ""); const std::string &name = "");
enum class LeakyReset {
Subtraction,
ToZero
};
std::shared_ptr<MetaOperator_Op> LeakyOp(); std::shared_ptr<MetaOperator_Op> LeakyOp();
std::shared_ptr<Node> Leaky(const int nbTimeSteps, std::shared_ptr<Node> Leaky(const int nbTimeSteps,
const float beta, const float beta,
const float threshold = 1.0, const float threshold = 1.0,
const LeakyReset resetType = LeakyReset::Subtraction,
const std::string &name = ""); const std::string &name = "");
} // namespace Aidge } // namespace Aidge
......
...@@ -406,6 +406,7 @@ void declare_LeakyOp(py::module &m) { ...@@ -406,6 +406,7 @@ void declare_LeakyOp(py::module &m) {
py::arg("nb_timesteps"), py::arg("nb_timesteps"),
py::arg("beta"), py::arg("beta"),
py::arg("threshold") = 1.0, py::arg("threshold") = 1.0,
py::arg("reset") = LeakyReset::Subtraction,
py::arg("name") = "", py::arg("name") = "",
R"mydelimiter( R"mydelimiter(
Initialize a Leaky neuron operator. Initialize a Leaky neuron operator.
......
...@@ -16,95 +16,100 @@ constexpr auto memorizeOpDataOutputRecIndex = 1; ...@@ -16,95 +16,100 @@ constexpr auto memorizeOpDataOutputRecIndex = 1;
std::shared_ptr<Node> Leaky(const int nbTimeSteps, std::shared_ptr<Node> Leaky(const int nbTimeSteps,
const float beta, const float beta,
const float threshold, const float threshold,
const LeakyReset resetType,
const std::string &name) { const std::string &name) {
auto microGraph = std::make_shared<GraphView>(); auto microGraph = std::make_shared<GraphView>();
auto inputNode = Identity((!name.empty()) ? name + "_input" : ""); /*
auto addNode = Add(!name.empty() ? name + "_add" : ""); * U[t] = Input[T] + beta * U[T-1] - S[T-1] * U_th
auto mulNode = Mul(!name.empty() ? name + "_mul" : ""); * with S[T] = | 1, if U[T] - U_th > 0
auto subNode = Sub(!name.empty() ? name + "_sub" : ""); * | 0 otherwise
auto hsNode = Heaviside(0, !name.empty() ? name + "_hs" : ""); */
auto subNode2 = Sub(!name.empty() ? name + "_threshold" : "");
auto reset = Mul(!name.empty() ? name + "_reset" : "");
auto input = Identity((!name.empty()) ? name + "_input" : "");
auto decay = Mul(!name.empty() ? name + "_mul" : "");
auto spike = Heaviside(0, !name.empty() ? name + "_hs" : "");
auto subNode2 = Sub(!name.empty() ? name + "_threshold" : "");
auto betaTensor = std::make_shared<Tensor>(beta); auto betaTensor = std::make_shared<Tensor>(beta);
auto uthTensor = std::make_shared<Tensor>(static_cast<float>(threshold)); auto uthTensor = std::make_shared<Tensor>(static_cast<float>(threshold));
uniformFiller<float>(uthTensor, threshold, threshold);
auto decayRate = Producer(betaTensor, "leaky_beta", true); auto decayRate = Producer(betaTensor, "leaky_beta", true);
auto uth = Producer(uthTensor, "leaky_uth", true); auto uth = Producer(uthTensor, "leaky_uth", true);
auto potentialMem = Memorize(nbTimeSteps, (!name.empty()) ? name + "_potential" : "");
auto spikeMem = Memorize(nbTimeSteps, (!name.empty()) ? name + "_spike" : "");
uniformFiller<float>(uthTensor, threshold, threshold);
auto potentialMem = // Common connections
Memorize(nbTimeSteps, (!name.empty()) ? name + "_potential" : ""); decayRate->addChild(decay, 0, 1);
auto spikeMem = potentialMem->addChild(decay, 1, 0);
Memorize(nbTimeSteps, (!name.empty()) ? name + "_spike" : "");
std::shared_ptr<Node> potentialNode; // Node containing the final potential value
// U[t] = Input[T] + beta * U[T-1] - S[T-1] * U_th
// with S[T] = | 1, if U[T] - U_th > 0 if (resetType == LeakyReset::Subtraction) {
// | 0 otherwise auto decayPlusInput = Add(!name.empty() ? name + "_add" : "");
decay->addChild(decayPlusInput, 0, 1);
// beta * U[T-1] input->addChild(decayPlusInput, 0, 0);
decayRate->addChild(/*otherNode=*/mulNode, /*outId=*/0, /*otherInId=*/1);
potentialMem->addChild(mulNode, 1, 0); auto potentialSubReset = Sub(!name.empty() ? name + "_sub" : "");
auto reset = Mul(!name.empty() ? name + "_reset" : "");
// Input[T] + beta * U[T-1]
mulNode->addChild(/*otherNode=*/addNode, /*outId=*/0, /*otherInId=*/1); spikeMem->addChild(reset, 1, 0);
inputNode->addChild(/*otherNode=*/addNode, /*outId=*/0, /*otherInId=*/0); uth->addChild(reset, 0, 1);
// S[T-1] * U_th decayPlusInput->addChild(potentialSubReset, 0, 0);
spikeMem->addChild(reset, reset->addChild(potentialSubReset, 0, 1);
/*outId=*/memorizeOpDataOutputRecIndex,
/*otherInId=*/0); potentialSubReset->addChild(potentialMem, 0, 0);
// TODO(#219) Handle hard/soft reset potentialNode = potentialSubReset;
uth->addChild(reset, 0, 1); microGraph->add({decayPlusInput, potentialSubReset, reset});
} else if (resetType == LeakyReset::ToZero) {
auto oneMinusSpike = Sub(!name.empty() ? name + "_one_minus_spike" : "");
auto one = Producer(std::make_shared<Tensor>(1.0f), "one", true);
auto finalMul = Mul(!name.empty() ? name + "_final" : "");
auto decayPlusInput = Add(!name.empty() ? name + "_add" : "");
one->addChild(oneMinusSpike, 0, 0);
spikeMem->addChild(oneMinusSpike, 1, 1);
oneMinusSpike->addChild(finalMul, 0, 0);
decay->addChild(finalMul, 0, 1);
// finalMul = (1-S[t-1]) * (decay)
oneMinusSpike->addChild(finalMul, 0, 1);
decay->addChild(finalMul, 0, 1);
// (1-S[t-1]) * (decay) + WX[t]
finalMul->addChild(decayPlusInput, 0, 0);
input->addChild(decayPlusInput, 0, 1);
decayPlusInput->addChild(potentialMem, 0, 0);
potentialNode = decayPlusInput;
microGraph->add({oneMinusSpike, one, finalMul, decayPlusInput});
}
// Threshold comparison : (U[t] - Uth)
potentialNode->addChild(subNode2, 0, 0);
uth->addChild(subNode2, 0, 1);
// Input[T] + beta * U[T-1] - S[T-1] * U_th // heaviside
addNode->addChild(subNode, 0, 0); subNode2->addChild(spike, 0, 0);
reset->addChild(subNode, 0, 1); spike->addChild(spikeMem, 0, 0);
// U[t] = (Input[T] + beta * U[T-1]) - S[T-1] microGraph->add(input);
subNode->addChild(potentialMem, 0, 0); microGraph->add({decay, potentialMem, decayRate,
uth, spikeMem, spike, subNode2}, false);
// U[T] - U_th microGraph->setOrderedInputs(
subNode->addChild(subNode2, 0, 0); {{input, 0}, {potentialMem, 1}, {spikeMem, 1}});
uth->addChild(subNode2, 0, 1);
// with S[T] = | 1, if U[T] - U_th > 0 // Use potentialNode for membrane potential output
subNode2->addChild(hsNode, 0, 0); microGraph->setOrderedOutputs({{potentialNode, 0}, {spike, 0}});
hsNode->addChild(spikeMem, 0, 0);
microGraph->add(inputNode);
microGraph->add({addNode,
mulNode,
potentialMem,
decayRate,
uth,
spikeMem,
hsNode,
subNode,
subNode2,
reset},
false);
microGraph->setOrderedInputs( return MetaOperator("Leaky", microGraph, {}, name);
{{inputNode, 0}, {potentialMem, 1}, {spikeMem, 1}});
// NOTE: Outputs are NOT the memory nodes (as it is done in LSTM), to avoid
// producing data during init. This way, we can plug an operator after
// our node, and get correct results.
microGraph->setOrderedOutputs({//{potentialMem, memorizeOpDataOutputIndex},
//{spikeMem, memorizeOpDataOutputIndex}
{subNode, 0},
{hsNode, 0}});
auto metaOp = MetaOperator(/*type*/ "Leaky",
/*graph*/ microGraph,
/*forcedInputsCategory=*/{},
/*name*/ "leaky");
return metaOp;
} }
std::shared_ptr<MetaOperator_Op> LeakyOp() { std::shared_ptr<MetaOperator_Op> LeakyOp() {
......
...@@ -159,4 +159,17 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") { ...@@ -159,4 +159,17 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
REQUIRE(myLeaky->nbOutputs() == 4); REQUIRE(myLeaky->nbOutputs() == 4);
REQUIRE(true); REQUIRE(true);
} }
SECTION("Leaky(Reset to zero)") {
auto myLeaky = Leaky(10, 1.0, 0.9, LeakyReset::ToZero);
auto op = std::static_pointer_cast<OperatorTensor>(myLeaky->getOperator());
auto inputs = myLeaky->inputs();
// Two memorize nodes + real data input
REQUIRE(myLeaky->nbInputs() == 3);
// Outputs for spike and memory + 2 Memorize node
REQUIRE(myLeaky->nbOutputs() == 4);
REQUIRE(true);
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment