-
Maxence Naud authored
- Remove setInput in Node - Change setDatatype to setDataType in GraphView and Tensor binding - Add namespace comment - Update Node includes - Run forwardDims() only if operators use Tensors
Maxence Naud authored- Remove setInput in Node - Change setDatatype to setDataType in GraphView and Tensor binding - Add namespace comment - Update Node includes - Run forwardDims() only if operators use Tensors
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Node.cpp 15.15 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Producer.hpp"
#include <memory>
#include <vector>
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
: mName(name),
mOperator(op),
mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)),
mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()),
std::vector<std::weak_ptr<Node>>())),
mIdInChildren(
std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())),
mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) {
// ctor
}
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) {
assert((ctors.size() == nbData()) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
(void) input; // avoid unused warning
}
IOIndex_t i = 0;
for (const Connector &ctor : ctors) {
if (ctor.node() != nullptr) { // ctor must be associated with a node
ctor.node()->addChild(shared_from_this(), ctor.index(), i++);
}
}
return Connector(shared_from_this());
}
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
void Aidge::Node::setName(const std::string &name) { mName = name; }
///////////////////////////////////////////////////////
// OPERATORS
///////////////////////////////////////////////////////
void Aidge::Node::forward() {
assert((mOperator != nullptr) && "No Operator interface provided, can't run forward().\n");
mOperator->forward();
}
void Aidge::Node::backward() {
assert((mOperator != nullptr) && "No Operator interface provided, can't run backward().\n");
mOperator->backward();
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
bool Aidge::Node::valid() const {
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) {
return false;
}
}
return true;
}
Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const {
IOIndex_t nbFreeDataIn = 0;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (input(i).second == gk_IODefaultIndex) {
++nbFreeDataIn;
}
}
return nbFreeDataIn;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::Node::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData());
for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) {
res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs());
for (std::size_t i = 0; i < nbInputs(); ++i) {
res[i] =
std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
}
return res;
}
// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) {
// assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
// if (mParents[idx] != nullptr) {
// mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
// removeParent(idx);
// }
// std::shared_ptr<Node> newConstantNode = Producer(tensor);
// newConstantNode->addChild(shared_from_this(), 0, idx);
// for (auto& graphPtr : views()) {
// graphPtr->add(newConstantNode);
// }
// }
std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::Node::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs =
std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size());
for (std::size_t i = 0; i < mIdInChildren.size(); ++i) {
listOutputs[i] = output(static_cast<IOIndex_t>(i));
}
return listOutputs;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::Node::output(Aidge::IOIndex_t outId) const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size());
for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) {
listOutputs[i] =
std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]);
}
return listOutputs;
}
Aidge::IOIndex_t Aidge::Node::nbValidInputs() const {
IOIndex_t counter = 0;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) ++counter;
}
return counter;
}
Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const {
IOIndex_t counter = 0;
if (mIdInChildren.size() == 0) return 0;
for (std::size_t i = 0; i < nbOutputs(); ++i) {
if (mIdInChildren[i].size() > 0U) counter++;
}
return counter;
}
void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) {
assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index");
if (mIdOutParents[inId] != gk_IODefaultIndex) {
std::printf("Warning: filling a Tensor already attributed\n");
auto originalParent = input(inId);
// remove original parent reference to child
// find the output ID for original Parent
// find first occurence of child in the output's children
originalParent.first->removeChild(shared_from_this(), originalParent.second);
}
mIdOutParents[inId] = newNodeoutId;
}
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) {
assert((otherInId < otherNode->nbInputs()) && "Input index out of bound.");
assert((outId < nbOutputs()) && "Output index out of bound.");
if (otherNode->input(otherInId).second != gk_IODefaultIndex) {
std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId);
}
// manage tensors and potential previous parent
otherNode->setInputId(otherInId, outId);
otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId));
// manage nodes
mChildren[outId].push_back(std::weak_ptr<Node>(otherNode));
mIdInChildren[outId].push_back(otherInId);
otherNode->addParent(shared_from_this(), otherInId);
}
void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId,
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound.");
assert((outId < nbOutputs()) && "Output index out of bound.");
std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes();
if (inNodes.size() == std::size_t(0)) { // no input Node
printf("Cannot add GraphView to the Node. No input node detected.\n");
} else // inNodes.size() >= 1
{
assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node
addChildOp(otherInId.first, outId, otherInId.second);
}
}
void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) {
otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput();
addChildOp(otherNode, outId, otherInId);
}
void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId,
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
if (!otherInId.first) {
assert((otherView->inputNodes().size() == 1U) &&
"Specify an input Node for the GraphView. More or less than one "
"Node is not explicit.");
otherInId.first = *(otherView->inputNodes().begin());
}
otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput();
addChildView(otherView, outId, otherInId);
}
void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) {
if (getParent(inId) != nullptr) {
printf("Warning, you're replacing a Parent.\n");
}
assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound.");
mParents[inId] = other_node;
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getParents() const { return mParents; }
std::shared_ptr<Aidge::Node> Aidge::Node::popParent(const IOIndex_t inId) {
assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound.");
std::shared_ptr<Node> val = mParents[inId];
removeParent(inId);
return val;
}
bool Aidge::Node::removeParent(const IOIndex_t inId) {
assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Parent index out of bound.");
if (mParents[inId]) {
mParents[inId] = nullptr;
mIdOutParents[inId] = gk_IODefaultIndex;
return true;
}
return false;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
std::set<std::shared_ptr<Node>> children;
for (const auto &childrenOfOneOutput : mChildren) {
for (const auto &oneChild : childrenOfOneOutput) {
children.insert(oneChild.lock());
}
}
return children;
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const {
std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size());
for (std::size_t outId = 0; outId < mChildren.size(); ++outId) {
children[outId] = getChildren(outId);
}
return children;
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const {
assert((outId < nbOutputs()) && "Output index out of bound.");
std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size());
for (std::size_t i = 0; i < mChildren[outId].size(); ++i) {
children.push_back(mChildren[outId][i].lock());
}
return children;
}
bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) {
assert((outId < nbOutputs()) && "Child index out of bound.");
bool removed = false;
for (std::size_t j = 0; j < mChildren[outId].size(); ++j) {
if (mChildren[outId][j].lock() == nodePtr) {
mChildren[outId].erase(mChildren[outId].begin() + j);
mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j);
removed = true;
break;
}
}
return removed;
}
void Aidge::Node::resetConnections(bool includeLearnableParam) {
// remove every parents reference to it
IOIndex_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbData();
for (IOIndex_t i = 0; i < nbRemovedInputs; ++i) {
std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i);
if (parent.first) {
// number of children linked to the parent's output
while(parent.first->removeChild(shared_from_this(), parent.second) == true) {}
}
// every reference to this object as child has been removed
// removing reference to parents.
mParents[i] = nullptr;
mIdOutParents[i] = gk_IODefaultIndex;
}
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) {
child.first->removeParent(child.second);
}
mChildren[i] = std::vector<std::weak_ptr<Node>>();
mIdInChildren[i] = std::vector<IOIndex_t>();
}
// removing this Node from every GraphView it belongs to
for (auto& graph : views()) {
// if keeping connections with LEarnable Parameters, then also remove them from graph
graph->remove(shared_from_this(), !includeLearnableParam);
}
}
///////////////////////////////////////////////////////
// CLONE
///////////////////////////////////////////////////////
Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
return std::make_shared<Node>(mOperator, mName);
}
Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type)
? mOperator
: mOperator->clone();
return std::make_shared<Node>(op, mName);
}
Aidge::NodePtr Aidge::Node::clone() const {
return std::make_shared<Node>(mOperator->clone(), mName);
}
std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){
std::set<Aidge::NodePtr> out;
nodeSee.insert(shared_from_this());
if(delta == 0) {
out.insert(shared_from_this());
}else if (delta > 0){
for (const NodePtr& node : getChildren()) {
if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance
for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){
out.insert(ch);
}
}
}
}else{
for (const NodePtr& node : getParents()) {
if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance
for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){
out.insert(pr);
}
}
}
}
return out;
}
/////////////////////////////////////////////////////////////////////////////////////////////
// private
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
///////////////////////////////////////////////////////
// OPERATORS
///////////////////////////////////////////////////////
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////