From 585f986f1e281c7b16b7471461d4aa0e4c2b4c2b Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 22 Nov 2023 16:00:54 +0000 Subject: [PATCH] Update tests and some files --- src/operator/MetaOperator.cpp | 21 +- src/recipies/FuseBatchNorm.cpp | 91 ++++--- src/recipies/FuseMulAdd.cpp | 32 ++- unit_tests/graph/Test_Connector.cpp | 132 +++++----- unit_tests/graph/Test_GraphView.cpp | 282 ++++++++++----------- unit_tests/graph/Test_get.cpp | 16 +- unit_tests/graphRegex/Test_FsmMatch.cpp | 10 +- unit_tests/graphRegex/Test_GraphRegex.cpp | 24 +- unit_tests/operator/Test_MetaOperator.cpp | 15 +- unit_tests/recipies/Test_FuseBatchNorm.cpp | 8 +- unit_tests/recipies/Test_FuseMulAdd.cpp | 4 +- unit_tests/recipies/Test_LabelGraph.cpp | 38 +-- 12 files changed, 332 insertions(+), 341 deletions(-) diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index c1f58c686..28085759f 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -15,17 +15,12 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, std::vector<NodePtr> inputNodes, std::vector<NodePtr> outputNodes) - : Operator(type), + : OperatorTensor(type, graph->dataInputs().size(), (graph->inputs().size() - graph->dataInputs().size()), graph->outputs().size()), mGraph(graph) { - mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); - for (std::size_t i = 0; i < mInputs.size(); ++i) { - mInputs[i] = std::make_shared<Tensor>(); - } - mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size()); - for (std::size_t i = 0; i < mOutputs.size(); ++i) { - mOutputs[i] = std::make_shared<Tensor>(); - } + // for (std::size_t i = 0; i < mInputs.size(); ++i) { + // mInputs[i] = std::make_shared<Tensor>(); + // } // Fill inputsNodes and outputsNodes when there is no ambiguity if (inputNodes.empty()) { @@ -46,14 +41,14 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph"); const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); - + int inputIdx = 0; // input idx relative to the current node for (const auto& in : inputNodeinputs) { if (in.first == nullptr || !mGraph->inView(in.first)) { // The input is not connected inside the micro-graph // (no connection to this input or connection outside the micro-graph) // => it is therefore an input for the meta-operator - mInputOps.push_back(std::make_pair(inputNode->getOperator(), inputIdx)); + mInputOps.push_back(std::make_pair(std::dynamic_pointer_cast<OperatorTensor>(inputNode->getOperator()), inputIdx)); } ++inputIdx; @@ -67,7 +62,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< outputNode->outputs(); for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) { - mOutputOps.push_back(std::make_pair(outputNode->getOperator(), outputIdx)); + mOutputOps.push_back(std::make_pair(std::dynamic_pointer_cast<OperatorTensor>(outputNode->getOperator()), outputIdx)); } } @@ -114,7 +109,7 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() { // Lazy initialization mScheduler = std::make_shared<SequentialScheduler>(mGraph); } - + // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule. // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()" mScheduler->generateScheduling(); diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 4b02692c2..ffb4599d8 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -12,10 +12,10 @@ #include <cassert> #include <memory> #include <string> + #include "aidge/operator/FC.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" - #include "aidge/recipies/Recipies.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" @@ -26,70 +26,68 @@ //Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" -using namespace Aidge; - -void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){ +void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) { + // TODO: Find a way to remove the template + // A feature map with 2 dimensions is assumed + const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); + const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); + const std::shared_ptr<Tensor> scale = batchOp->getInput(1); + const std::shared_ptr<Tensor> shift = batchOp->getInput(2); + const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3); + const std::shared_ptr<Tensor> b_var = batchOp->getInput(4); - std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); - std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); - std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second); - std::shared_ptr<Tensor> b_var = batchnorm->input(4).first->getOperator()->getOutput(batchnorm->input(4).second); + const float epsilon = batchOp -> getAttr<float>("Epsilon"); + const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels"); + const DimSize_t channelsSize = convOp -> getAttr<DimSize_t>("InChannels"); + const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims"); - // TODO : Find a way to remove the template - const float epsilon = std::static_pointer_cast<BatchNorm_Op<2>>(batchnorm->getOperator())->getAttr<float>("Epsilon"); - DimSize_t convOutDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("OutChannels"); - - - assert(scale->size() == convOutDims); - assert(shift->size() == convOutDims); - assert(b_mean->size() == convOutDims); - assert(b_var->size() == convOutDims); + assert(scale->size() == convNbOutChannels); + assert(shift->size() == convNbOutChannels); + assert(b_mean->size() == convNbOutChannels); + assert(b_var->size() == convNbOutChannels); assert(epsilon > 0.0); // TODO : no no_bias attribute ? + + float meanVariance = 0.0; unsigned int count = 0; - for (std::size_t output = 0; output < convOutDims; ++output) { - // TODO : get suppose datatype is float .. - if (b_var->get<float>(output) > 1.0e-12) { - meanVariance += b_var->get<float>(output); + for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { + // TODO: get() assumed dataType is float... + if (b_var->get<float>(outChId) > 1.0e-12) { + meanVariance += b_var->get<float>(outChId); ++count; } else { - printf("Zero-variance: %s [%lu]\n", conv->name().c_str(), output); + printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId); } } if (count > 0) meanVariance /= count; else { - printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n"); + printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } - const DimSize_t channelsSize = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels"); - - // TODO : suppose we have Conv2D ... - const std::array<DimSize_t, 2> kernelDims = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims"); - - std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second); - std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second); + std::shared_ptr<Tensor> weight = convOp -> getInput(1); + std::shared_ptr<Tensor> bias = convOp -> getInput(2); - for (std::size_t output = 0; output < convOutDims; ++output) { + for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { // Corrected for zero-variance issue: // "A Quantization-Friendly Separable Convolution for MobileNets" // https://arxiv.org/pdf/1803.08607.pdf // to help post-training quantization - const float factor = scale->get<float>(output) - / std::sqrt(epsilon + ((b_var->get<float>(output) > 1.0e-12 || count == 0) - ? b_var->get<float>(output) : meanVariance)); + const float factor = scale->get<float>(outChId) + / std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0) + ? b_var->get<float>(outChId) : meanVariance)); // Weights adjustments for (std::size_t channel = 0; channel < channelsSize; ++channel) { // TODO : Suppose kerneldims = 2 for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ - std::vector<DimSize_t> currentIdx = {output, channel, k0, k1}; + std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1}; // TODO : suppose weights are float float weightValue = weight->get<float>(currentIdx); weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights @@ -98,25 +96,25 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batch } // TODO : check if noBias==true is set, then set biasValue to 0 - float biasValue = bias->get<float>(output); + float biasValue = bias->get<float>(outChId); - biasValue = shift->get<float>(output) + (biasValue - b_mean->get<float>(output)) * factor; + biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor; - bias->set<float>(output, biasValue); + bias->set<float>(outChId, biasValue); } GraphView::replace(std::set<std::shared_ptr<Node>>({ - batchnorm, - batchnorm->input(1).first, - batchnorm->input(2).first, - batchnorm->input(3).first, - batchnorm->input(4).first + batchnormNode, + batchnormNode->input(1).first, + batchnormNode->input(2).first, + batchnormNode->input(3).first, + batchnormNode->input(4).first }), {}); } -void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){ +void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) { assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n"); assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n"); @@ -129,7 +127,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){ } -void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ +void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); @@ -143,5 +141,4 @@ void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ fuseBatchNorm(solution); } - -} +} \ No newline at end of file diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index a268b7fef..d37f47496 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -21,28 +21,28 @@ #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" +#include "aidge/operator/MatMul.hpp" //Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" -using namespace Aidge; -void Aidge::fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add){//std::set<std::shared_ptr<Node>> nodes){ +void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { //std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - assert((matmul->type() == "MatMul" && add->type() == "Add") && "Wrong type for the nodes to replace"); + assert((matmulNode->type() == "MatMul" && addNode->type() == "Add") && "Wrong type for the nodes to replace"); // Step 1 : Create FC // Fetch the output dimension throught the bias size - std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr; + std::shared_ptr<Node> bias = (addNode->getParent(1)) ? addNode->getParent(1)->cloneSharedOperators() : nullptr; - if (!(matmul->getParent(1))) { + if (!(matmulNode->getParent(1))) { AIDGE_INTERNAL_ASSERT("No weight detected to produce the fuseMulAdd recipe."); } - std::shared_ptr<Node> weight = matmul->getParent(1)->cloneSharedOperators(); - DimSize_t outSize = weight->getOperator()->output(0).dims<2>()[1]; + std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); + const DimSize_t outSize = std::dynamic_pointer_cast<MatMul_Op>(matmulNode->getOperator()) -> getAttr<DimSize_t>("OutChannels"); // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); @@ -61,25 +61,25 @@ void Aidge::fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add){/ // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory? auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)}); - GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, newNodes); + GraphView::replace({matmulNode, addNode, addNode->getParent(1), matmulNode->getParent(1)}, newNodes); } -void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){ +void Aidge::fuseMulAdd(std::shared_ptr<Aidge::MatchSolution> solution){ assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); - for (const auto& matmul : solution->at("MatMul")) { - for (const auto& add : solution->at("Add")) { - fuseMulAdd(matmul,add); + for (const auto& matmulNode : solution->at("MatMul")) { + for (const auto& addNode : solution->at("Add")) { + fuseMulAdd(matmulNode,addNode); } } } -void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ +void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){ std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); @@ -90,10 +90,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ for (const auto& solution : regex->match(graphView)) { fuseMulAdd(solution); - - } - -} + } +} \ No newline at end of file diff --git a/unit_tests/graph/Test_Connector.cpp b/unit_tests/graph/Test_Connector.cpp index ef70521d0..a7cee610e 100644 --- a/unit_tests/graph/Test_Connector.cpp +++ b/unit_tests/graph/Test_Connector.cpp @@ -26,19 +26,19 @@ TEST_CASE("[core/graph] Connector(Constructor)") { REQUIRE(x.node() == nullptr); } SECTION("0 output") { - std::shared_ptr<Node> node = GenericOperator("Producer",1,1,0); + std::shared_ptr<Node> node = GenericOperator("Producer", 1, 0, 0); Connector x = Connector(node); REQUIRE(x.index() == gk_IODefaultIndex); REQUIRE(x.node() == node); } SECTION("1 output") { - std::shared_ptr<Node> node = GenericOperator("ReLU",1,1,1); + std::shared_ptr<Node> node = GenericOperator("ReLU", 1, 0, 1); Connector x = Connector(node); REQUIRE(x.index() == 0); REQUIRE(x.node() == node); } SECTION("Several outputs") { - std::shared_ptr<Node> node = GenericOperator("Split",1,1,2); + std::shared_ptr<Node> node = GenericOperator("Split", 1, 0, 2); Connector x = Connector(node); REQUIRE(x.index() == gk_IODefaultIndex); REQUIRE(x.node() == node); @@ -47,30 +47,30 @@ TEST_CASE("[core/graph] Connector(Constructor)") { TEST_CASE("Connector connections Node", "[Connector]") { SECTION("0 input / 0 output") { - std::shared_ptr<Node> fic = GenericOperator("Display",0,0,0); + std::shared_ptr<Node> fic = GenericOperator("Display", 0, 0, 0); Connector x; x = (*fic)({}); REQUIRE(x.node() == fic); } SECTION("1 input / 0 output") { - std::shared_ptr<Node> fic = GenericOperator("Loss",1,1,0); + std::shared_ptr<Node> fic = GenericOperator("Loss", 1, 0, 0); Connector x; x = (*fic)({x}); REQUIRE(x.node() == fic); } SECTION("0 input / 1 output") { // Producers - std::shared_ptr<Node> fic = GenericOperator("Producer",0,0,1); + std::shared_ptr<Node> fic = GenericOperator("Producer", 0, 0, 1); Connector x = (*fic)({}); REQUIRE(x.node() == fic); } SECTION("1 input / 1 output") { - std::shared_ptr<Node> fic = GenericOperator("Conv",1,1,1); + std::shared_ptr<Node> fic = GenericOperator("Conv", 1, 0, 1); Connector x(GenericOperator("Producer",0,0,1)); x = (*fic)({x}); REQUIRE(x.node() ==fic); } SECTION("2+ inputs / 1 output") { // ElemWise - std::shared_ptr<Node> fic = GenericOperator("fictive",3,3,1); + std::shared_ptr<Node> fic = GenericOperator("fictive", 3, 0, 1); Connector x1(GenericOperator("fictive",0,0,1)); Connector x2(GenericOperator("fictive",0,0,1)); Connector x3(GenericOperator("fictive",0,0,1)); @@ -78,9 +78,9 @@ TEST_CASE("Connector connections Node", "[Connector]") { REQUIRE(x.node() ==fic); } SECTION("1 input / 2+ outputs") { // Slice - std::shared_ptr<Node> fic = GenericOperator("fictive",1,1,3); + std::shared_ptr<Node> fic = GenericOperator("fictive", 1, 0, 3); - Connector x(GenericOperator("fictive2",0,0,1)); + Connector x(GenericOperator("fictive2", 0, 0, 1)); Connector y; REQUIRE_NOTHROW(y = (*fic)({x})); REQUIRE(y[0].node() == fic); @@ -91,16 +91,16 @@ TEST_CASE("Connector connections Node", "[Connector]") { TEST_CASE("GraphGeneration from Connector", "[GraphView]") { - auto node01 = GenericOperator("Conv",0,0,1,"g_conv1"); - auto node02 = GenericOperator("ReLU",1,1,1,"g_relu"); - auto node03 = GenericOperator("g_maxpool1", 1,1,1); - auto node04 = GenericOperator("g_conv2_par1",1,1,1); - auto node05 = GenericOperator("g_relu2_par1", 1,1,1); - auto node06 = GenericOperator("g_conv2_par2", 1,1,1); - auto node07 = GenericOperator("g_relu2_par2", 1,1,1); - auto node08 = GenericOperator("g_concat", 2,2,1); - auto node09 = GenericOperator("g_conv3", 1, 1,1); - auto node10 = GenericOperator("g_matmul1", 2,2,1); + auto node01 = GenericOperator("Conv", 0, 0, 1,"g_conv1"); + auto node02 = GenericOperator("ReLU", 1, 0, 1,"g_relu"); + auto node03 = GenericOperator("g_maxpool1", 1, 0, 1); + auto node04 = GenericOperator("g_conv2_par1", 1, 0, 1); + auto node05 = GenericOperator("g_relu2_par1", 1, 0, 1); + auto node06 = GenericOperator("g_conv2_par2", 1, 0, 1); + auto node07 = GenericOperator("g_relu2_par2", 1, 0, 1); + auto node08 = GenericOperator("g_concat", 2, 0, 1); + auto node09 = GenericOperator("g_conv3", 1, 0, 1); + auto node10 = GenericOperator("g_matmul1", 2, 0, 1); Connector a = (*node01)({}); Connector x = (*node02)({a}); x = (*node03)({x}); @@ -118,38 +118,38 @@ TEST_CASE("GraphGeneration from Connector", "[GraphView]") { TEST_CASE("Connector connection GraphView", "[Connector]") { SECTION("1 input") { Connector x = Connector(); - auto prod = GenericOperator("Producer",0,0,1); + auto prod = GenericOperator("Producer", 0, 0, 1); auto g = Residual({ - GenericOperator("g_conv1", 1,1,1), - GenericOperator("g_relu", 1,1,1), - GenericOperator("g_maxpool1", 1,1,1), + GenericOperator("g_conv1", 1, 0, 1), + GenericOperator("g_relu", 1, 0, 1), + GenericOperator("g_maxpool1", 1, 0, 1), Parallel({ - Sequential({GenericOperator("g_conv2_par1",1,1,1), GenericOperator("g_relu2_par1", 1,1,1)}), - Sequential({GenericOperator("g_conv2_par2", 1,1,1), GenericOperator("g_relu2_par2", 1,1,1)}) + Sequential({GenericOperator("g_conv2_par1", 1, 0, 1), GenericOperator("g_relu2_par1", 1, 0, 1)}), + Sequential({GenericOperator("g_conv2_par2", 1, 0, 1), GenericOperator("g_relu2_par2", 1, 0, 1)}) }), - GenericOperator("g_concat", 2,2,1), - GenericOperator("g_conv3", 1, 1,1), - GenericOperator("g_matmul1", 2,2,1) + GenericOperator("g_concat", 2, 0, 1), + GenericOperator("g_conv3", 1, 0, 1), + GenericOperator("g_matmul1", 2, 0, 1) }); x = (*prod)({}); x = (*g)({x}); std::shared_ptr<GraphView> g2 = generateGraph({x}); std::shared_ptr<GraphView> g3 = g; g3->add(prod); - REQUIRE(*g3== *g2); + REQUIRE(*g3 == *g2); } SECTION("2+ inputs") { - Connector x = (*GenericOperator("Producer",0,0,1))({}); - Connector y = (*GenericOperator("Producer",0,0,1))({}); - Connector z = (*GenericOperator("Producer",0,0,1))({}); - auto g = Sequential({GenericOperator("ElemWise", 3,3,1), + Connector x = (*GenericOperator("Producer", 0, 0, 1))({}); + Connector y = (*GenericOperator("Producer", 0, 0, 1))({}); + Connector z = (*GenericOperator("Producer", 0, 0, 1))({}); + auto g = Sequential({GenericOperator("ElemWise", 3, 0, 1), Parallel({ - Sequential({GenericOperator("g_conv2_par1",1,1,1), GenericOperator("g_relu2_par1", 1,1,1)}), - Sequential({GenericOperator("g_conv2_par2", 1,1,1), GenericOperator("g_relu2_par2", 1,1,1)}), - Sequential({GenericOperator("g_conv2_par3", 1,1,1), GenericOperator("g_relu2_par3", 1,1,1)}) + Sequential({GenericOperator("g_conv2_par1", 1, 0, 1), GenericOperator("g_relu2_par1", 1, 0, 1)}), + Sequential({GenericOperator("g_conv2_par2", 1, 0, 1), GenericOperator("g_relu2_par2", 1, 0, 1)}), + Sequential({GenericOperator("g_conv2_par3", 1, 0, 1), GenericOperator("g_relu2_par3", 1, 0, 1)}) }), - GenericOperator("g_concat", 3,3,1), - GenericOperator("g_conv3", 1, 1,1) + GenericOperator("g_concat", 3, 0, 1), + GenericOperator("g_conv3", 1, 0, 1) }); x = (*g)({x, y, z}); @@ -162,12 +162,12 @@ TEST_CASE("Connector connection GraphView", "[Connector]") { TEST_CASE("Connector Mini-graph", "[Connector]") { Connector x = Connector(); Connector y = Connector(); - x = (*GenericOperator("Producer",0,0,1))({}); - y = (*GenericOperator("Producer",0,0,1))({}); + x = (*GenericOperator("Producer", 0, 0, 1))({}); + y = (*GenericOperator("Producer", 0, 0, 1))({}); for (int i = 0; i<5; ++i) { - x = (*GenericOperator("Conv",1,1,1))({x}); + x = (*GenericOperator("Conv", 1, 0, 1))({x}); } - y = (*GenericOperator("ElemWise",2,2,1))({y, x}); + y = (*GenericOperator("ElemWise", 2, 0, 1))({y, x}); std::shared_ptr<GraphView> g = generateGraph({y}); g->save("TestGraph"); } @@ -180,16 +180,16 @@ TEST_CASE("Structural descrition - Sequential", "[GraphView]") { // REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>()); // } SECTION("1-element Sequence") { - std::shared_ptr<Node> fic = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic = GenericOperator("node1", 1, 0, 1); std::shared_ptr<GraphView> g2 = Sequential({fic}); REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic})); REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic})); REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic})); } SECTION("several-elements simple Sequence") { - std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1, 0, 1); std::shared_ptr<GraphView> g2 = Sequential({fic1, fic2, fic3}); REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1})); @@ -206,37 +206,37 @@ TEST_CASE("Structural description - Parallel", "[GraphView]") { // REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>()); // } SECTION("1-element Parallel") { - std::shared_ptr<Node> fic = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic = GenericOperator("node1", 1, 0, 1); std::shared_ptr<GraphView> g2 = Parallel({fic}); REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic})); REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic})); REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic})); } SECTION("several-elements simple Parallel") { - std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1, 0, 1); std::shared_ptr<GraphView> g2 = Parallel({fic1, fic2, fic3}); REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); } SECTION("1 Graph in Parallel") { - std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1, 0, 1); std::shared_ptr<GraphView> g2 = Parallel({Sequential({fic1, fic2, fic3})}); REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1})); REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic3})); } SECTION("several Sequential in Parallel") { - std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic4 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic5 = GenericOperator("node1", 1,1,1); - std::shared_ptr<Node> fic6 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic4 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic5 = GenericOperator("node1", 1, 0, 1); + std::shared_ptr<Node> fic6 = GenericOperator("node1", 1, 0, 1); std::shared_ptr<GraphView> g2 = Parallel({Sequential({fic1, fic2, fic3}),Sequential({fic4, fic5, fic6})}); REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3, fic4, fic5, fic6})); REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic4})); @@ -245,13 +245,13 @@ TEST_CASE("Structural description - Parallel", "[GraphView]") { } TEST_CASE("Strucutral Description - Complex Graph", "[GraphView]") { - std::shared_ptr<Node> firstLayer = GenericOperator("first", 1,1,1); + std::shared_ptr<Node> firstLayer = GenericOperator("first", 1, 0, 1); auto g = Sequential({firstLayer, - GenericOperator("l2",1,1,1), - Parallel({Sequential({GenericOperator("conv1",1,1,1), GenericOperator("relu1",1,1,1)}), - Sequential({GenericOperator("conv2",1,1,1), GenericOperator("relu2",1,1,1)})}), - GenericOperator("concat",2,2,1), - GenericOperator("lastLayer",1,1,1)}); + GenericOperator("l2", 1, 0, 1), + Parallel({Sequential({GenericOperator("conv1",1, 0, 1), GenericOperator("relu1", 1, 0, 1)}), + Sequential({GenericOperator("conv2", 1, 0, 1), GenericOperator("relu2", 1, 0, 1)})}), + GenericOperator("concat", 2, 0, 1), + GenericOperator("lastLayer", 1, 0, 1)}); REQUIRE(g->getNodes().size() == 8U); REQUIRE(g->inputNodes() == std::set<std::shared_ptr<Node>>({firstLayer})); } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 3b4fb167b..bb726bd4d 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -40,13 +40,13 @@ TEST_CASE("[core/graph] GraphView(add)") { g->add(GOp1); std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 1, "Gop2"); g->add(GOp2); - std::shared_ptr<Node> GOp3 = GenericOperator("Fictive", 1, 1, 0, "Gop3"); + std::shared_ptr<Node> GOp3 = GenericOperator("Fictive", 1, 0, 0, "Gop3"); g->add(GOp3); std::shared_ptr<Node> GOp4 = GenericOperator("Fictive", 0, 1, 0, "Gop4"); g->add(GOp4); - std::shared_ptr<Node> GOp5 = GenericOperator("Fictive", 1, 1, 1, "Gop5"); + std::shared_ptr<Node> GOp5 = GenericOperator("Fictive", 1, 0, 1, "Gop5"); g->add(GOp5); - std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 2, 1, "Gop6"); + std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 1, 1, "Gop6"); g->add(GOp6); } @@ -75,11 +75,11 @@ TEST_CASE("[core/graph] GraphView(add)") { SECTION("another GraphView") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph-1"); std::shared_ptr<GraphView> g2 = std::make_shared<GraphView>("TestGraph-2"); - auto conv = GenericOperator("Conv", 1, 1, 1, "c"); - auto conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); - auto conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); - auto conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); - auto conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + auto conv = GenericOperator("Conv", 1, 0, 1, "c"); + auto conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); + auto conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + auto conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); + auto conv4 = GenericOperator("Conv", 1, 0, 1, "c4"); conv->addChild(conv1); conv1->addChild(conv2); conv2->addChild(conv3); @@ -96,13 +96,13 @@ TEST_CASE("[core/graph] GraphView(add)") { TEST_CASE("[core/graph] GraphView(addChild)") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); - std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); - std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); - std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); - std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 0, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5"); g1->add(conv); SECTION("add(node)") { @@ -177,12 +177,12 @@ TEST_CASE("[core/graph] GraphView(outputs)") { TEST_CASE("[core/graph] GraphView(save)") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); - std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); - std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); - std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5"); g1->add(conv); g1->addChild(conv1, "c"); @@ -197,9 +197,9 @@ TEST_CASE("[core/graph] GraphView(save)") { TEST_CASE("[core/graph] GraphView(resetConnections)") { SECTION("disconnect data iput") { - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 3, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); conv->addChild(conv1); @@ -222,9 +222,9 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } SECTION("disconnect data iput + learnable parameters") { - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 3, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); conv->addChild(conv1); @@ -259,21 +259,21 @@ TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") { g->forwardDims(); SECTION("Check input-output connections") { - REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getInput(1) == g->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getInput(2) == g->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); - REQUIRE(conv2->getOperator()->getInput(1) == g->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getInput(2) == g->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); - REQUIRE(conv3->getOperator()->getInput(1) == g->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(conv3->getOperator()->getInput(2) == g->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawInput(1) == g->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawInput(2) == g->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); + REQUIRE(conv2->getOperator()->getRawInput(1) == g->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawInput(2) == g->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); + REQUIRE(conv3->getOperator()->getRawInput(1) == g->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv3->getOperator()->getRawInput(2) == g->getNode("conv3_b")->getOperator()->getRawOutput(0)); } SECTION("Check forwarded dims") { - REQUIRE(std::static_pointer_cast<Tensor>(conv1->getOperator()->getOutput(0)) + REQUIRE(std::static_pointer_cast<Tensor>(conv1->getOperator()->getRawOutput(0)) ->dims() == std::vector<DimSize_t>({16, 32, 222, 222})); - REQUIRE(std::static_pointer_cast<Tensor>(conv2->getOperator()->getOutput(0)) + REQUIRE(std::static_pointer_cast<Tensor>(conv2->getOperator()->getRawOutput(0)) ->dims() == std::vector<DimSize_t>({16, 64, 220, 220})); } } @@ -286,10 +286,10 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w"); auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b"); - auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); - auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); - auto matmul = GenericOperator("MatMul", 1, 2, 1, "matmul"); - auto add = GenericOperator("Add", 1, 2, 1, "add"); + auto other1 = GenericOperator("Other", 1, 0, 1, "other1"); + auto other2 = GenericOperator("Other", 1, 0, 1, "other2"); + auto matmul = GenericOperator("MatMul", 1, 1, 1, "matmul"); + auto add = GenericOperator("Add", 1, 1, 1, "add"); otherInput->addChild(other1); other1->addChild(matmul); matmul->addChild(add); @@ -303,7 +303,7 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { std::set<std::shared_ptr<Node>> nodeToReplace = std::set<std::shared_ptr<Node>>({matmulWeight, addBias, matmul, add}); // create replacing graph - std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 3, 1, "fc"); + std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 2, 1, "fc"); auto newMatmulWeight = matmulWeight->cloneSharedOperators(); newMatmulWeight->addChild(myFC, 0, 1); auto newAddBias = addBias->cloneSharedOperators(); @@ -319,9 +319,9 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { SECTION("replace with nothing") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); auto r1 = GenericOperator("relu", 0, 0, 1); - auto r2 = GenericOperator("relu", 1, 1, 1); - auto r3 = GenericOperator("relu", 1, 1, 1); - auto r4 = GenericOperator("relu", 1, 1, 0); + auto r2 = GenericOperator("relu", 1, 0, 1); + auto r3 = GenericOperator("relu", 1, 0, 1); + auto r4 = GenericOperator("relu", 1, 0, 0); r1->addChild(r2); r2->addChild(r3); r3->addChild(r4); @@ -337,20 +337,20 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { SECTION("replace for tiling") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph"); auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); - auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); - auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv"); - auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); + auto other1 = GenericOperator("Other", 1, 0, 1, "other1"); + auto myConv = GenericOperator("Conv", 1, 0, 1, "myConv"); + auto other2 = GenericOperator("Other", 1, 0, 1, "other2"); otherInput->addChild(other1); other1->addChild(myConv); myConv->addChild(other2); g->add({other1, myConv, other2}); // create tiled Conv - auto conv1 = GenericOperator("Conv", 1, 1, 1, "myConv1"); - auto conv2 = GenericOperator("Conv", 1, 1, 1, "myConv2"); - auto conv3 = GenericOperator("Conv", 1, 1, 1, "myConv3"); - auto conv4 = GenericOperator("Conv", 1, 1, 1, "myConv4"); - auto concat = GenericOperator("Concat", 4, 4, 1, "myConcat"); + auto conv1 = GenericOperator("Conv", 1, 0, 1, "myConv1"); + auto conv2 = GenericOperator("Conv", 1, 0, 1, "myConv2"); + auto conv3 = GenericOperator("Conv", 1, 0, 1, "myConv3"); + auto conv4 = GenericOperator("Conv", 1, 0, 1, "myConv4"); + auto concat = GenericOperator("Concat", 4, 0, 1, "myConcat"); conv1->addChild(concat); conv2->addChild(concat); conv3->addChild(concat); @@ -368,12 +368,12 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { SECTION("Change every Nodes in a GraphView") { auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0"); auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0"); - auto matmul0 = GenericOperator("MatMul", 1, 2, 1, "matmul0"); - auto add0 = GenericOperator("Add", 1, 2, 1, "add0"); + auto matmul0 = GenericOperator("MatMul", 1, 1, 1, "matmul0"); + auto add0 = GenericOperator("Add", 1, 1, 1, "add0"); auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1"); auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1"); - auto matmul1 = GenericOperator("MatMul", 1, 2, 1, "matmul1"); - auto add1 = GenericOperator("Add", 1, 2, 1, "add1"); + auto matmul1 = GenericOperator("MatMul", 1, 1, 1, "matmul1"); + auto add1 = GenericOperator("Add", 1, 1, 1, "add1"); matmulWeight0 -> addChild(matmul0, 0, 1); addBias0 -> addChild(add0, 0, 1); @@ -389,8 +389,8 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { auto newAddBias0 = addBias0->cloneSharedOperators(); auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators(); auto newAddBias1 = addBias1->cloneSharedOperators(); - auto fc0 = GenericOperator("FC", 1, 3, 1, "fc0"); - auto fc1 = GenericOperator("FC", 1, 3, 1, "fc1"); + auto fc0 = GenericOperator("FC", 1, 2, 1, "fc0"); + auto fc1 = GenericOperator("FC", 1, 2, 1, "fc1"); newMatmulWeight0 -> addChild(fc0, 0, 1); newAddBias0 -> addChild(fc0, 0, 2); @@ -417,15 +417,15 @@ TEST_CASE("[GraphView] clone") { g1->save("clone_g1"); SECTION("Check input-output connections") { - REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); - REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); - REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); + REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); + REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0)); } auto g2 = g1->clone(); @@ -461,27 +461,27 @@ TEST_CASE("[GraphView] clone") { } SECTION("Check new connections") { - REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); - REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) != g2->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) != g2->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) != g2->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) != g2->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) != g2->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) != g2->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) != g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) != g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) != g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) != g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) != g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) != g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } SECTION("Check input-output connections") { - REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } } @@ -498,15 +498,15 @@ TEST_CASE("[GraphView] cloneSharedProducers") { g1->save("cloneSharedProducers_g1"); SECTION("Check input-output connections") { - REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); - REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); - REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); + REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); + REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0)); } auto g2 = g1->cloneSharedProducers(); @@ -542,27 +542,27 @@ TEST_CASE("[GraphView] cloneSharedProducers") { } SECTION("Check new connections") { - REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); - REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } SECTION("Check input-output connections") { - REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } } @@ -579,15 +579,15 @@ TEST_CASE("[GraphView] cloneSharedOperators") { g1->save("cloneSharedOperators_g1"); SECTION("Check input-output connections") { - REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); - REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); - REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); + REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); + REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0)); } auto g2 = g1->cloneSharedOperators(); @@ -619,15 +619,15 @@ TEST_CASE("[GraphView] cloneSharedOperators") { } SECTION("Check input-output connections") { - REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + REQUIRE(dataProvider->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } } @@ -653,10 +653,10 @@ TEST_CASE("[core/graph] GraphView(insertParent)") { std::set<NodePtr> expectedConv1Children = {conv3, newConv}; std::set<NodePtr> expectedNewConvChildren = {conv2}; - REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); - REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0)); + REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); REQUIRE((newConv->getChildren()) == expectedNewConvChildren); REQUIRE((conv1->getChildren()) == expectedConv1Children); @@ -665,11 +665,11 @@ TEST_CASE("[core/graph] GraphView(insertParent)") { std::set<NodePtr> expectedConv1Children2 = {newConv}; std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3}; - REQUIRE(conv1->getOperator()->getOutput(0) != conv3->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); - REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); - REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); - REQUIRE(newConv->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) != conv3->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0)); + REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0)); + REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); + REQUIRE(newConv->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); REQUIRE((newConv->getChildren()) == expectedNewConvChildren2); REQUIRE((conv1->getChildren()) == expectedConv1Children2); diff --git a/unit_tests/graph/Test_get.cpp b/unit_tests/graph/Test_get.cpp index afd1f42ee..7b396f22b 100644 --- a/unit_tests/graph/Test_get.cpp +++ b/unit_tests/graph/Test_get.cpp @@ -23,15 +23,15 @@ using namespace Aidge; TEST_CASE("get Delta") { - + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); - std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); - std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); - std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); - std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 0, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5"); g1->add(conv); g1->addChild(conv1, "c"); diff --git a/unit_tests/graphRegex/Test_FsmMatch.cpp b/unit_tests/graphRegex/Test_FsmMatch.cpp index 4b0a009a4..008251fea 100644 --- a/unit_tests/graphRegex/Test_FsmMatch.cpp +++ b/unit_tests/graphRegex/Test_FsmMatch.cpp @@ -34,8 +34,8 @@ TEST_CASE("FsmMatch") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); g1->add(conv); g1->addChild(conv1, "c"); @@ -55,9 +55,9 @@ TEST_CASE("FsmMatch") { SECTION("2 branche graph"){ std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Fc", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Fc", 1, 0, 1, "c2"); g1->add(conv); g1->addChild(conv1,conv); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index b30560ea3..ad8f257e4 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -16,10 +16,10 @@ TEST_CASE("GraphRegexUser") { std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> fc = GenericOperator("FC", 1, 1, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); - std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> fc = GenericOperator("FC", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 0, 1, "c3"); g1->add(conv); g1->addChild(fc, "c"); @@ -29,7 +29,7 @@ TEST_CASE("GraphRegexUser") { sut->setKeyFromGraph(g1); sut->addQuery(query); - + for (const auto& solution : sut->match(g1)) { REQUIRE(solution->getQuery() == query); @@ -43,21 +43,21 @@ TEST_CASE("GraphRegexUser") { } } //REQUIRE( sut->match(g1)[1]->getAll() == std::set<NodePtr>{conv,fc}); - + } SECTION("CC") { - + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); - std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); g1->add(conv); g1->addChild(conv1, "c"); @@ -79,6 +79,6 @@ TEST_CASE("GraphRegexUser") { for (const auto& solution : sut->match(g1)) { REQUIRE(solution->getQuery() == query); } - + } } \ No newline at end of file diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 79f6979c9..ef0c4e7f7 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -36,15 +36,16 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(); myInput->resize({2,3,5,5}); - op->getOperator()->associateInput(0,myInput); - op->getOperator()->computeOutputDims(); + std::shared_ptr<OperatorTensor> opTensor = std::static_pointer_cast<OperatorTensor>(op->getOperator()); + opTensor->associateInput(0,myInput); + opTensor->computeOutputDims(); - REQUIRE(op->getOperator()->outputDimsForwarded()); - REQUIRE(op->getOperator()->getOutput(0)->dims() == std::vector<size_t>({2,3,5,5})); - REQUIRE(op->getOperator()->getInput(0) == myInput); + REQUIRE(opTensor->outputDimsForwarded()); + REQUIRE(std::static_pointer_cast<Tensor>(opTensor->getRawOutput(0))->dims() == std::vector<size_t>({2,3,5,5})); + REQUIRE(std::static_pointer_cast<Tensor>(opTensor->getRawInput(0)) == myInput); // Order not garanteed by the GraphView - //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getInput(0) == myInput); - REQUIRE(op->getOperator()->getOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getOutput(0)); + //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getRawInput(0) == myInput); + REQUIRE(opTensor->getRawOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getRawOutput(0)); //op->getOperator()->updateConsummerProducer(); // require implementation //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp index 13facefd2..5d9c02d55 100644 --- a/unit_tests/recipies/Test_FuseBatchNorm.cpp +++ b/unit_tests/recipies/Test_FuseBatchNorm.cpp @@ -42,7 +42,7 @@ namespace Aidge { BatchNorm<2>() }); - g1->setDatatype(DataType::Float32); + g1->setDataType(DataType::Float32); g1->setBackend("cpu"); g1->forwardDims(); @@ -59,12 +59,12 @@ namespace Aidge { SECTION("Check resulting nodes") { // REQUIRE(g1->getNodes().size() == 2); // REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + // REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); // REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + // REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); // REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); } } - + } */ \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 39a4f7949..0c65db989 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -25,9 +25,9 @@ namespace Aidge { TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView - auto matmul0 = MatMul(5, "matmul0"); + auto matmul0 = MatMul(5, 5, "matmul0"); auto add0 = Add(2, "add0"); - auto matmul1 = MatMul(5, "matmul1"); + auto matmul1 = MatMul(5, 5, "matmul1"); auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp index 873ad68f3..e0ba9be6c 100644 --- a/unit_tests/recipies/Test_LabelGraph.cpp +++ b/unit_tests/recipies/Test_LabelGraph.cpp @@ -45,9 +45,9 @@ TEST_CASE("[LabelGraph] conv") { SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); } } @@ -56,7 +56,7 @@ TEST_CASE("[LabelGraph] deleted node") { auto g1 = Sequential({ Producer({16, 3, 224, 224}, "dataProvider"), Conv(3, 32, {3, 3}, "conv1"), - GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 0, 1), Conv(32, 64, {3, 3}, "conv2"), Conv(64, 10, {1, 1}, "conv3", {2, 2}) }); @@ -74,16 +74,16 @@ TEST_CASE("[LabelGraph] deleted node") { SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); } SECTION("Check dimensions") { - REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); - REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 220, 220})); - REQUIRE(g2->getNode("conv3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 110, 110})); + REQUIRE(std::static_pointer_cast<Tensor>(g2->getNode("conv1")->getOperator()->getRawOutput(0))->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(std::static_pointer_cast<Tensor>(g2->getNode("conv2")->getOperator()->getRawOutput(0))->dims() == std::vector<DimSize_t>({16, 1, 220, 220})); + REQUIRE(std::static_pointer_cast<Tensor>(g2->getNode("conv3")->getOperator()->getRawOutput(0))->dims() == std::vector<DimSize_t>({16, 1, 110, 110})); } } @@ -91,11 +91,11 @@ TEST_CASE("[LabelGraph] deleted nodes") { auto g1 = Sequential({ Producer({16, 3, 224, 224}, "dataProvider"), Conv(3, 32, {3, 3}, "conv1"), - GenericOperator("Dummy_to_be_removed", 1, 1, 1), - GenericOperator("Dummy_to_be_removed", 1, 1, 1), - GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 0, 1), + GenericOperator("Dummy_to_be_removed", 1, 0, 1), + GenericOperator("Dummy_to_be_removed", 1, 0, 1), Conv(32, 64, {3, 3}, "conv2"), - GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 0, 1), Conv(64, 10, {1, 1}, "conv3") }); @@ -112,9 +112,9 @@ TEST_CASE("[LabelGraph] deleted nodes") { SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); } } @@ -140,15 +140,15 @@ TEST_CASE("[LabelGraph] pooling") { SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); REQUIRE(g2->getNode("pool1")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0) == g2->getNode("pool2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool1")->getOperator()->getRawOutput(0) == g2->getNode("pool2")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("pool2")->getOperator()->type() == "MaxPooling"); - REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0) == g2->getNode("pool3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool2")->getOperator()->getRawOutput(0) == g2->getNode("pool3")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("pool3")->getOperator()->type() == "MaxPooling"); } SECTION("Check dimensions") { - REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 223, 223})); - REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); - REQUIRE(g2->getNode("pool3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 111, 111})); + REQUIRE(std::static_pointer_cast<Tensor>(g2->getNode("pool1")->getOperator()->getRawOutput(0))->dims() == std::vector<DimSize_t>({16, 1, 223, 223})); + REQUIRE(std::static_pointer_cast<Tensor>(g2->getNode("pool2")->getOperator()->getRawOutput(0))->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(std::static_pointer_cast<Tensor>(g2->getNode("pool3")->getOperator()->getRawOutput(0))->dims() == std::vector<DimSize_t>({16, 1, 111, 111})); } } -- GitLab