diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index f781e5e2e80f7bb265e796e1e76f65b9d6efeee8..7b0b80d816eba8000e782e0e5238c2550dd4eed9 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -745,7 +745,9 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { REQUIRE( approxEq<float>(*(fc2Op->getOutput(0)), *(expectedOutputfc2ts2))); } +} +TEST_CASE("[cpu/operator] MetaOperator", "[Leaky][CPU]") { SECTION("Leaky(forward)") { std::random_device rd; @@ -764,25 +766,15 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { std::generate(dims.begin(), dims.end(), [&]() { return dimSizeDist(gen); }); const auto nbTimeSteps = dims[0]; - auto myLeaky = Leaky(nbTimeSteps, beta, threshold, LeakyReset::Subtraction, "leaky"); - auto op = std::static_pointer_cast<MetaOperator_Op>(myLeaky->getOperator()); + auto leakyNode = Leaky(nbTimeSteps, beta, threshold, LeakyReset::Subtraction, "leaky"); + auto leakyOp = std::static_pointer_cast<MetaOperator_Op>(leakyNode->getOperator()); auto memoryRecord = Stack(nbTimeSteps, "mem_rec"); auto spikeRecord = Stack(nbTimeSteps, "spk_rec"); - auto pop = Pop("input"); - - auto leakyOutputs = op->getMicroGraph()->getOrderedOutputs(); - auto leakyInputs = op->getMicroGraph()->getOrderedInputs(); - pop->addChild(leakyInputs[0].first, 0, 0); - leakyOutputs[1].first->addChild(memoryRecord,0,0); - leakyOutputs[0].first->addChild(spikeRecord,0,0); - - auto myGraph = std::make_shared<GraphView>(); - myGraph->add(op->getMicroGraph()); - myGraph->add({pop, memoryRecord, spikeRecord}); + auto popNode = Pop("input"); - REQUIRE(myLeaky->nbInputs() == 3); - REQUIRE(myLeaky->inputCategory(0) == InputCategory::Data); - REQUIRE(myLeaky->nbOutputs() == 4); + REQUIRE(leakyNode->nbInputs() == 3); + REQUIRE(leakyNode->inputCategory(0) == InputCategory::Data); + REQUIRE(leakyNode->nbOutputs() == 4); const auto nb_elements = std::accumulate(dims.cbegin(), @@ -830,19 +822,28 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { memoryInit->setBackend("cpu"); memoryInit->resize(std::vector<std::size_t>(dims.begin() + 1, dims.end())); memoryInit->zeros(); + auto memoryInitNode = Producer(memoryInit); - pop->getOperator()->associateInput(0, inputTensor); - op->associateInput(1, memoryInit); - op->associateInput(2, memoryInit); + popNode->getOperator()->associateInput(0, inputTensor); + popNode->addChild(leakyNode,0, 0); + memoryInitNode->addChild(leakyNode, 0, 1); + memoryInitNode->addChild(leakyNode, 0, 2); + leakyNode->addChild(memoryRecord, 1, 0); + leakyNode->addChild(spikeRecord, 0, 0); - myGraph->compile("cpu", DataType::Float32); - auto scheduler = SequentialScheduler(myGraph); + auto g = std::make_shared<GraphView>(); + g->add({popNode, leakyNode, memoryRecord, spikeRecord, memoryInitNode}); + g->setDataType(DataType::Float32); + g->setBackend("cpu"); + + auto scheduler = SequentialScheduler(g); REQUIRE_NOTHROW(scheduler.generateScheduling()); REQUIRE_NOTHROW(scheduler.forward(true)); // Compare expected output with actual output auto memOp = std::static_pointer_cast<OperatorTensor>(spikeRecord->getOperator()); + //memOp->getOutput(0)->print(); REQUIRE(approxEq<float>(*(memOp->getOutput(0)), *(expectedOutput))); } }