Skip to content
Snippets Groups Projects
Commit 142e6e34 authored by Jerome Hue's avatar Jerome Hue
Browse files

Use nodes instead of Tensors

parent d4d09c91
No related branches found
No related tags found
No related merge requests found
Pipeline #70072 failed
...@@ -745,7 +745,9 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -745,7 +745,9 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
REQUIRE( REQUIRE(
approxEq<float>(*(fc2Op->getOutput(0)), *(expectedOutputfc2ts2))); approxEq<float>(*(fc2Op->getOutput(0)), *(expectedOutputfc2ts2)));
} }
}
TEST_CASE("[cpu/operator] MetaOperator", "[Leaky][CPU]") {
SECTION("Leaky(forward)") { SECTION("Leaky(forward)") {
std::random_device rd; std::random_device rd;
...@@ -764,25 +766,15 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -764,25 +766,15 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
std::generate(dims.begin(), dims.end(), [&]() { return dimSizeDist(gen); }); std::generate(dims.begin(), dims.end(), [&]() { return dimSizeDist(gen); });
const auto nbTimeSteps = dims[0]; const auto nbTimeSteps = dims[0];
auto myLeaky = Leaky(nbTimeSteps, beta, threshold, LeakyReset::Subtraction, "leaky"); auto leakyNode = Leaky(nbTimeSteps, beta, threshold, LeakyReset::Subtraction, "leaky");
auto op = std::static_pointer_cast<MetaOperator_Op>(myLeaky->getOperator()); auto leakyOp = std::static_pointer_cast<MetaOperator_Op>(leakyNode->getOperator());
auto memoryRecord = Stack(nbTimeSteps, "mem_rec"); auto memoryRecord = Stack(nbTimeSteps, "mem_rec");
auto spikeRecord = Stack(nbTimeSteps, "spk_rec"); auto spikeRecord = Stack(nbTimeSteps, "spk_rec");
auto pop = Pop("input"); auto popNode = Pop("input");
auto leakyOutputs = op->getMicroGraph()->getOrderedOutputs();
auto leakyInputs = op->getMicroGraph()->getOrderedInputs();
pop->addChild(leakyInputs[0].first, 0, 0);
leakyOutputs[1].first->addChild(memoryRecord,0,0);
leakyOutputs[0].first->addChild(spikeRecord,0,0);
auto myGraph = std::make_shared<GraphView>();
myGraph->add(op->getMicroGraph());
myGraph->add({pop, memoryRecord, spikeRecord});
REQUIRE(myLeaky->nbInputs() == 3); REQUIRE(leakyNode->nbInputs() == 3);
REQUIRE(myLeaky->inputCategory(0) == InputCategory::Data); REQUIRE(leakyNode->inputCategory(0) == InputCategory::Data);
REQUIRE(myLeaky->nbOutputs() == 4); REQUIRE(leakyNode->nbOutputs() == 4);
const auto nb_elements = const auto nb_elements =
std::accumulate(dims.cbegin(), std::accumulate(dims.cbegin(),
...@@ -830,19 +822,28 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -830,19 +822,28 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
memoryInit->setBackend("cpu"); memoryInit->setBackend("cpu");
memoryInit->resize(std::vector<std::size_t>(dims.begin() + 1, dims.end())); memoryInit->resize(std::vector<std::size_t>(dims.begin() + 1, dims.end()));
memoryInit->zeros(); memoryInit->zeros();
auto memoryInitNode = Producer(memoryInit);
pop->getOperator()->associateInput(0, inputTensor); popNode->getOperator()->associateInput(0, inputTensor);
op->associateInput(1, memoryInit); popNode->addChild(leakyNode,0, 0);
op->associateInput(2, memoryInit); memoryInitNode->addChild(leakyNode, 0, 1);
memoryInitNode->addChild(leakyNode, 0, 2);
leakyNode->addChild(memoryRecord, 1, 0);
leakyNode->addChild(spikeRecord, 0, 0);
myGraph->compile("cpu", DataType::Float32); auto g = std::make_shared<GraphView>();
auto scheduler = SequentialScheduler(myGraph); g->add({popNode, leakyNode, memoryRecord, spikeRecord, memoryInitNode});
g->setDataType(DataType::Float32);
g->setBackend("cpu");
auto scheduler = SequentialScheduler(g);
REQUIRE_NOTHROW(scheduler.generateScheduling()); REQUIRE_NOTHROW(scheduler.generateScheduling());
REQUIRE_NOTHROW(scheduler.forward(true)); REQUIRE_NOTHROW(scheduler.forward(true));
// Compare expected output with actual output // Compare expected output with actual output
auto memOp = auto memOp =
std::static_pointer_cast<OperatorTensor>(spikeRecord->getOperator()); std::static_pointer_cast<OperatorTensor>(spikeRecord->getOperator());
//memOp->getOutput(0)->print();
REQUIRE(approxEq<float>(*(memOp->getOutput(0)), *(expectedOutput))); REQUIRE(approxEq<float>(*(memOp->getOutput(0)), *(expectedOutput)));
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment