diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5a7b05e469daab10a4abd468177a3ad137096f63..6f0cc55159b1cc72b87bb34230376eb140b7ab8a 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -11,22 +11,25 @@ #include "aidge/graph/Node.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/operator/Producer.hpp" #include <memory> #include <vector> + +#include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/utils/Types.h" Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) : mName(name), mOperator(op), - mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), - mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), - std::vector<std::weak_ptr<Node>>())), - mIdInChildren( - std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), - mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { + mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), + nullptr)), + mChildren(std::vector<std::vector<std::weak_ptr<Node>>>( + static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())), + mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), + std::vector<IOIndex_t>())), + mIdOutParents( + std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { // ctor } @@ -34,14 +37,15 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// -Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { +Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) { assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); - for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { - assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); - (void) input; // avoid unused warning + for (std::pair<std::shared_ptr<Node>, IOIndex_t>& input : inputs()) { + assert((gk_IODefaultIndex == input.second) && + "At least one input connection is not free.\n"); + (void)input; // avoid unused warning } IOIndex_t i = 0; - for (const Connector &ctor : ctors) { + for (const Connector& ctor : ctors) { if (ctor.node() != nullptr) { // ctor must be associated with a node ctor.node()->addChild(shared_from_this(), ctor.index(), i++); } @@ -53,7 +57,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { // INNER /////////////////////////////////////////////////////// -void Aidge::Node::setName(const std::string &name) { mName = name; } +void Aidge::Node::setName(const std::string& name) { mName = name; } /////////////////////////////////////////////////////// // OPERATORS @@ -92,8 +96,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { return nbFreeDataIn; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::dataInputs() const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() + const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { @@ -104,15 +108,15 @@ Aidge::Node::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); for (std::size_t i = 0; i < nbInputs(); ++i) { - res[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); + res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } -// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { +// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> +// tensor) { // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); // if (mParents[idx] != nullptr) { // mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); @@ -128,20 +132,21 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::Node::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = - std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); + std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>( + mIdInChildren.size()); for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { listOutputs[i] = output(static_cast<IOIndex_t>(i)); } return listOutputs; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::output(Aidge::IOIndex_t outId) const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output( + Aidge::IOIndex_t outId) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { - listOutputs[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); + listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), + mIdInChildren[outId][i]); } return listOutputs; } @@ -180,7 +185,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) // TOPOLOGY /////////////////////////////////////////////////////// -void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { +void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + const IOIndex_t otherInId) { assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second != gk_IODefaultIndex) { @@ -196,33 +202,41 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou } void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { - assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + assert((otherInId.second < otherInId.first->nbInputs()) && + "Other graph input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { - assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node + assert((inNodes.find(otherInId.first) != + inNodes.end())); // assert it really is an input node addChildOp(otherInId.first, outId, otherInId.second); } } -void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { - otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); - addChildOp(otherNode, outId, otherInId); +void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + IOIndex_t otherInId) { + if (otherNode) { + otherInId = + (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); + addChildOp(otherNode, outId, otherInId); + } } void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { if (!otherInId.first) { assert((otherView->inputNodes().size() == 1U) && "Specify an input Node for the GraphView. More or less than one " "Node is not explicit."); otherInId.first = *(otherView->inputNodes().begin()); } - otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); + otherInId.second = (otherInId.second != gk_IODefaultIndex) + ? otherInId.second + : otherInId.first->getFirstFreeDataInput(); addChildView(otherView, outId, otherInId); } @@ -255,8 +269,8 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) { std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const auto &childrenOfOneOutput : mChildren) { - for (const auto &oneChild : childrenOfOneOutput) { + for (const auto& childrenOfOneOutput : mChildren) { + for (const auto& oneChild : childrenOfOneOutput) { children.insert(oneChild.lock()); } } @@ -264,7 +278,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { - std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); + std::vector<std::vector<std::shared_ptr<Node>>> children = + std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { children[outId] = getChildren(outId); } @@ -273,14 +288,16 @@ std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedCh std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { assert((outId < nbOutputs()) && "Output index out of bound."); - std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); + std::vector<std::shared_ptr<Node>> children = + std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { - children.push_back(mChildren[outId][i].lock()); - } + children.push_back(mChildren[outId][i].lock()); + } return children; } -bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { +bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, + const Aidge::IOIndex_t outId) { assert((outId < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { @@ -301,7 +318,8 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); if (parent.first) { // number of children linked to the parent's output - while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} + while (parent.first->removeChild(shared_from_this(), parent.second) == true) { + } } // every reference to this object as child has been removed // removing reference to parents. @@ -316,24 +334,23 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to - for (auto& graph : views()) { - // if keeping connections with LEarnable Parameters, then also remove them from graph - graph->remove(shared_from_this(), !includeLearnableParam); - } + // for (auto& graph : views()) { + // // if keeping connections with LEarnable Parameters, then also remove them from graph + // graph->remove(shared_from_this(), !includeLearnableParam); + // } } - /////////////////////////////////////////////////////// - // CLONE - /////////////////////////////////////////////////////// +/////////////////////////////////////////////////////// +// CLONE +/////////////////////////////////////////////////////// Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { return std::make_shared<Node>(mOperator, mName); } Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { - std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) - ? mOperator - : mOperator->clone(); + std::shared_ptr<Operator> op = + (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone(); return std::make_shared<Node>(op, mName); } @@ -342,27 +359,25 @@ Aidge::NodePtr Aidge::Node::clone() const { return std::make_shared<Node>(mOperator->clone(), mName); } - -std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ - +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::NodePtr> nodeSee) { std::set<Aidge::NodePtr> out; nodeSee.insert(shared_from_this()); - if(delta == 0) { + if (delta == 0) { out.insert(shared_from_this()); - }else if (delta > 0){ - for (const NodePtr& node : getChildren()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + } else if (delta > 0) { + for (const NodePtr& node : getChildren()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta - 1, nodeSee)) { out.insert(ch); } } } - }else{ - for (const NodePtr& node : getParents()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + } else { + for (const NodePtr& node : getParents()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta + 1, nodeSee)) { out.insert(pr); } }