Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Node.cpp 22.73 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/graph/Node.hpp"

#include <memory>
#include <vector>

#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/future_std/any.hpp"

Aidge::Node::Node(std::shared_ptr<Operator> op, std::shared_ptr<DynamicAttributes> attrs)
    : mAttrs(attrs),
      mOperator(op),
      mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()),
                                                  nullptr)),
      mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(
              static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())),
      mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()),
                                                        std::vector<IOIndex_t>())),
      mIdOutParents(
              std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex))
{
    mForward.push_back([this](){ this->mOperator->forward(); return true; });
    // mForward.push_back(std::bind(&Operator::forward, mOperator.get()));
    mBackward.push_back([this](){ this->mOperator->backward(); return true; });
    op->setInheritedAttrs(attrs);
}

// Aidge::Node::Node(std::shared_ptr<Operator> op, const DynamicAttributes& attrs)
//     : Node(op, std::make_shared<DynamicAttributes>(attrs)) {}

Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
    : Node(op, std::make_shared<DynamicAttributes>(std::map<std::string, future_std::any>({std::make_pair("name", future_std::any(name))})))
{
    //ctor
}

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) {
    IOIndex_t idx = 0;
    for (const auto& ctor : ctors) {
        // Skip to next possible input idx
        for (; idx < nbInputs() && (inputCategory(idx) != InputCategory::Data && inputCategory(idx) != InputCategory::OptionalData); ++idx) {}

        AIDGE_ASSERT(idx < nbInputs(), "Too many input connectors ({}) vs available node inputs.", ctors.size());
        AIDGE_ASSERT(input(idx).second == gk_IODefaultIndex, "Data input#{} connection is not free.", idx);

        if (ctor.node() != nullptr) {  // ctor must be associated with a node
            ctor.node()->addChild(shared_from_this(), ctor.index(), idx);
        }
        ++idx;
    }

    // Skip to next possible input idx
    for (; idx < nbInputs() && (inputCategory(idx) != InputCategory::Data); ++idx) {}
    AIDGE_ASSERT(idx == nbInputs(), "Missing an input connector for Data input#{}", idx);

    return Connector(shared_from_this());
}

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////

void Aidge::Node::setName(const std::string& name) {
    for (auto graphView : views()) graphView->updateNodeName(shared_from_this(), name);
    mAttrs->setAttr<std::string>("name", name);
}

std::string Aidge::Node::createUniqueName(std::string baseName)
{
    int index = 0;
    bool nameAlreadyUsed = true;
    std::string newName;
    while (nameAlreadyUsed) {
        std::string suffix = "_" + std::to_string(index);
        newName = (index == 0) ? baseName : baseName + suffix;
        nameAlreadyUsed = false;
        for (auto graphView : views()) {
            if (graphView->inView(newName)) {
                Log::info("Node::createUniqueName(): name '{}' already used in graph '{}'", newName, graphView->name());
                nameAlreadyUsed = true;
                break;
            }
        }
        index++;
    }
    return newName;
}

///////////////////////////////////////////////////////
//        OPERATORS
///////////////////////////////////////////////////////

void Aidge::Node::forward() {
    for (auto it = mForward.begin(); it != mForward.end(); ) {
        const auto keep = (*it)();
        if (!keep) {
            it = mForward.erase(it);
        }
        else {
            ++it;
        }
    }
}

void Aidge::Node::backward() {
    for (auto it = mBackward.begin(); it != mBackward.end(); ) {
        const auto keep = (*it)();
        if (!keep) {
            it = mBackward.erase(it);
        }
        else {
            ++it;
        }
    }
}

///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////

bool Aidge::Node::valid() const {
    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
        if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) {
            return false;
        }
    }
    return true;
}

Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const {
    IOIndex_t nbFreeDataIn = 0;
    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
        if (input(i).second == gk_IODefaultIndex) {
            ++nbFreeDataIn;
        }
    }
    return nbFreeDataIn;
}

std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs()
        const {
    std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
    for (std::size_t i = 0; i < static_cast<std::size_t>(nbInputs()); ++i) {
        if (inputCategory(i) == InputCategory::Data || inputCategory(i) == InputCategory::OptionalData) {
            res.push_back(std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]));
        }
    }
    return res;
}

std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const {
    std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
            std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs());
    for (std::size_t i = 0; i < nbInputs(); ++i) {
        res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
    }
    return res;
}

std::vector<std::string> Aidge::Node::inputsNames() const {
    std::vector<std::string> res = std::vector<std::string>(nbInputs());
    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
        res[i] = inputName(i);
        /*
        if (mInputNames.count(i)) {
            res[i] = mInputNames[i];
        } else if (mParents[i]) {
            res[i] = mParents[i]->name() + "_out" + std::to_string(mIdOutParents[i]);
        } else {
            res[i] = this->name() + "_in" + std::to_string(i);
        }*/
    }
    return res;
}
std::string Aidge::Node::inputName(const Aidge::IOIndex_t inID) const {
    // nbInputs already < gk_IODefaultIndex
    AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound.");
    std::string res = "";
    if (mInputNames.count(inID)) {
        res = this->mInputNames.at(inID);
    } else if (mParents[inID]) {
        res = mParents[inID]->name() + "_out" +
        std::to_string(mIdOutParents[inID]);
    } else {
        res = this->name() + "_in" + std::to_string(inID);
    }
    if (mParents[inID] && mParents[inID]->outputName(mIdOutParents[inID]) != res) {
        Log::warn("Problem, parent node don't have same output name as this input name.");
    }
    return res;
}

std::string Aidge::Node::inputName(const Aidge::IOIndex_t inID, std::string newName) {
    // nbInputs already < gk_IODefaultIndex
    AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound.");
    this->mInputNames[inID] = newName;
    if (mParents[inID] && mParents[inID]->outputName(mIdOutParents[inID]) != newName) {
        mParents[inID]->outputName(mIdOutParents[inID], newName);
    }
    return this->mInputNames[inID];
}

std::vector<std::string> Aidge::Node::outputsNames() const {
    std::vector<std::string> listOutputs = std::vector<std::string>(mIdInChildren.size());
    for (std::size_t i = 0; i < mIdInChildren.size(); ++i) {
        listOutputs[i] = outputName(static_cast<IOIndex_t>(i));
    }
    return listOutputs;
}

std::string Aidge::Node::outputName(Aidge::IOIndex_t outID) const {
    if (mOutputNames.count(outID)) {
        return mOutputNames.at(outID);
    }
    return this->name() + "_out" + std::to_string(outID);
}

std::string Aidge::Node::outputName(Aidge::IOIndex_t outID, std::string newName) {
    this->mOutputNames[outID] = newName;
    for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) {
        if (std::shared_ptr<Node> child = mChildren[outID][i].lock()) {
            if (child->inputName(mIdInChildren[outID][i]) != newName) {
                child->inputName(mIdInChildren[outID][i], newName);
            }
        }
        else {
            Log::warn("Node::output(): dangling connection at index #{} of output #{} for node {} (of type {})", i, outID, name(), type());
        }
    }
    return this->mOutputNames[outID];
}

std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t> Aidge::Node::input(const Aidge::IOIndex_t inID) const {
    // nbInputs already < gk_IODefaultIndex
    AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound.");
    return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]);
}


// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor>
// tensor) {
//     assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
//     if (mParents[idx] != nullptr) {
//         mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
//         removeParent(idx);
//     }
//     std::shared_ptr<Node> newConstantNode = Producer(tensor);
//     newConstantNode->addChild(shared_from_this(), 0, idx);
//     for (auto& graphPtr : views()) {
//         graphPtr->add(newConstantNode);
//     }
// }

std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::Node::outputs() const {
    std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs =
            std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(
                    mIdInChildren.size());
    for (std::size_t i = 0; i < mIdInChildren.size(); ++i) {
        listOutputs[i] = output(static_cast<IOIndex_t>(i));
    }
    return listOutputs;
}

std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output(
        Aidge::IOIndex_t outId) const {
    std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs =
            std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>();
    for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) {
        if (std::shared_ptr<Node> child = mChildren[outId][i].lock()) {
            listOutputs.push_back(std::pair<std::shared_ptr<Node>, IOIndex_t>(child,
                                                                        mIdInChildren[outId][i]));
        }
        else {
            Log::warn("Node::output(): dangling connection at index #{} of output #{} for node {} (of type {})", i, outId, name(), type());
        }
    }
    return listOutputs;
}

Aidge::IOIndex_t Aidge::Node::nbValidInputs() const {
    IOIndex_t counter = 0;
    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
        if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) ++counter;
    }
    return counter;
}

Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const {
    IOIndex_t counter = 0;
    if (mIdInChildren.size() == 0) return 0;
    for (std::size_t i = 0; i < nbOutputs(); ++i) {
        if (mIdInChildren[i].size() > 0U) counter++;
    }
    return counter;
}

std::set<std::shared_ptr<Aidge::GraphView>> Aidge::Node::views() const noexcept {
    std::set<std::shared_ptr<GraphView>> res;
    for (const auto &v : mViews) {
        if (auto p = v.lock()) {
        res.insert(p);
        }
    }
    return res;
}

void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) {
    AIDGE_ASSERT(inId != gk_IODefaultIndex && inId < nbInputs(),
        "Input index ({}) is out of bound ({}) for node {} (of type {})",
        inId, nbInputs(), name(), type());
    if (mIdOutParents[inId] != gk_IODefaultIndex) {
        Log::notice("Filling a Tensor already attributed.");
        auto originalParent = input(inId);
        // remove original parent reference to child
        // find the output ID for original Parent
        // find first occurrence of child in the output's children
        originalParent.first->removeChild(shared_from_this(), originalParent.second);
    }
    mIdOutParents[inId] = newNodeoutId;
}

///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////

void Aidge::Node::addChildOp(const std::shared_ptr<Node>& otherNode, const IOIndex_t outId,
                             const IOIndex_t otherInId) {
    AIDGE_ASSERT(otherInId < otherNode->nbInputs(),
        "Input index (#{}) of the node {} (of type {}) is out of bound (it has {} inputs), when trying to add it as a child of node {} (of type {})",
        otherInId, otherNode->name(), otherNode->type(), otherNode->nbInputs(), name(), type());
    AIDGE_ASSERT(outId < nbOutputs(),
        "Output index (#{}) of the node {} (of type {}) is out of bound (it has {} outputs), when trying to add the child node {} (of type {})",
        outId, name(), type(), nbOutputs(), otherNode->name(), otherNode->type());
    if (otherNode.use_count() == 1) {
        Log::debug("Node::addChild(): the node {} (of type {}) only holds a weak reference to the added child node {} (of type {})."
            "If the child node goes out of scope, it will be destructed, leading to a dangling connection."
            "To avoid this message, consider adding the child node to a GraphView first.", name(), type(), otherNode->name(), otherNode->type());
    }
    if (otherNode->input(otherInId).second != gk_IODefaultIndex) {
        Log::notice("the {}-th Parent of the child node {} (of type {}) already existed", otherInId, otherNode->name(), otherNode->type());
    }
    // manage tensors and potential previous parent
    otherNode->setInputId(otherInId, outId);
    otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId));
    // manage nodes
    mChildren[outId].push_back(std::weak_ptr<Node>(otherNode));
    mIdInChildren[outId].push_back(otherInId);
    otherNode->addParent(shared_from_this(), otherInId);
}

void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId,
                               std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
    const auto inNodes = otherGraph->inputNodes();
    AIDGE_ASSERT(otherInId.first != nullptr && inNodes.find(otherInId.first) != inNodes.end(),
        "Node {} (of type {}) is not a valid input node of GraphView {}, when trying to add it as a child of node {} (of type {})",
        (otherInId.first) ? otherInId.first->name() : "#nullptr", (otherInId.first) ? otherInId.first->type() : "", otherGraph->name(), name(), type());
    addChildOp(otherInId.first, outId, otherInId.second);
}

void Aidge::Node::addChild(const std::shared_ptr<Node>& otherNode, const IOIndex_t outId,
                           IOIndex_t otherInId) {
    if (otherNode) {
        otherInId =
                (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput();
        addChildOp(otherNode, outId, otherInId);
    }
}

void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId,
                           std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
    if (!otherInId.first) {
        AIDGE_ASSERT(otherView->inputNodes().size() == 1U,
            "Input node of GraphView {} need to be specified, because it has more than one input ({} inputs), when trying to add it as a child of node {} (of type {})",
            otherView->name(), otherView->inputNodes().size(), name(), type());
        otherInId.first = *(otherView->inputNodes().begin());
    }
    otherInId.second = (otherInId.second != gk_IODefaultIndex)
                               ? otherInId.second
                               : otherInId.first->getFirstFreeDataInput();
    addChildView(otherView, outId, otherInId);
}

void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) {
    if (getParent(inId) != nullptr) {
        Log::notice("You are replacing an existing parent for node {} (of type {}).", name(), type());
    }
    AIDGE_ASSERT(inId != gk_IODefaultIndex && inId < nbInputs(),
        "Input index ({}) is out of bound ({}) for node {} (of type {})",
        inId, nbInputs(), name(), type());
    mParents[inId] = other_node;
}

std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getParents() const { return mParents; }

std::shared_ptr<Aidge::Node> Aidge::Node::popParent(const IOIndex_t inId) {
    AIDGE_ASSERT(inId != gk_IODefaultIndex && inId < nbInputs(),
        "Input index ({}) is out of bound ({}) for node {} (of type {})",
        inId, nbInputs(), name(), type());
    std::shared_ptr<Node> val = mParents[inId];
    removeParent(inId);
    return val;
}

bool Aidge::Node::removeParent(const IOIndex_t inId) {
    AIDGE_ASSERT(inId != gk_IODefaultIndex && inId < nbInputs(),
        "Input index ({}) is out of bound ({}) for node {} (of type {})",
        inId, nbInputs(), name(), type());
    if (mParents[inId]) {
        mParents[inId] = nullptr;
        mIdOutParents[inId] = gk_IODefaultIndex;
        return true;
    }
    return false;
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
    std::set<std::shared_ptr<Node>> children;
    size_t outId = 0;
    for (const auto& childrenOfOneOutput : mChildren) {
        for (const auto& oneChild : childrenOfOneOutput) {
            if (std::shared_ptr<Node> child = oneChild.lock()) {
                children.insert(child);
            }
            else {
                Log::warn("Node::getChildren(): dangling connection at output #{} for node {} (of type {})", outId, name(), type());
            }
        }
        ++outId;
    }
    return children;
}

std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const {
    auto children =
            std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size());
    for (std::size_t outId = 0; outId < mChildren.size(); ++outId) {
        children[outId] = getChildren(outId);
    }
    return children;
}

std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const {
    assert((outId < nbOutputs()) && "Output index out of bound.");
    std::vector<std::shared_ptr<Node>> children;
    for (std::size_t i = 0; i < mChildren[outId].size(); ++i) {
        if (std::shared_ptr<Node> child = mChildren[outId][i].lock()) {
            children.push_back(child);
        }
        else {
            Log::warn("Node::getChildren(): dangling connection at index #{} of output #{} for node {} (of type {})", i, outId, name(), type());
        }
    }
    return children;
}

bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr,
                              const Aidge::IOIndex_t outId) {
    assert((outId < nbOutputs()) && "Child index out of bound.");
    bool removed = false;
    for (std::size_t j = 0; j < mChildren[outId].size(); ++j) {
        if (mChildren[outId][j].lock() == nodePtr) {
            mChildren[outId].erase(mChildren[outId].begin() + j);
            mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j);
            removed = true;
            break;
        }
    }
    return removed;
}

void Aidge::Node::resetConnections(bool includeLearnableParam) {
    // remove every parents reference to it
    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
        if (includeLearnableParam || inputCategory(i) == InputCategory::Data || inputCategory(i) == InputCategory::OptionalData) {
            std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i);
            if (parent.first) {
                // number of children linked to the parent's output
                while (parent.first->removeChild(shared_from_this(), parent.second) == true) {
                }
            }
            // every reference to this object as child has been removed
            // removing reference to parents.
            mParents[i] = nullptr;
            mIdOutParents[i] = gk_IODefaultIndex;
        }
    }
    for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
        for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) {
            child.first->removeParent(child.second);
        }
        mChildren[i] = std::vector<std::weak_ptr<Node>>();
        mIdInChildren[i] = std::vector<IOIndex_t>();
    }
    // removing this Node from every GraphView it belongs to
    // for (auto& graph : views()) {
    //     // if keeping connections with LEarnable Parameters, then also remove them from graph
    //     graph->remove(shared_from_this(), !includeLearnableParam);
    // }
}

///////////////////////////////////////////////////////
//        CLONE
///////////////////////////////////////////////////////

Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
    return std::make_shared<Node>(mOperator, std::make_shared<DynamicAttributes>(*mAttrs));
}

Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
    std::shared_ptr<Operator> op =
            (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone();

    return std::make_shared<Node>(op, std::make_shared<DynamicAttributes>(*mAttrs));
}

Aidge::NodePtr Aidge::Node::clone() const {
    return std::make_shared<Node>(mOperator->clone(), std::make_shared<DynamicAttributes>(*mAttrs));
}

Aidge::Node::~Node() = default;

// namespace Aidge {
// std::ostream& operator << (std::ostream& os, Aidge::Node& n) {
//     using namespace std;
//     os << "Node :\tName :\t\"" << n.name() << "\"\tType : \"" << n.getOperator()->type()<< "\"\tIN/OUTputs : "<< n.nbInputs() <<"/"<< n.nbOutputs() <<endl;
//     os << "\tParents :\t" ;
//     for (const auto & p : n.getParents())
//     {
//         os << "\"" <<p->name() << "\"\t";
//     }
//     os << endl;
//     os << "\tChildren :\t" ;
//     for (const auto & c : n.getChildren())
//     {
//         os << "\"" << c->name() << "\"\t";
//     }
//     os << endl;
//     return os;
// }
// }
/////////////////////////////////////////////////////////////////////////////////////////////
// private

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

///////////////////////////////////////////////////////
//        OPERATORS
///////////////////////////////////////////////////////

///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////