/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/Producer.hpp" #include <memory> #include <vector> #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Types.h" Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) : mName(name), mOperator(op), mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())), mIdInChildren( std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { // ctor } /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); (void) input; // avoid unused warning } IOIndex_t i = 0; for (const Connector &ctor : ctors) { if (ctor.node() != nullptr) { // ctor must be associated with a node ctor.node()->addChild(shared_from_this(), ctor.index(), i++); } } return Connector(shared_from_this()); } /////////////////////////////////////////////////////// // INNER /////////////////////////////////////////////////////// void Aidge::Node::setName(const std::string &name) { mName = name; } /////////////////////////////////////////////////////// // OPERATORS /////////////////////////////////////////////////////// void Aidge::Node::forward() { assert((mOperator != nullptr) && "No Operator interface provided, can't run forward().\n"); mOperator->forward(); } void Aidge::Node::backward() { assert((mOperator != nullptr) && "No Operator interface provided, can't run backward().\n"); mOperator->backward(); } /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// bool Aidge::Node::valid() const { for (IOIndex_t i = 0; i < nbInputs(); ++i) { if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) { return false; } } return true; } Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { IOIndex_t nbFreeDataIn = 0; for (IOIndex_t i = 0; i < nbInputs(); ++i) { if (input(i).second == gk_IODefaultIndex) { ++nbFreeDataIn; } } return nbFreeDataIn; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); for (std::size_t i = 0; i < nbInputs(); ++i) { res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } // void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); // if (mParents[idx] != nullptr) { // mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); // removeParent(idx); // } // std::shared_ptr<Node> newConstantNode = Producer(tensor); // newConstantNode->addChild(shared_from_this(), 0, idx); // for (auto& graphPtr : views()) { // graphPtr->add(newConstantNode); // } // } std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::Node::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { listOutputs[i] = output(static_cast<IOIndex_t>(i)); } return listOutputs; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output(Aidge::IOIndex_t outId) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); } return listOutputs; } Aidge::IOIndex_t Aidge::Node::nbValidInputs() const { IOIndex_t counter = 0; for (IOIndex_t i = 0; i < nbInputs(); ++i) { if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) ++counter; } return counter; } Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const { IOIndex_t counter = 0; if (mIdInChildren.size() == 0) return 0; for (std::size_t i = 0; i < nbOutputs(); ++i) { if (mIdInChildren[i].size() > 0U) counter++; } return counter; } void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) { assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index"); if (mIdOutParents[inId] != gk_IODefaultIndex) { std::printf("Warning: filling a Tensor already attributed\n"); auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent // find first occurence of child in the output's children originalParent.first->removeChild(shared_from_this(), originalParent.second); } mIdOutParents[inId] = newNodeoutId; } /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second != gk_IODefaultIndex) { std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId); } // manage tensors and potential previous parent otherNode->setInputId(otherInId, outId); otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId)); // manage nodes mChildren[outId].push_back(std::weak_ptr<Node>(otherNode)); mIdInChildren[outId].push_back(otherInId); otherNode->addParent(shared_from_this(), otherInId); } void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node addChildOp(otherInId.first, outId, otherInId.second); } } void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); addChildOp(otherNode, outId, otherInId); } void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { if (!otherInId.first) { assert((otherView->inputNodes().size() == 1U) && "Specify an input Node for the GraphView. More or less than one " "Node is not explicit."); otherInId.first = *(otherView->inputNodes().begin()); } otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); addChildView(otherView, outId, otherInId); } void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) { if (getParent(inId) != nullptr) { printf("Warning, you're replacing a Parent.\n"); } assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound."); mParents[inId] = other_node; } std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getParents() const { return mParents; } std::shared_ptr<Aidge::Node> Aidge::Node::popParent(const IOIndex_t inId) { assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound."); std::shared_ptr<Node> val = mParents[inId]; removeParent(inId); return val; } bool Aidge::Node::removeParent(const IOIndex_t inId) { assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Parent index out of bound."); if (mParents[inId]) { mParents[inId] = nullptr; mIdOutParents[inId] = gk_IODefaultIndex; return true; } return false; } std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; for (const auto &childrenOfOneOutput : mChildren) { for (const auto &oneChild : childrenOfOneOutput) { children.insert(oneChild.lock()); } } return children; } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { children[outId] = getChildren(outId); } return children; } std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { assert((outId < nbOutputs()) && "Output index out of bound."); std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { children.push_back(mChildren[outId][i].lock()); } return children; } bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { assert((outId < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { if (mChildren[outId][j].lock() == nodePtr) { mChildren[outId].erase(mChildren[outId].begin() + j); mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j); removed = true; break; } } return removed; } void Aidge::Node::resetConnections(bool includeLearnableParam) { // remove every parents reference to it IOIndex_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbData(); for (IOIndex_t i = 0; i < nbRemovedInputs; ++i) { std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); if (parent.first) { // number of children linked to the parent's output while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} } // every reference to this object as child has been removed // removing reference to parents. mParents[i] = nullptr; mIdOutParents[i] = gk_IODefaultIndex; } for (IOIndex_t i = 0; i < nbOutputs(); ++i) { for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) { child.first->removeParent(child.second); } mChildren[i] = std::vector<std::weak_ptr<Node>>(); mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to for (auto& graph : views()) { // if keeping connections with LEarnable Parameters, then also remove them from graph graph->remove(shared_from_this(), !includeLearnableParam); } } /////////////////////////////////////////////////////// // CLONE /////////////////////////////////////////////////////// Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { return std::make_shared<Node>(mOperator, mName); } Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone(); return std::make_shared<Node>(op, mName); } Aidge::NodePtr Aidge::Node::clone() const { return std::make_shared<Node>(mOperator->clone(), mName); } std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ std::set<Aidge::NodePtr> out; nodeSee.insert(shared_from_this()); if(delta == 0) { out.insert(shared_from_this()); }else if (delta > 0){ for (const NodePtr& node : getChildren()) { if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ out.insert(ch); } } } }else{ for (const NodePtr& node : getParents()) { if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ out.insert(pr); } } } } return out; } ///////////////////////////////////////////////////////////////////////////////////////////// // private /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// /////////////////////////////////////////////////////// // OPERATORS /////////////////////////////////////////////////////// /////////////////////////////////////////////////////// // TENSOR MANAGEMENT ///////////////////////////////////////////////////////