Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_core
  • hrouis/aidge_core
  • mszczep/aidge_core
  • oantoni/aidge_core
  • cguillon/aidge_core
  • jeromeh/aidge_core
  • axelfarr/aidge_core
  • cmoineau/aidge_core
  • noamzerah/aidge_core
  • lrakotoarivony/aidge_core
  • silvanosky/aidge_core
  • maab05/aidge_core
  • mick94/aidge_core
  • lucaslopez/aidge_core_ll
  • wboussella/aidge_core
  • farnez/aidge_core
  • mnewson/aidge_core
17 results
Show changes
......@@ -42,6 +42,14 @@ void Aidge::Operator::updateConsummerProducer(){
mImpl->updateConsummerProducer();
}
void Aidge::Operator::forward() { mImpl->forward(); }
void Aidge::Operator::runHooks() const {
for (auto& hook : mHooks) {
hook.second->call();
}
}
void Aidge::Operator::forward() {
mImpl->forward();
runHooks();
}
void Aidge::Operator::backward() { mImpl->backward(); }
......@@ -59,12 +59,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Step 2 : Branch existing producers & create the others
// link weights & bias
if (matmul->getParents(1)==nullptr) {
matmul->getParents(0)->addChild(fc, 0, 1);
if (matmul->getParent(1)==nullptr) {
matmul->getParent(0)->addChild(fc, 0, 1);
} else {
if (matmul->getParents(0)!=nullptr)
matmul->getParents(0)->addChild(fc, 0, 0);
matmul->getParents(1)->addChild(fc, 0, 1);
if (matmul->getParent(0)!=nullptr)
matmul->getParent(0)->addChild(fc, 0, 0);
matmul->getParent(1)->addChild(fc, 0, 1);
}
(producer_add_bias.first)->addChild(fc,0,2);
......
......@@ -330,4 +330,48 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4);
}
}
TEST_CASE("[core/graph] GraphView(insertParent)") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(32, 64, {1, 1}, "conv3");
auto g = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g->add(conv1);
g->addChild(conv2, conv1, 0);
g->addChild(conv3, conv1, 0);
g->save("graphForwardDims");
g->forwardDims();
auto newConv = Conv(32, 32, {1, 1}, "newConv");
SECTION("Check insertParent conv2 then insertParent conv3") {
g->insertParent(conv2, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children = {conv3, newConv};
std::set<NodePtr> expectedNewConvChildren = {conv2};
REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren);
REQUIRE((conv1->getChildren()) == expectedConv1Children);
g->insertParent(conv3, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children2 = {newConv};
std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3};
REQUIRE(conv1->getOperator()->getOutput(0) != conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren2);
REQUIRE((conv1->getChildren()) == expectedConv1Children2);
}
}
\ No newline at end of file
......@@ -20,10 +20,10 @@ using namespace Aidge;
TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") {
SECTION("INT") {
GenericOperator_Op Testop("TestOp", 1, 1, 1);
int value = 5;
const char* key = "intParam";
Testop.addParameter(key, value);
REQUIRE(Testop.getParameter<int>(key) == value);
Testop.addParameter(key, int(5));
int registeredVal = Testop.getParameter<int>(key);
REQUIRE(registeredVal == 5);
}
SECTION("LONG") {
GenericOperator_Op Testop("TestOp", 1, 1, 1);
......