/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include "graph/Node.hpp" #include "graph/GraphView.hpp" #include "operator/Producer.hpp" #include <memory> #include <vector> #include "utils/Types.h" Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name) : mName((name == nullptr) ? std::string() : std::string(name)), mOperator(op), mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), mChildren(std::vector<std::vector<std::shared_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<std::shared_ptr<Node>>())), mIdInChildren( std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { // ctor } /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> ctors) { assert((ctors.size() == nbDataInputs()) && "Wrong number of arguments.\n"); for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); } IOIndex_t i = 0; for (const Connector &ctor : ctors) { if (ctor.node() != nullptr) { // ctor must be associated with a node ctor.node()->addChild(shared_from_this(), ctor.index(), i++); } } return Connector(shared_from_this()); } /////////////////////////////////////////////////////// // INNER /////////////////////////////////////////////////////// void Aidge::Node::setName(const std::string &name) { mName = name; } /////////////////////////////////////////////////////// // OPERATORS /////////////////////////////////////////////////////// void Aidge::Node::forward() { assert((mOperator != nullptr) && "No Operator interface provided, can't run forward().\n"); mOperator->forward(); } void Aidge::Node::backward() { assert((mOperator != nullptr) && "No Operator interface provided, can't run backward().\n"); mOperator->backward(); } /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// bool Aidge::Node::valid() const { for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbInputs(); ++i) { if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) { return false; } } return true; } Aidge::IONb_t Aidge::Node::getNbFreeDataInputs() const { IONb_t nbFreeDataIn = 0; for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbInputs(); ++i) { if (input(i).second < 0) { ++nbFreeDataIn; } } return nbFreeDataIn; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataInputs()); for (std::size_t i = 0; i < static_cast<std::size_t>(nbDataInputs()); ++i) { res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); for (std::size_t i = 0; i < nbInputs(); ++i) { res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { assert((idx != gk_IODefaultIndex) && (static_cast<IONb_t>(idx) < nbInputs()) && "Parent index out of bound."); if (mParents[idx] != nullptr) { mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); removeParent(idx); } std::shared_ptr<Node> newConstantNode = Producer(tensor); newConstantNode->addChild(shared_from_this(), 0, idx); for (auto& graphPtr : views()) { graphPtr->add(newConstantNode); } } std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::Node::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { listOutputs[i] = output(static_cast<IOIndex_t>(i)); } return listOutputs; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output(Aidge::IOIndex_t outID) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outID].size()); for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) { listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outID][i], mIdInChildren[outID][i]); } return listOutputs; } Aidge::IONb_t Aidge::Node::nbValidInputs() const { IONb_t counter = 0; for (IONb_t i = 0; i < nbInputs(); ++i) { if (mIdOutParents[static_cast<std::size_t>(i)] < 0) ++counter; } return counter; } Aidge::IONb_t Aidge::Node::nbValidOutputs() const { IONb_t counter = 0; if (mIdInChildren.size() == 0) return 0; for (std::size_t i = 0; i < nbOutputs(); ++i) { if (mIdInChildren[i].size() > 0U) counter++; } return counter; } void Aidge::Node::setInputId(IOIndex_t inId, IOIndex_t newNodeOutID) { assert(inId != gk_IODefaultIndex && (static_cast<IONb_t>(inId) < nbInputs()) && "Must be a valid index"); if (mIdOutParents[inId] != gk_IODefaultIndex) { std::printf("Warning: filling a Tensor already attributed\n"); auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent // find first occurence of child in the output's children originalParent.first->removeChild(shared_from_this(), originalParent.second); } mIdOutParents[inId] = newNodeOutID; } /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { assert((otherInId != gk_IODefaultIndex) && (static_cast<IONb_t>(otherInId) < otherNode->nbInputs()) && "Input index out of bound."); assert((outId != gk_IODefaultIndex) && (static_cast<IONb_t>(outId) < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second >= 0) { std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId); } // manage tensors and potential previous parent otherNode->setInputId(otherInId, outId); otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId)); // manage nodes mChildren[outId].push_back(otherNode); mIdInChildren[outId].push_back(otherInId); otherNode->addParent(shared_from_this(), otherInId); } void Aidge::Node::addChildView(std::shared_ptr<GraphView> other_graph, const IOIndex_t outID, std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { assert((otherInId.second != gk_IODefaultIndex) && (static_cast<IONb_t>(otherInId.second) < otherInId.first->nbInputs()) && "Other graph input index out of bound."); assert((outID != gk_IODefaultIndex) && (static_cast<IONb_t>(outID) < nbOutputs()) && "Output index out of bound."); std::set<std::shared_ptr<Node>> inNodes = other_graph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node addChildOp(otherInId.first, outID, otherInId.second); } } void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { otherInId = (otherInId >= 0) ? otherInId : otherNode->getFirstFreeDataInput(); addChildOp(otherNode, outId, otherInId); } void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { if (!otherInId.first) { assert((otherView->inputNodes().size() == 1U) && "Specify an input Node for the GraphView. More or less than one " "Node is not explicit."); otherInId.first = *(otherView->inputNodes().begin()); } otherInId.second = (otherInId.second >= 0) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); addChildView(otherView, outId, otherInId); } void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) { if (getParents(inId) != nullptr) { printf("Warning, you're replacing a Parent.\n"); } assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Input index out of bound."); mParents[inId] = other_node; } std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getParents() const { return mParents; } std::shared_ptr<Aidge::Node> Aidge::Node::popParent(const IOIndex_t inId) { assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Input index out of bound."); std::shared_ptr<Node> val = mParents[inId]; removeParent(inId); return val; } bool Aidge::Node::removeParent(const IOIndex_t inId) { assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Parent index out of bound."); if (mParents[inId]) { mParents[inId] = nullptr; mIdOutParents[inId] = gk_IODefaultIndex; return true; } return false; } std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; for (const std::vector<std::shared_ptr<Node>> &childrenOfOneOutput : mChildren) { children.insert(childrenOfOneOutput.begin(), childrenOfOneOutput.end()); } return children; } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { return mChildren; } std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(IOIndex_t outID) const { assert((outID != gk_IODefaultIndex) && (static_cast<IONb_t>(outID) < nbOutputs()) && "Output index out of bound."); return mChildren[outID]; } bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { assert((outId != gk_IODefaultIndex) && (static_cast<IONb_t>(outId) < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { if (mChildren[outId][j] == nodePtr) { mChildren[outId].erase(mChildren[outId].begin() + j); mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j); removed = true; break; } } return removed; } void Aidge::Node::resetConnections(bool includeLearnableParam) { // remove every parents reference to it IONb_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbDataInputs(); for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbRemovedInputs; ++i) { std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); if (parent.first) { // number of children linked to the parent's output while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} } // every reference to this object as child has been removed // removing reference to parents. mParents[i] = nullptr; mIdOutParents[i] = gk_IODefaultIndex; } for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbOutputs(); ++i) { for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) { child.first->removeParent(child.second); } mChildren[i] = std::vector<std::shared_ptr<Node>>(); mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to for (auto& graph : views()) { // if keeping connections with LEarnable Parameters, then also remove them from graph graph->remove(shared_from_this(), !includeLearnableParam); } } ///////////////////////////////////////////////////////////////////////////////////////////// // private /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// /////////////////////////////////////////////////////// // OPERATORS /////////////////////////////////////////////////////// /////////////////////////////////////////////////////// // TENSOR MANAGEMENT ///////////////////////////////////////////////////////