Skip to content
Snippets Groups Projects
Commit 861dd77c authored by Maxence Naud's avatar Maxence Naud
Browse files

Remove the 'remove' fonction from 'resetConnection' in Node.cpp

parent da735299
No related branches found
No related tags found
1 merge request!53GraphView inputs/outputs ordering
...@@ -11,22 +11,25 @@ ...@@ -11,22 +11,25 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Producer.hpp"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
: mName(name), : mName(name),
mOperator(op), mOperator(op),
mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()),
mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), nullptr)),
std::vector<std::weak_ptr<Node>>())), mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(
mIdInChildren( static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())),
std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()),
mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { std::vector<IOIndex_t>())),
mIdOutParents(
std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) {
// ctor // ctor
} }
...@@ -34,14 +37,15 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) ...@@ -34,14 +37,15 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
// FUNCTIONAL DESCRIPTION // FUNCTIONAL DESCRIPTION
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) {
assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); assert((ctors.size() == nbData()) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { for (std::pair<std::shared_ptr<Node>, IOIndex_t>& input : inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); assert((gk_IODefaultIndex == input.second) &&
(void) input; // avoid unused warning "At least one input connection is not free.\n");
(void)input; // avoid unused warning
} }
IOIndex_t i = 0; IOIndex_t i = 0;
for (const Connector &ctor : ctors) { for (const Connector& ctor : ctors) {
if (ctor.node() != nullptr) { // ctor must be associated with a node if (ctor.node() != nullptr) { // ctor must be associated with a node
ctor.node()->addChild(shared_from_this(), ctor.index(), i++); ctor.node()->addChild(shared_from_this(), ctor.index(), i++);
} }
...@@ -53,7 +57,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { ...@@ -53,7 +57,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) {
// INNER // INNER
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
void Aidge::Node::setName(const std::string &name) { mName = name; } void Aidge::Node::setName(const std::string& name) { mName = name; }
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// OPERATORS // OPERATORS
...@@ -92,8 +96,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { ...@@ -92,8 +96,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const {
return nbFreeDataIn; return nbFreeDataIn;
} }
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs()
Aidge::Node::dataInputs() const { const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData());
for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) {
...@@ -104,15 +108,15 @@ Aidge::Node::dataInputs() const { ...@@ -104,15 +108,15 @@ Aidge::Node::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs());
for (std::size_t i = 0; i < nbInputs(); ++i) { for (std::size_t i = 0; i < nbInputs(); ++i) {
res[i] = res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]);
} }
return res; return res;
} }
// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { // void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor>
// tensor) {
// assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
// if (mParents[idx] != nullptr) { // if (mParents[idx] != nullptr) {
// mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); // mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
...@@ -128,20 +132,21 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No ...@@ -128,20 +132,21 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No
std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::Node::outputs() const { Aidge::Node::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs =
std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(
mIdInChildren.size());
for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { for (std::size_t i = 0; i < mIdInChildren.size(); ++i) {
listOutputs[i] = output(static_cast<IOIndex_t>(i)); listOutputs[i] = output(static_cast<IOIndex_t>(i));
} }
return listOutputs; return listOutputs;
} }
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output(
Aidge::Node::output(Aidge::IOIndex_t outId) const { Aidge::IOIndex_t outId) const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size());
for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) {
listOutputs[i] = listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(),
std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); mIdInChildren[outId][i]);
} }
return listOutputs; return listOutputs;
} }
...@@ -180,7 +185,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) ...@@ -180,7 +185,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId)
// TOPOLOGY // TOPOLOGY
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId,
const IOIndex_t otherInId) {
assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); assert((otherInId < otherNode->nbInputs()) && "Input index out of bound.");
assert((outId < nbOutputs()) && "Output index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound.");
if (otherNode->input(otherInId).second != gk_IODefaultIndex) { if (otherNode->input(otherInId).second != gk_IODefaultIndex) {
...@@ -196,33 +202,41 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou ...@@ -196,33 +202,41 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou
} }
void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId,
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); assert((otherInId.second < otherInId.first->nbInputs()) &&
"Other graph input index out of bound.");
assert((outId < nbOutputs()) && "Output index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound.");
std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes();
if (inNodes.size() == std::size_t(0)) { // no input Node if (inNodes.size() == std::size_t(0)) { // no input Node
printf("Cannot add GraphView to the Node. No input node detected.\n"); printf("Cannot add GraphView to the Node. No input node detected.\n");
} else // inNodes.size() >= 1 } else // inNodes.size() >= 1
{ {
assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node assert((inNodes.find(otherInId.first) !=
inNodes.end())); // assert it really is an input node
addChildOp(otherInId.first, outId, otherInId.second); addChildOp(otherInId.first, outId, otherInId.second);
} }
} }
void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId,
otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); IOIndex_t otherInId) {
addChildOp(otherNode, outId, otherInId); if (otherNode) {
otherInId =
(otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput();
addChildOp(otherNode, outId, otherInId);
}
} }
void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId,
std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
if (!otherInId.first) { if (!otherInId.first) {
assert((otherView->inputNodes().size() == 1U) && assert((otherView->inputNodes().size() == 1U) &&
"Specify an input Node for the GraphView. More or less than one " "Specify an input Node for the GraphView. More or less than one "
"Node is not explicit."); "Node is not explicit.");
otherInId.first = *(otherView->inputNodes().begin()); otherInId.first = *(otherView->inputNodes().begin());
} }
otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); otherInId.second = (otherInId.second != gk_IODefaultIndex)
? otherInId.second
: otherInId.first->getFirstFreeDataInput();
addChildView(otherView, outId, otherInId); addChildView(otherView, outId, otherInId);
} }
...@@ -255,8 +269,8 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) { ...@@ -255,8 +269,8 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) {
std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
std::set<std::shared_ptr<Node>> children; std::set<std::shared_ptr<Node>> children;
for (const auto &childrenOfOneOutput : mChildren) { for (const auto& childrenOfOneOutput : mChildren) {
for (const auto &oneChild : childrenOfOneOutput) { for (const auto& oneChild : childrenOfOneOutput) {
children.insert(oneChild.lock()); children.insert(oneChild.lock());
} }
} }
...@@ -264,7 +278,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { ...@@ -264,7 +278,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
} }
std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const {
std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); std::vector<std::vector<std::shared_ptr<Node>>> children =
std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size());
for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { for (std::size_t outId = 0; outId < mChildren.size(); ++outId) {
children[outId] = getChildren(outId); children[outId] = getChildren(outId);
} }
...@@ -273,14 +288,16 @@ std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedCh ...@@ -273,14 +288,16 @@ std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedCh
std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const {
assert((outId < nbOutputs()) && "Output index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound.");
std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); std::vector<std::shared_ptr<Node>> children =
std::vector<std::shared_ptr<Node>>(mChildren[outId].size());
for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { for (std::size_t i = 0; i < mChildren[outId].size(); ++i) {
children.push_back(mChildren[outId][i].lock()); children.push_back(mChildren[outId][i].lock());
} }
return children; return children;
} }
bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr,
const Aidge::IOIndex_t outId) {
assert((outId < nbOutputs()) && "Child index out of bound."); assert((outId < nbOutputs()) && "Child index out of bound.");
bool removed = false; bool removed = false;
for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { for (std::size_t j = 0; j < mChildren[outId].size(); ++j) {
...@@ -301,7 +318,8 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { ...@@ -301,7 +318,8 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i);
if (parent.first) { if (parent.first) {
// number of children linked to the parent's output // number of children linked to the parent's output
while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} while (parent.first->removeChild(shared_from_this(), parent.second) == true) {
}
} }
// every reference to this object as child has been removed // every reference to this object as child has been removed
// removing reference to parents. // removing reference to parents.
...@@ -316,24 +334,23 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { ...@@ -316,24 +334,23 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
mIdInChildren[i] = std::vector<IOIndex_t>(); mIdInChildren[i] = std::vector<IOIndex_t>();
} }
// removing this Node from every GraphView it belongs to // removing this Node from every GraphView it belongs to
for (auto& graph : views()) { // for (auto& graph : views()) {
// if keeping connections with LEarnable Parameters, then also remove them from graph // // if keeping connections with LEarnable Parameters, then also remove them from graph
graph->remove(shared_from_this(), !includeLearnableParam); // graph->remove(shared_from_this(), !includeLearnableParam);
} // }
} }
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// CLONE // CLONE
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
return std::make_shared<Node>(mOperator, mName); return std::make_shared<Node>(mOperator, mName);
} }
Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) std::shared_ptr<Operator> op =
? mOperator (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone();
: mOperator->clone();
return std::make_shared<Node>(op, mName); return std::make_shared<Node>(op, mName);
} }
...@@ -342,27 +359,25 @@ Aidge::NodePtr Aidge::Node::clone() const { ...@@ -342,27 +359,25 @@ Aidge::NodePtr Aidge::Node::clone() const {
return std::make_shared<Node>(mOperator->clone(), mName); return std::make_shared<Node>(mOperator->clone(), mName);
} }
std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::NodePtr> nodeSee) {
std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){
std::set<Aidge::NodePtr> out; std::set<Aidge::NodePtr> out;
nodeSee.insert(shared_from_this()); nodeSee.insert(shared_from_this());
if(delta == 0) { if (delta == 0) {
out.insert(shared_from_this()); out.insert(shared_from_this());
}else if (delta > 0){ } else if (delta > 0) {
for (const NodePtr& node : getChildren()) { for (const NodePtr& node : getChildren()) {
if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance
for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ for (const NodePtr& ch : node->getNodeDelta(delta - 1, nodeSee)) {
out.insert(ch); out.insert(ch);
} }
} }
} }
}else{ } else {
for (const NodePtr& node : getParents()) { for (const NodePtr& node : getParents()) {
if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance
for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ for (const NodePtr& pr : node->getNodeDelta(delta + 1, nodeSee)) {
out.insert(pr); out.insert(pr);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment