Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Node.cpp 13.56 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "graph/Node.hpp"

#include "graph/GraphView.hpp"
#include "operator/Producer.hpp"
#include <memory>
#include <vector>
#include "utils/Types.h"

Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name)
    : mName((name == nullptr) ? std::string() : std::string(name)),
      mOperator(op),
      mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)),
      mChildren(std::vector<std::vector<std::shared_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()),
                                                                std::vector<std::shared_ptr<Node>>())),
      mIdInChildren(
              std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())),
      mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) {
    // ctor
}

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> ctors) {
    assert((ctors.size() == nbDataInputs()) && "Wrong number of arguments.\n");
    for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) {
        assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
    }
    IOIndex_t i = 0;
    for (const Connector &ctor : ctors) {
        if (ctor.node() != nullptr) {  // ctor must be associated with a node
            ctor.node()->addChild(shared_from_this(), ctor.index(), i++);
        }
    }
    return Connector(shared_from_this());
}

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////

void Aidge::Node::setName(const std::string &name) { mName = name; }

///////////////////////////////////////////////////////
//        OPERATORS
///////////////////////////////////////////////////////

void Aidge::Node::forward() {
    assert((mOperator != nullptr) && "No Operator interface provided, can't run forward().\n");
    mOperator->forward();
}

void Aidge::Node::backward() {
    assert((mOperator != nullptr) && "No Operator interface provided, can't run backward().\n");
    mOperator->backward();
}

///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////

bool Aidge::Node::valid() const {
    for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbInputs(); ++i) {
        if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) {
            return false;
        }
    }
    return true;
}

Aidge::IONb_t Aidge::Node::getNbFreeDataInputs() const {
    IONb_t nbFreeDataIn = 0;
    for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbInputs(); ++i) {
        if (input(i).second < 0) {
            ++nbFreeDataIn;
        }
    }
    return nbFreeDataIn;
}

std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::Node::dataInputs() const {
    std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
            std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataInputs());
    for (std::size_t i = 0; i < static_cast<std::size_t>(nbDataInputs()); ++i) {
        res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
    }
    return res;
}

std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const {
    std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
        std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs());
    for (std::size_t i = 0; i < nbInputs(); ++i) {
      res[i] =
          std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
    }
    return res;
}

void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) {
    assert((idx != gk_IODefaultIndex) && (static_cast<IONb_t>(idx) < nbInputs()) && "Parent index out of bound.");
    if (mParents[idx] != nullptr) {
        mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
        removeParent(idx);
    }
    std::shared_ptr<Node> newConstantNode = Producer(tensor);
    newConstantNode->addChild(shared_from_this(), 0, idx);
    for (auto& graphPtr : views()) {
        graphPtr->add(newConstantNode);
    }
}

std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::Node::outputs() const {
    std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs =
            std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size());
    for (std::size_t i = 0; i < mIdInChildren.size(); ++i) {
        listOutputs[i] = output(static_cast<IOIndex_t>(i));
    }
    return listOutputs;
}

std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::Node::output(Aidge::IOIndex_t outID) const {
    std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs =
            std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outID].size());
    for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) {
        listOutputs[i] =
                std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outID][i], mIdInChildren[outID][i]);
    }
    return listOutputs;
}

Aidge::IONb_t Aidge::Node::nbValidInputs() const {
    IONb_t counter = 0;
    for (IONb_t i = 0; i < nbInputs(); ++i) {
        if (mIdOutParents[static_cast<std::size_t>(i)] < 0) ++counter;
    }
    return counter;
}

Aidge::IONb_t Aidge::Node::nbValidOutputs() const {
    IONb_t counter = 0;
    if (mIdInChildren.size() == 0) return 0;
    for (std::size_t i = 0; i < nbOutputs(); ++i) {
        if (mIdInChildren[i].size() > 0U) counter++;
    }
    return counter;
}

void Aidge::Node::setInputId(IOIndex_t inId, IOIndex_t newNodeOutID) {
    assert(inId != gk_IODefaultIndex && (static_cast<IONb_t>(inId) < nbInputs()) && "Must be a valid index");
    if (mIdOutParents[inId] != gk_IODefaultIndex) {
        std::printf("Warning: filling a Tensor already attributed\n");
        auto originalParent = input(inId);
        // remove original parent reference to child
        // find the output ID for original Parent
        // find first occurence of child in the output's children
        originalParent.first->removeChild(shared_from_this(), originalParent.second);
    }
    mIdOutParents[inId] = newNodeOutID;
}

///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////

void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) {
    assert((otherInId != gk_IODefaultIndex) && (static_cast<IONb_t>(otherInId) < otherNode->nbInputs()) &&
           "Input index out of bound.");
    assert((outId != gk_IODefaultIndex) && (static_cast<IONb_t>(outId) < nbOutputs()) && "Output index out of bound.");
    if (otherNode->input(otherInId).second >= 0) {
        std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId);
    }
    // manage tensors and potential previous parent
    otherNode->setInputId(otherInId, outId);
    otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId));
    // manage nodes
    mChildren[outId].push_back(otherNode);
    mIdInChildren[outId].push_back(otherInId);
    otherNode->addParent(shared_from_this(), otherInId);
}

void Aidge::Node::addChildView(std::shared_ptr<GraphView> other_graph, const IOIndex_t outID,
                              std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
    assert((otherInId.second != gk_IODefaultIndex) &&
           (static_cast<IONb_t>(otherInId.second) < otherInId.first->nbInputs()) &&
           "Other graph input index out of bound.");
    assert((outID != gk_IODefaultIndex) && (static_cast<IONb_t>(outID) < nbOutputs()) && "Output index out of bound.");
    std::set<std::shared_ptr<Node>> inNodes = other_graph->inputNodes();
    if (inNodes.size() == std::size_t(0)) {  // no input Node
        printf("Cannot add GraphView to the Node. No input node detected.\n");
    } else  // inNodes.size() >= 1
    {
        assert((inNodes.find(otherInId.first) != inNodes.end()));  // assert it really is an input node
        addChildOp(otherInId.first, outID, otherInId.second);
    }
}

void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) {
    otherInId = (otherInId >= 0) ? otherInId : otherNode->getFirstFreeDataInput();
    addChildOp(otherNode, outId, otherInId);
}

void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId,
                          std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
    if (!otherInId.first) {
        assert((otherView->inputNodes().size() == 1U) &&
               "Specify an input Node for the GraphView. More or less than one "
               "Node is not explicit.");
        otherInId.first = *(otherView->inputNodes().begin());
    }
    otherInId.second = (otherInId.second >= 0) ? otherInId.second : otherInId.first->getFirstFreeDataInput();
    addChildView(otherView, outId, otherInId);
}

void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) {
    if (getParents(inId) != nullptr) {
        printf("Warning, you're replacing a Parent.\n");
    }
    assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Input index out of bound.");
    mParents[inId] = other_node;
}

std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getParents() const { return mParents; }

std::shared_ptr<Aidge::Node> Aidge::Node::popParent(const IOIndex_t inId) {
    assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Input index out of bound.");
    std::shared_ptr<Node> val = mParents[inId];
    removeParent(inId);
    return val;
}

bool Aidge::Node::removeParent(const IOIndex_t inId) {
    assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Parent index out of bound.");
    if (mParents[inId]) {
        mParents[inId] = nullptr;
        mIdOutParents[inId] = gk_IODefaultIndex;
        return true;
    }
    return false;
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
    std::set<std::shared_ptr<Node>> children;
    for (const std::vector<std::shared_ptr<Node>> &childrenOfOneOutput : mChildren) {
        children.insert(childrenOfOneOutput.begin(), childrenOfOneOutput.end());
    }
    return children;
}

std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { return mChildren; }

std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(IOIndex_t outID) const {
    assert((outID != gk_IODefaultIndex) && (static_cast<IONb_t>(outID) < nbOutputs()) && "Output index out of bound.");
    return mChildren[outID];
}

bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) {
    assert((outId != gk_IODefaultIndex) && (static_cast<IONb_t>(outId) < nbOutputs()) && "Child index out of bound.");
    bool removed = false;
    for (std::size_t j = 0; j < mChildren[outId].size(); ++j) {
        if (mChildren[outId][j] == nodePtr) {
            mChildren[outId].erase(mChildren[outId].begin() + j);
            mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j);
            removed = true;
            break;
        }
    }
    return removed;
}

void Aidge::Node::resetConnections(bool includeLearnableParam) {
    // remove every parents reference to it
    IONb_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbDataInputs();
    for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbRemovedInputs; ++i) {
        std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i);
        if (parent.first) {
            // number of children linked to the parent's output
            while(parent.first->removeChild(shared_from_this(), parent.second) == true) {}
        }
        // every reference to this object as child has been removed
        // removing reference to parents.
        mParents[i] = nullptr;
        mIdOutParents[i] = gk_IODefaultIndex;
    }
    for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbOutputs(); ++i) {
        for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) {
            child.first->removeParent(child.second);
        }
        mChildren[i] = std::vector<std::shared_ptr<Node>>();
        mIdInChildren[i] = std::vector<IOIndex_t>();
    }
    // removing this Node from every GraphView it belongs to
    for (auto& graph : views()) {
        // if keeping connections with LEarnable Parameters, then also remove them from graph
        graph->remove(shared_from_this(), !includeLearnableParam);
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////
// private

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

///////////////////////////////////////////////////////
//        OPERATORS
///////////////////////////////////////////////////////

///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////