Skip to content
Snippets Groups Projects
Commit 07996a12 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'fix_mem' into 'dev'

Fix scheduling with Memorize

See merge request !119
parents 5d703003 450a7d79
No related branches found
No related tags found
1 merge request!119Fix scheduling with Memorize
Pipeline #63129 failed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <memory>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
namespace Aidge {
TEST_CASE("[cpu/operator] Memorize(forward)", "[Memorize][CPU]") {
SECTION("Test simple") {
std::shared_ptr<Tensor> inputTensor =
std::make_shared<Tensor>(Array1D<int, 1>{{1}});
auto input = Producer({1}, "input");
auto init = Producer({1}, "init");
auto add = Add("add");
auto mem = Memorize(3, "mem");
input->addChild(add, 0, 0);
init->addChild(mem, 0, 1);
add->addChild(mem, 0,0);
mem->addChild(/*otherNode=*/add, /*outId=*/1, /*otherInId=*/1);
input->getOperator()->setOutput(0, inputTensor);
init->getOperator()->setOutput(0, inputTensor);
auto g = getConnectedGraphView(input);
g->setDataType(Aidge::DataType::Int32);
g->setBackend("cpu");
g->forwardDims();
g->save("simple_graph");
SequentialScheduler scheduler(g);
REQUIRE_NOTHROW(scheduler.forward());
scheduler.saveSchedulingDiagram("simple");
const auto expectedOutput = std::make_shared<Tensor>(Array1D<int, 1>{{4}});
std::shared_ptr<Tensor> other = std::static_pointer_cast<OperatorTensor>(mem->getOperator())->getOutput(0);
other->print();
REQUIRE((*other == *expectedOutput));
}
}
} // namespace Aidge
......@@ -18,6 +18,10 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/Pop.hpp"
#include "aidge/operator/Stack.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/ParallelScheduler.hpp"
......@@ -438,4 +442,69 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward
predictedOutput->setGrad(targetOutput);
REQUIRE_NOTHROW(scheduler.backward());
}
std::shared_ptr<Node> Accumulate(int seqLength, const std::string& name) {
auto input = Identity((!name.empty()) ? name + "_input" : "");
auto hiddenState = Memorize(seqLength, (!name.empty()) ? name + "_hidden_state" : "");
auto add = Add((!name.empty()) ? name + "_add" : "");
input->addChild(add, 0, 0);
add->addChild(hiddenState, 0,0);
hiddenState->addChild(/*otherNode=*/add, /*outId=*/1, /*otherInId=*/1);
std::shared_ptr<GraphView> microGraph = std::make_shared<GraphView>();
microGraph->add(input);
microGraph->add({hiddenState, add});
microGraph->setOrderedInputs({{input, 0}, {hiddenState, 1}});
microGraph->setOrderedOutputs({{hiddenState, 0}});
auto metaOp = MetaOperator("Accumulate", microGraph, {}, name);
return metaOp;
}
TEST_CASE("[cpu/scheduler] Accumulate", "[scheduler]") {
std::shared_ptr<Tensor> Input = std::make_shared<Tensor>(
Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}},
{{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}});
std::shared_ptr<Tensor> MemInit =
std::make_shared<Tensor>(Array2D<float, 3, 2>{
{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}});
auto meta = Accumulate(2, "accumulate");
auto op = std::static_pointer_cast<MetaOperator_Op>(meta->getOperator());
auto pop_i = Pop("pop_input");
auto pop_o = Identity("pop_output"); // NOTE: Could be Identity/Stack/Whatever node you want, this is is not the problem here
pop_i->getOperator()->associateInput(0, Input);
pop_i->addChild(op->getMicroGraph()->getOrderedInputs()[0].first, 0, 0);
op->getMicroGraph()->getOrderedOutputs()[0].first->addChild(pop_o, 0, 0);
//pop_i->addChild(meta, 0, 0);
//meta->addChild(pop_o, 0, 0);
//op->associateInput(1, MemInit);
op->getMicroGraph()->getNode("accumulate_hidden_state")->getOperator()->associateInput(1, MemInit);
// Build the graph.
auto myGraph = std::make_shared<GraphView>();
myGraph->add(pop_i);
myGraph->add(op->getMicroGraph());
//myGraph->add(meta);
myGraph->add(pop_o);
myGraph->compile("cpu", DataType::Float32);
myGraph->save("accumulate_graph", true);
// Schedule and run
auto scheduler = SequentialScheduler(myGraph);
scheduler.generateScheduling();
scheduler.saveStaticSchedulingDiagram("accumulate_scheduling");
REQUIRE_NOTHROW(scheduler.forward(true));
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{3.0, 5.0}, {7.0, 9.0}, {11.0, 13.0}}});
std::shared_ptr<Tensor> output = std::static_pointer_cast<OperatorTensor>(pop_o->getOperator())->getOutput(0);
REQUIRE(*output == *expectedOutput);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment