Skip to content
Snippets Groups Projects
Commit ca2b839b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed many issues, still not working

parent d64f9665
No related branches found
No related tags found
1 merge request!53GraphView inputs/outputs ordering
Pipeline #34854 failed
......@@ -120,7 +120,7 @@ public:
///////////////////////////////////////////////////////
public:
/** @brief Get reference to the set of input Nodes. */
inline const std::set<NodePtr>& inputNodes() const noexcept {
inline std::set<NodePtr> inputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mInputNodes) {
nodes.insert(node.first);
......@@ -128,7 +128,7 @@ public:
return nodes;
}
/** @brief Get reference to the set of output Nodes. */
inline const std::set<NodePtr>& outputNodes() const noexcept {
inline std::set<NodePtr> outputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mOutputNodes) {
nodes.insert(node.first);
......
......@@ -47,7 +47,7 @@ public:
mGraph(op.mGraph->clone())
{
// cpy-ctor
// TODO: FIXME: mInputOps and mOutputOps are not populated!
// TODO: FIXME: mInputNodes and mOutputNodes are not populated!
// Issue: how to map new (cloned) nodes with old nodes? getNodes() does
// not garantee any order! Check issue #52.
}
......@@ -71,8 +71,8 @@ public:
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
const auto& inputOp = mInputOps[inputIdx];
inputOp.first->associateInput(inputOp.second, data);
const auto& inputOp = mInputNodes[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
......@@ -83,9 +83,9 @@ public:
mGraph->forwardDims();
// Associate outputs to micro-graph outputs for custom implementation
for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) {
const auto& outputOp = mOutputOps[outputIdx];
mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second);
for (size_t outputIdx = 0; outputIdx < mOutputNodes.size(); ++outputIdx) {
const auto& outputOp = mOutputNodes[outputIdx];
mOutputs[outputIdx] = outputOp.first->getOperator()->getOutput(outputOp.second);
}
}
......
......@@ -26,7 +26,15 @@ Aidge::Connector::Connector(std::shared_ptr<Aidge::Node> node) {
Aidge::IOIndex_t Aidge::Connector::size() const { return mNode->nbOutputs(); }
std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ctors) {
std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ctors) {
std::set<NodePtr> nodesToAdd;
for (const Connector& ctor : ctors) {
nodesToAdd.insert(ctor.node());
}
return std::make_shared<GraphView>(nodesToAdd, ctors.back().node());
// TODO: FIXME: don't understand the following code!
/*
std::shared_ptr<GraphView> graph = std::make_shared<GraphView>();
std::vector<std::shared_ptr<Node>> nodesToAdd = std::vector<std::shared_ptr<Node>>();
for (const Connector& ctor : ctors) {
......@@ -51,4 +59,5 @@ std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ct
buffer = {};
}
return graph;
*/
}
\ No newline at end of file
......@@ -385,7 +385,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
// List only the nodes that are not already present in current graph
std::set<NodePtr> nodesToAdd;
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::back_inserter(nodesToAdd));
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
do {
std::set<NodePtr> nextNodesToAdd;
......@@ -394,14 +394,17 @@ void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl
// such that the obtained GraphView inputs list will be the same, regardless
// of the evaluation order of those nodes
// (i.e. one of their child is in current GraphView)
for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) {
for (auto child : node_ptr->getChildren()) {
for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
for (auto child : (*it)->getChildren()) {
if (mNodes.find(child) != mNodes.end()) {
nextNodesToAdd.insert(node_ptr);
nodesToAdd.erase(node_ptr);
nextNodesToAdd.insert(*it);
it = nodesToAdd.erase(it);
break;
}
}
if (it == nodesToAdd.end()) {
break;
}
}
// If there is no more parent, find nodes that are direct children of current GraphView,
......@@ -412,14 +415,17 @@ void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl
// the empty() condition, but there might be edge cases that may change
// the resulting inputs/outputs order depending on evaluation order (???)
if (nextNodesToAdd.empty()) {
for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) {
for (auto parent : node_ptr->getParents()) {
for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
for (auto parent : (*it)->getParents()) {
if (mNodes.find(parent) != mNodes.end()) {
nextNodesToAdd.insert(node_ptr);
nodesToAdd.erase(node_ptr);
nextNodesToAdd.insert(*it);
it = nodesToAdd.erase(it);
break;
}
}
if (it == nodesToAdd.end()) {
break;
}
}
}
......@@ -783,6 +789,9 @@ void Aidge::GraphView::updateInputNodes() {
*/
void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Can be called several times with the same node, e.g. when addChild() is
// called on a node already part of the GraphView. In this case, inputs/outputs
// need to be updated!
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();
// Remove inputs that are not input anymore because connected to newNode
......@@ -790,19 +799,22 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
for (auto ch_ptr : orderedChilds) {
// Check that newNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.end()) {
std::size_t inputIdx = 0;
IOIndex_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode is connected to it
if (pa_ptr == newNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
// Check that it was not already the case (if node UPDATE)
if (iter != mInputNodes.end()) {
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
}
}
++inputIdx;
......@@ -814,16 +826,30 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Check if node inputs are inputs for the GraphView and add them to the input list if so
// Inputs addition order follows node inputs order
// Inputs are inserted at the position of the first input removed
std::size_t inputIdx = 0U;
IOIndex_t inputIdx = 0U;
for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
const auto val = std::make_pair(newNode, inputIdx);
// Make sure to not add this input twice, as updateInputsNew() may be
// called several times for the same node.
if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end());
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
}
++inputIdx;
}
// (if node UPDATE)
// newNode may already exists in the graph and may have been updated
// Check and remove inputs that are not inputs anymore
inputIdx = 0U;
for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
if ((pa_ptr != nullptr) &&
(mNodes.find(pa_ptr) !=
mNodes.end())) {
const auto val = std::make_pair(newNode, inputIdx);
auto it = std::find(mInputNodes.begin(), mInputNodes.end(), val);
if (it != mInputNodes.end()) {
mInputNodes.erase(it);
}
}
++inputIdx;
......@@ -832,33 +858,35 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();
// Remove outputs that are not output anymore because connected to newNode
std::size_t outputIdx = 0;
for (const std::shared_ptr<Node>& parent : newNode->getParents()) {
// Check that newNode parent is in current GraphView
if (mNodes.find(parent) != mNodes.end()) {
for (auto orderedChilds : parent->getOrderedChildren()) {
IOIndex_t outputIdx = 0;
for (auto ch_ptr : orderedChilds) {
// If newNode is connected to it
if (ch_ptr == newNode) {
const auto val = std::make_pair(parent, outputIdx);
const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
if (iter != mOutputNodes.end()) {
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
}
}
}
++outputIdx;
}
}
++outputIdx;
}
// Check if node outputs are outputs for the GraphView and add them to the output list if so
outputIdx = 0;
IOIndex_t outputIdx = 0;
for (auto orderedChilds : newNode->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
......@@ -870,6 +898,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
if (noInsideConnection) {
const auto val = std::make_pair(newNode, outputIdx);
// Output may be already be present (see addChild() with a node already in GraphView)
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
}
......@@ -882,19 +911,19 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();
// Check if node inputs were inputs for the GraphView and remove them from the list if so
std::size_t inputIdx = 0;
for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) {
for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) {
const auto val = std::make_pair(deletedNode, inputIdx);
const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
if (iter != mInputNodes.end()) {
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
}
++inputIdx;
}
// Add child node inputs that become GraphView input following the removal of the node
......@@ -903,16 +932,13 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
for (auto ch_ptr : orderedChilds) {
// Check that deletedNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.end()) {
inputIdx = 0;
IOIndex_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode was connected to it
if (pa_ptr == deletedNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
// Make sure to not add this input twice, as updateInputsNew() may be
// called several times for the same node.
if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
}
AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end());
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
}
++inputIdx;
}
......@@ -923,25 +949,29 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();
// Check if node outputs were outputs for the GraphView and remove them from the list if so
std::size_t outputIdx = 0;
for (auto orderedChilds : deletedNode->getOrderedChildren()) {
for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) {
const auto val = std::make_pair(deletedNode, outputIdx);
const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
if (iter != mOutputNodes.end()) {
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
}
++outputIdx;
}
// Add parent node outputs that become GraphView output following the removal of the node
// Outputs addition order follows deletedNode inputs order
for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) {
std::size_t outputIdx = 0;
if (parent == nullptr) {
continue;
}
IOIndex_t outputIdx = 0;
for (auto orderedChilds : parent->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
......@@ -953,10 +983,42 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
if (noInsideConnection) {
const auto val = std::make_pair(parent, outputIdx);
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
AIDGE_INTERNAL_ASSERT(std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end());
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
}
++outputIdx;
}
}
}
void Aidge::GraphView::updateInputsOutputsNodes() {
mInputNodes.clear();
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
IOIndex_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.push_back(std::make_pair(go_ptr, inputIdx));
}
++inputIdx;
}
}
mOutputNodes.clear();
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
IOIndex_t outputIdx = 0;
for (auto orderedChilds : go_ptr->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx));
}
}
if (orderedChilds.empty()) {
// an output linked to nothing
mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx));
}
++outputIdx;
}
}
......
......@@ -14,6 +14,7 @@
#include <memory>
#include <set>
#include <string>
#include <random>
#include <catch2/catch_test_macros.hpp>
......@@ -26,6 +27,39 @@
using namespace Aidge;
std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) {
std::random_device rd;
std::mt19937 gen(rd());
std::binomial_distribution<> dIn(maxIn, avgIn/maxIn);
std::binomial_distribution<> dOut(maxOut, avgOut/maxOut);
std::binomial_distribution<> dLink(1, density);
std::vector<NodePtr> nodes;
for (size_t i = 0; i < nbNodes; ++i) {
nodes.push_back(GenericOperator("Fictive", dIn(gen), dIn(gen), dOut(gen)));
}
for (size_t i = 0; i < nbNodes; ++i) {
for (size_t j = i + 1; j < nbNodes; ++j) {
for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) {
for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) {
if (dLink(gen)) {
nodes[i]->addChild(nodes[j], outId, inId);
}
}
}
}
}
return std::set<NodePtr>(nodes.begin(), nodes.end());
}
TEST_CASE("genRandomDAG") {
auto g = std::make_shared<GraphView>(genRandomDAG(10));
REQUIRE(g->getNodes().size() == 10);
g->save("./genRandomDAG");
}
TEST_CASE("[core/graph] GraphView(Constructor)") {
std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment