Skip to content
Snippets Groups Projects
Commit 861dd77c authored by Maxence Naud's avatar Maxence Naud
Browse files

Remove the 'remove' fonction from 'resetConnection' in Node.cpp

parent da735299
No related branches found
No related tags found
No related merge requests found
......@@ -11,22 +11,25 @@
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Producer.hpp"
#include <memory>
#include <vector>
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Types.h"
Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
: mName(name),
mOperator(op),
mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)),
mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()),
std::vector<std::weak_ptr<Node>>())),
mIdInChildren(
std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())),
mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) {
mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()),
nullptr)),
mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(
static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())),
mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()),
std::vector<IOIndex_t>())),
mIdOutParents(
std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) {
// ctor
}
......@@ -34,14 +37,15 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) {
Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) {
assert((ctors.size() == nbData()) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
(void) input; // avoid unused warning
for (std::pair<std::shared_ptr<Node>, IOIndex_t>& input : inputs()) {
assert((gk_IODefaultIndex == input.second) &&
"At least one input connection is not free.\n");
(void)input; // avoid unused warning
}
IOIndex_t i = 0;
for (const Connector &ctor : ctors) {
for (const Connector& ctor : ctors) {
if (ctor.node() != nullptr) { // ctor must be associated with a node
ctor.node()->addChild(shared_from_this(), ctor.index(), i++);
}
......@@ -53,7 +57,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) {
// INNER
///////////////////////////////////////////////////////
void Aidge::Node::setName(const std::string &name) { mName = name; }
void Aidge::Node::setName(const std::string& name) { mName = name; }
///////////////////////////////////////////////////////
// OPERATORS
......@@ -92,8 +96,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const {
return nbFreeDataIn;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::Node::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs()
const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData());
for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) {
......@@ -104,15 +108,15 @@ Aidge::Node::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs());
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs());
for (std::size_t i = 0; i < nbInputs(); ++i) {
res[i] =
std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
}
return res;
}
// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) {
// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor>
// tensor) {
// assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
// if (mParents[idx] != nullptr) {
// mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
......@@ -128,20 +132,21 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No
std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::Node::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs =
std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size());
std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(
mIdInChildren.size());
for (std::size_t i = 0; i < mIdInChildren.size(); ++i) {
listOutputs[i] = output(static_cast<IOIndex_t>(i));
}
return listOutputs;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::Node::output(Aidge::IOIndex_t outId) const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output(
Aidge::IOIndex_t outId) const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size());
for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) {
listOutputs[i] =
std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]);
listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(),
mIdInChildren[outId][i]);
}
return listOutputs;
}
......@@ -180,7 +185,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId)
// TOPOLOGY
///////////////////////////////////////////////////////
void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) {
void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId,
const IOIndex_t otherInId) {
assert((otherInId < otherNode->nbInputs()) && "Input index out of bound.");
assert((outId < nbOutputs()) && "Output index out of bound.");
if (otherNode->input(otherInId).second != gk_IODefaultIndex) {
......@@ -196,33 +202,41 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou
}
void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId,
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound.");
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
assert((otherInId.second < otherInId.first->nbInputs()) &&
"Other graph input index out of bound.");
assert((outId < nbOutputs()) && "Output index out of bound.");
std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes();
if (inNodes.size() == std::size_t(0)) { // no input Node
printf("Cannot add GraphView to the Node. No input node detected.\n");
} else // inNodes.size() >= 1
{
assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node
assert((inNodes.find(otherInId.first) !=
inNodes.end())); // assert it really is an input node
addChildOp(otherInId.first, outId, otherInId.second);
}
}
void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) {
otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput();
addChildOp(otherNode, outId, otherInId);
void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId,
IOIndex_t otherInId) {
if (otherNode) {
otherInId =
(otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput();
addChildOp(otherNode, outId, otherInId);
}
}
void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId,
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
if (!otherInId.first) {
assert((otherView->inputNodes().size() == 1U) &&
"Specify an input Node for the GraphView. More or less than one "
"Node is not explicit.");
otherInId.first = *(otherView->inputNodes().begin());
}
otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput();
otherInId.second = (otherInId.second != gk_IODefaultIndex)
? otherInId.second
: otherInId.first->getFirstFreeDataInput();
addChildView(otherView, outId, otherInId);
}
......@@ -255,8 +269,8 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) {
std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
std::set<std::shared_ptr<Node>> children;
for (const auto &childrenOfOneOutput : mChildren) {
for (const auto &oneChild : childrenOfOneOutput) {
for (const auto& childrenOfOneOutput : mChildren) {
for (const auto& oneChild : childrenOfOneOutput) {
children.insert(oneChild.lock());
}
}
......@@ -264,7 +278,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const {
std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size());
std::vector<std::vector<std::shared_ptr<Node>>> children =
std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size());
for (std::size_t outId = 0; outId < mChildren.size(); ++outId) {
children[outId] = getChildren(outId);
}
......@@ -273,14 +288,16 @@ std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedCh
std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const {
assert((outId < nbOutputs()) && "Output index out of bound.");
std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size());
std::vector<std::shared_ptr<Node>> children =
std::vector<std::shared_ptr<Node>>(mChildren[outId].size());
for (std::size_t i = 0; i < mChildren[outId].size(); ++i) {
children.push_back(mChildren[outId][i].lock());
}
children.push_back(mChildren[outId][i].lock());
}
return children;
}
bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) {
bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr,
const Aidge::IOIndex_t outId) {
assert((outId < nbOutputs()) && "Child index out of bound.");
bool removed = false;
for (std::size_t j = 0; j < mChildren[outId].size(); ++j) {
......@@ -301,7 +318,8 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i);
if (parent.first) {
// number of children linked to the parent's output
while(parent.first->removeChild(shared_from_this(), parent.second) == true) {}
while (parent.first->removeChild(shared_from_this(), parent.second) == true) {
}
}
// every reference to this object as child has been removed
// removing reference to parents.
......@@ -316,24 +334,23 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
mIdInChildren[i] = std::vector<IOIndex_t>();
}
// removing this Node from every GraphView it belongs to
for (auto& graph : views()) {
// if keeping connections with LEarnable Parameters, then also remove them from graph
graph->remove(shared_from_this(), !includeLearnableParam);
}
// for (auto& graph : views()) {
// // if keeping connections with LEarnable Parameters, then also remove them from graph
// graph->remove(shared_from_this(), !includeLearnableParam);
// }
}
///////////////////////////////////////////////////////
// CLONE
///////////////////////////////////////////////////////
///////////////////////////////////////////////////////
// CLONE
///////////////////////////////////////////////////////
Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
return std::make_shared<Node>(mOperator, mName);
}
Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type)
? mOperator
: mOperator->clone();
std::shared_ptr<Operator> op =
(mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone();
return std::make_shared<Node>(op, mName);
}
......@@ -342,27 +359,25 @@ Aidge::NodePtr Aidge::Node::clone() const {
return std::make_shared<Node>(mOperator->clone(), mName);
}
std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){
std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::NodePtr> nodeSee) {
std::set<Aidge::NodePtr> out;
nodeSee.insert(shared_from_this());
if(delta == 0) {
if (delta == 0) {
out.insert(shared_from_this());
}else if (delta > 0){
for (const NodePtr& node : getChildren()) {
if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance
for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){
} else if (delta > 0) {
for (const NodePtr& node : getChildren()) {
if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance
for (const NodePtr& ch : node->getNodeDelta(delta - 1, nodeSee)) {
out.insert(ch);
}
}
}
}else{
for (const NodePtr& node : getParents()) {
if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance
for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){
} else {
for (const NodePtr& node : getParents()) {
if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance
for (const NodePtr& pr : node->getNodeDelta(delta + 1, nodeSee)) {
out.insert(pr);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment