Skip to content
Snippets Groups Projects
Commit d0097894 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Clean up, removed node ordering from MetaOperator

parent b5cc6cca
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@
#include <map>
#include <memory>
#include <unordered_set>
#include <set>
#include <string>
#include <utility>
#include <vector>
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_GRAPH_TESTING_H_
#define AIDGE_CORE_GRAPH_TESTING_H_
#include <vector>
#include <set>
#include <random>
#include <algorithm>
#include <utility>
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* Random DAG generator
*/
struct RandomDAG {
/// @brief Connection density (between 0 and 1)
float density = 0.5;
/// @brief Max number of inputs per node (regardless if they are connected or not)
size_t maxIn = 5;
/// @brief Average number of inputs per node (regardless if they are connected or not)
float avgIn = 1.5;
/// @brief Max number of outputs per node (regardless if they are connected or not)
size_t maxOut = 2;
/// @brief Average number of outputs per node (regardless if they are connected or not)
float avgOut = 1.1;
/// @brief List of node types that should be generated in the graph (as GenericOperator)
std::vector<std::string> types = {"Fictive"};
/// @brief Weights of each node type, used to compute the probability of generating this type
std::vector<float> typesWeights = {1.0};
/**
* Generate a DAG according to the parameters of the class.
* @param seed Random seed. For an identical seed, an identical topology is
* generated, but with a random node ordering in the return set of nodes.
* @param nbNodes Number of nodes to generate.
*/
std::pair<NodePtr, std::set<NodePtr>> gen(std::mt19937::result_type seed, size_t nbNodes) const;
};
std::string nodePtrToType(NodePtr node);
std::string nodePtrToName(NodePtr node);
std::set<std::string> nodePtrTo(const std::set<NodePtr>& nodes,
std::string(*nodeTo)(NodePtr) = nodePtrToType);
std::vector<std::pair<std::string, IOIndex_t>> nodePtrTo(
const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes,
std::string(*nodeTo)(NodePtr) = nodePtrToType);
}
#endif /* AIDGE_CORE_GRAPH_TESTING_H_ */
......@@ -27,16 +27,9 @@ public:
// Micro-graph handling:
std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph
std::shared_ptr<SequentialScheduler> mScheduler;
// Need to store an ordored list of input/output nodes for the micro-graph,
// because input/output nodes in a GraphView are unordered.
// TODO: refactor GraphView to handle ordered input/output?
std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes;
std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
public:
MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph,
std::vector<NodePtr> inputNodes = std::vector<NodePtr>(),
std::vector<NodePtr> outputNodes = std::vector<NodePtr>());
MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph);
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......@@ -47,9 +40,6 @@ public:
mGraph(op.mGraph->clone())
{
// cpy-ctor
// TODO: FIXME: mInputNodes and mOutputNodes are not populated!
// Issue: how to map new (cloned) nodes with old nodes? getNodes() does
// not garantee any order! Check issue #52.
}
/**
......@@ -71,7 +61,7 @@ public:
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
const auto& inputOp = mInputNodes[inputIdx];
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
......@@ -83,8 +73,8 @@ public:
mGraph->forwardDims();
// Associate outputs to micro-graph outputs for custom implementation
for (size_t outputIdx = 0; outputIdx < mOutputNodes.size(); ++outputIdx) {
const auto& outputOp = mOutputNodes[outputIdx];
for (size_t outputIdx = 0; outputIdx < mGraph->getOrderedOutputs().size(); ++outputIdx) {
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
mOutputs[outputIdx] = outputOp.first->getOperator()->getOutput(outputOp.second);
}
}
......@@ -159,11 +149,9 @@ public:
inline std::shared_ptr<Node> MetaOperator(const char *type,
const std::shared_ptr<GraphView>& graph,
const std::string& name = "",
std::vector<NodePtr> inputNodes = std::vector<NodePtr>(),
std::vector<NodePtr> outputNodes = std::vector<NodePtr>())
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph, inputNodes, outputNodes), name);
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name);
}
} // namespace Aidge
......
......@@ -32,10 +32,8 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels,
// Construct micro-graph
auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name, orderedInputNodes);
auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name);
addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w");
addProducer(metaOp, 2, {out_channels}, "b");
return metaOp;
......@@ -65,10 +63,8 @@ inline std::shared_ptr<Node> PaddedConvDepthWise(const std::array<DimSize_t, DIM
// Construct micro-graph
auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name, orderedInputNodes);
auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name);
addProducer(metaOp, 1, std::array<DimSize_t,0>({}), "w");
addProducer(metaOp, 2, std::array<DimSize_t,0>({}), "b");
return metaOp;
......
......@@ -303,57 +303,7 @@ void Aidge::GraphView::setDatatype(const DataType &datatype) {
node->getOperator()->setDatatype(datatype);
}
}
/*
void Aidge::GraphView::updateOutputNodes() {
mOutputNodes.clear();
for (const std::shared_ptr<Node>& go_it : mNodes) {
if (go_it->nbOutputs() !=
go_it->nbValidOutputs()) { // an output linked to nothing
mOutputNodes.insert(go_it);
continue;
}
for (const std::shared_ptr<Node>& ch_ptr : go_it->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
mOutputNodes.insert(go_it);
break;
}
}
}
}
void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) {
if (node->nbOutputs() !=
node->nbValidOutputs()) { // an output linked to nothing
mOutputNodes.insert(node);
} else { // don't enter if was already added to outputNodes
for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
mOutputNodes.insert(node);
break;
}
}
}
// update other outputNodes
for (const std::shared_ptr<Node> &pa_ptr :
node->getParents()) { // check if any parent is in OutputNodes too
if ((pa_ptr != nullptr) &&
(mOutputNodes.find(pa_ptr) !=
mOutputNodes.end())) { // it's a match! Must check if the outputNode
// found is still an outputNode
bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs());
for (const std::shared_ptr<Node>& ch_ptr : pa_ptr->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
remove = false;
break;
}
}
if (remove) {
mOutputNodes.erase(pa_ptr);
}
}
}
}
*/
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
......@@ -841,35 +791,6 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
return true;
}
/*
void Aidge::GraphView::updateInputNodes() {
std::set<std::pair<NodePtr, IOIndex_t>> inputNodes;
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
size_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
inputNodes.insert(std::make_pair(go_ptr, inputIdx));
}
++inputIdx;
}
}
// Remove inputs that are not input anymore (deleted node or input connected internally)
for (auto it = mInputNodes.begin(); it != mInputNodes.end(); ++it) {
if (inputNodes.find(*it) == inputNodes.end()) {
it = mInputNodes.erase(it);
}
}
// Add remaining new inputs
for (auto inputNode : inputNodes) {
mInputNodes.push_back(inputNode);
}
}
*/
void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Can be called several times with the same node, e.g. when addChild() is
// called on a node already part of the GraphView. In this case, inputs/outputs
......@@ -1111,71 +1032,6 @@ void Aidge::GraphView::updateInputsOutputsNodes() {
}
}
/*
void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) {
// add node_ptr to inputNode if it can
std::size_t filledWithKnownInputs = 0U;
bool wasAdded = mInputNodes.find(node) != mInputNodes.end();
for (const std::shared_ptr<Node>& pa_ptr : node->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.insert(node);
wasAdded = true;
break;
}
++filledWithKnownInputs;
}
if (filledWithKnownInputs == node->nbInputs() && wasAdded) {
mInputNodes.erase(node);
}
// update other inputNodes
for (const std::shared_ptr<Node>& ch_ptr :
node->getChildren()) { // check if any child is in InputNodes too
if (mInputNodes.find(ch_ptr) !=
mInputNodes.end()) { // it's a match! Must check if the inputNode found
// is still an inputNode
// change here
bool remove = true;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
if (pa_ptr == nullptr ||
mNodes.find(pa_ptr) ==
mNodes
.end()) { // Parent doesn't exist || Parent not in the graph
remove = false;
break;
}
}
if (remove) {
mInputNodes.erase(ch_ptr);
}
}
}
}
*/
/*
void Aidge::GraphView::removeInputNode(const std::string nodeName) {
std::map<std::string, std::shared_ptr<Node>>::iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
const std::shared_ptr<Node> val = (*it).second;
if (mInputNodes.find(val) != mInputNodes.end()) {
mInputNodes.erase(val);
}
}
}
void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
std::map<std::string, std::shared_ptr<Node>>::iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
const std::shared_ptr<Node> val = (*it).second;
if (mOutputNodes.find(val) != mOutputNodes.end()) {
mOutputNodes.erase(val);
}
}
}
*/
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
......@@ -1183,37 +1039,44 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
std::map<NodePtr, NodePtr> oldToNewNodes;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
oldToNewNodes[node_ptr] = cloneNode(node_ptr);
auto clonedNode = cloneNode(node_ptr);
if (clonedNode == nullptr) {
AIDGE_ASSERT(node_ptr->getChildren().size() <= 1, "deleted nodes in GraphView::clone() cannot have multiple children");
AIDGE_ASSERT(node_ptr->nbDataInputs() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents");
}
oldToNewNodes[node_ptr] = clonedNode;
}
// For each node, convert old node -> new node connections
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr)
if (oldToNewNode.second == nullptr) {
continue; // deleted node
}
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0;
for (auto parent : oldToNewNode.first->inputs()) {
while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
AIDGE_ASSERT(parent.first->getChildren().size() == 1, "deleted nodes in GraphView::clone() cannot have multiple children");
AIDGE_ASSERT(parent.first->nbDataInputs() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents");
const auto& parents = parent.first->dataInputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
parent = parents[0];
}
else {
break;
if (parent.first != nullptr) {
while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
AIDGE_INTERNAL_ASSERT(parent.first->getChildren().size() == 1);
AIDGE_INTERNAL_ASSERT(parent.first->nbDataInputs() <= 1);
const auto& parents = parent.first->dataInputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
parent = parents[0];
}
else {
break;
}
}
}
if (oldToNewNodes[parent.first]) {
AIDGE_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs(),
"next valid parent after deleted nodes in GraphView::clone() has wrong number of outputs");
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
if (oldToNewNodes[parent.first]) {
AIDGE_INTERNAL_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs());
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
}
}
++parentId;
......@@ -1224,6 +1087,11 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
// This has to be done in a second step to ensure that new GraphView inputs/outputs
// are properly set (otherwise, some node's inputs/outputs may be wrongly registered as
// GraphView inputs/outputs because not yet connected to other nodes)
if (oldToNewNodes[mRootNode] != nullptr) {
// Add root node first if is still exists!
newGraph->add(oldToNewNodes[mRootNode], false);
}
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr)
continue; // deleted node
......@@ -1237,19 +1105,22 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
// If input node was removed, find next valid input
while (oldToNewNodes[it->first] == nullptr) {
// Removed node should have only one connected output, otherwise cloning is invalid
AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() == 1);
auto child = *it->first->getChildren().begin();
AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() <= 1);
bool found = false;
std::size_t inputIdx = 0;
for (auto parent : child->getParents()) {
if (parent == it->first) {
it->first = child;
it->second = inputIdx;
found = true;
break;
if (it->first->getChildren().size() == 1) {
auto child = *it->first->getChildren().begin();
std::size_t inputIdx = 0;
for (auto parent : child->getParents()) {
if (parent == it->first) {
it->first = child;
it->second = inputIdx;
found = true;
break;
}
++inputIdx;
}
++inputIdx;
}
if (!found) {
......@@ -1275,7 +1146,9 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
AIDGE_INTERNAL_ASSERT(it->first->nbDataInputs() <= 1);
auto parents = it->first->dataInputs();
if (!parents.empty()) {
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
*it = parents[0];
}
else {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/GenericOperator.hpp"
std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::mt19937::result_type seed, size_t nbNodes) const {
std::mt19937 gen(seed);
std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn);
std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut);
std::binomial_distribution<> dLink(1, density);
std::discrete_distribution<> dType(typesWeights.begin(), typesWeights.end());
std::vector<std::pair<int, int>> nbIOs;
for (size_t i = 0; i < nbNodes; ++i) {
const auto nbIn = 1 + dIn(gen);
nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen)));
}
std::vector<int> nodesSeq(nbNodes);
std::iota(nodesSeq.begin(), nodesSeq.end(), 0);
// Don't use gen or seed here, must be different each time!
std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}()));
std::vector<NodePtr> nodes(nbNodes, nullptr);
for (auto idx : nodesSeq) {
const std::string type = types[dType(gen)];
const std::string name = type + std::to_string(idx);
nodes[idx] = GenericOperator(type.c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str());
}
for (size_t i = 0; i < nbNodes; ++i) {
for (size_t j = i + 1; j < nbNodes; ++j) {
for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) {
for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) {
if (dLink(gen)) {
nodes[i]->addChild(nodes[j], outId, inId);
break;
}
}
}
}
}
return std::make_pair(nodes[0], std::set<NodePtr>(nodes.begin(), nodes.end()));
}
std::string Aidge::nodePtrToType(NodePtr node) {
return node->type();
}
std::string Aidge::nodePtrToName(NodePtr node) {
return node->name();
}
std::set<std::string> Aidge::nodePtrTo(const std::set<NodePtr>& nodes,
std::string(*nodeTo)(NodePtr))
{
std::set<std::string> nodesStr;
std::transform(nodes.begin(), nodes.end(), std::inserter(nodesStr, nodesStr.begin()), nodeTo);
return nodesStr;
}
std::vector<std::pair<std::string, Aidge::IOIndex_t>> Aidge::nodePtrTo(
const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes,
std::string(*nodeTo)(NodePtr))
{
std::vector<std::pair<std::string, IOIndex_t>> nodesStr;
std::transform(nodes.begin(), nodes.end(), std::back_inserter(nodesStr),
[nodeTo](const std::pair<NodePtr, IOIndex_t>& node) {
return std::make_pair(nodeTo(node.first), node.second);
});
return nodesStr;
}
......@@ -12,9 +12,7 @@
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph,
std::vector<NodePtr> inputNodes,
std::vector<NodePtr> outputNodes)
Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph)
: Operator(type),
mGraph(graph)
{
......@@ -26,53 +24,6 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
for (std::size_t i = 0; i < mOutputs.size(); ++i) {
mOutputs[i] = std::make_shared<Tensor>();
}
// Fill inputsNodes and outputsNodes when there is no ambiguity
if (inputNodes.empty()) {
AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping");
inputNodes.push_back(*mGraph->inputNodes().begin());
}
if (outputNodes.empty()) {
AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping");
outputNodes.push_back(*mGraph->outputNodes().begin());
}
AIDGE_ASSERT(mGraph->inputNodes().size() == inputNodes.size(), "wrong number of specified input nodes");
AIDGE_ASSERT(mGraph->outputNodes().size() == outputNodes.size(), "wrong number of specified output nodes");
// Identify inputs that are outside the micro-graph
for (const auto& inputNode : inputNodes) {
AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph");
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs();
int inputIdx = 0; // input idx relative to the current node
for (const auto& in : inputNodeinputs) {
if (in.first == nullptr || !mGraph->inView(in.first)) {
// The input is not connected inside the micro-graph
// (no connection to this input or connection outside the micro-graph)
// => it is therefore an input for the meta-operator
mInputNodes.push_back(std::make_pair(inputNode, inputIdx));
}
++inputIdx;
}
}
// The outputs of the output nodes are also the outputs of the meta-operator
for (const auto& outputNode : outputNodes) {
AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph");
const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs =
outputNode->outputs();
for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) {
mOutputNodes.push_back(std::make_pair(outputNode, outputIdx));
}
}
AIDGE_INTERNAL_ASSERT(mInputNodes.size() == mGraph->inputs().size());
AIDGE_INTERNAL_ASSERT(mOutputNodes.size() == mGraph->outputs().size());
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
......@@ -80,7 +31,7 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI
return mImpl->getNbRequiredData(inputIdx);
}
else {
const auto& inputOp = mInputNodes[inputIdx];
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
return inputOp.first->getOperator()->getNbRequiredData(inputOp.second);
}
}
......@@ -90,7 +41,7 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co
return mImpl->getNbConsumedData(inputIdx);
}
else {
const auto& inputOp = mInputNodes[inputIdx];
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
return inputOp.first->getOperator()->getNbConsumedData(inputOp.second);
}
}
......@@ -100,7 +51,7 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c
return mImpl->getNbProducedData(outputIdx);
}
else {
const auto& outputOp = mOutputNodes[outputIdx];
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
return outputOp.first->getOperator()->getNbProducedData(outputOp.second);
}
}
......
......@@ -16,6 +16,7 @@
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graph/Testing.hpp"
using namespace Aidge;
......@@ -113,6 +114,8 @@ TEST_CASE("GraphGeneration from Connector", "[GraphView]") {
x = (*node10)({a, x});
std::shared_ptr<GraphView> gv = generateGraph({x});
gv->save("GraphGeneration");
REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({}));
REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_matmul1", 0}}));
}
TEST_CASE("Connector connection GraphView", "[Connector]") {
......@@ -131,6 +134,9 @@ TEST_CASE("Connector connection GraphView", "[Connector]") {
GenericOperator("g_conv3", 1, 1,1),
GenericOperator("g_matmul1", 2,2,1)
});
REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv1", 0}}));
REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_matmul1", 0}}));
x = (*prod)({});
x = (*g)({x});
std::shared_ptr<GraphView> g2 = generateGraph({x});
......@@ -151,9 +157,13 @@ TEST_CASE("Connector connection GraphView", "[Connector]") {
GenericOperator("g_concat", 3,3,1),
GenericOperator("g_conv3", 1, 1,1)
});
REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"ElemWise", 0}, {"ElemWise", 1}, {"ElemWise", 2}}));
REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv3", 0}}));
x = (*g)({x, y, z});
std::shared_ptr<GraphView> gv = generateGraph({x});
REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({}));
REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv3", 0}}));
gv->save("MultiInputSequentialConnector");
REQUIRE(gv->inputNodes().size() == 0U);
}
......@@ -169,6 +179,8 @@ TEST_CASE("Connector Mini-graph", "[Connector]") {
}
y = (*GenericOperator("ElemWise",2,2,1))({y, x});
std::shared_ptr<GraphView> g = generateGraph({y});
REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({}));
REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"ElemWise", 0}}));
g->save("TestGraph");
}
......
......@@ -14,79 +14,19 @@
#include <memory>
#include <set>
#include <string>
#include <random>
#include <algorithm>
#include <utility>
#include <catch2/catch_test_macros.hpp>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
using namespace Aidge;
std::pair<NodePtr, std::set<NodePtr>> genRandomDAG(std::mt19937::result_type seed, size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) {
std::mt19937 gen(seed);
std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn);
std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut);
std::binomial_distribution<> dLink(1, density);
std::vector<std::pair<int, int>> nbIOs;
for (size_t i = 0; i < nbNodes; ++i) {
const auto nbIn = 1 + dIn(gen);
nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen)));
}
std::vector<int> nodesSeq(nbNodes);
std::iota(nodesSeq.begin(), nodesSeq.end(), 0);
// Don't use gen or seed here, must be different each time!
std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}()));
std::vector<NodePtr> nodes(nbNodes, nullptr);
for (auto idx : nodesSeq) {
const std::string type = "Fictive";
const std::string name = type + std::to_string(idx);
nodes[idx] = GenericOperator(type.c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str());
}
for (size_t i = 0; i < nbNodes; ++i) {
for (size_t j = i + 1; j < nbNodes; ++j) {
for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) {
for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) {
if (dLink(gen)) {
nodes[i]->addChild(nodes[j], outId, inId);
break;
}
}
}
}
}
return std::make_pair(nodes[0], std::set<NodePtr>(nodes.begin(), nodes.end()));
}
std::set<std::string> nodePtrToName(const std::set<NodePtr>& nodes) {
std::set<std::string> nodesName;
std::transform(nodes.begin(), nodes.end(), std::inserter(nodesName, nodesName.begin()),
[](const NodePtr& node) {
return node->name();
});
return nodesName;
}
std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes) {
std::vector<std::pair<std::string, IOIndex_t>> nodesName;
std::transform(nodes.begin(), nodes.end(), std::back_inserter(nodesName),
[](const std::pair<NodePtr, IOIndex_t>& node) {
return std::make_pair(node.first->name(), node.second);
});
return nodesName;
}
TEST_CASE("genRandomDAG") {
const size_t nbTests = 100;
size_t nbUnicity = 0;
......@@ -95,10 +35,11 @@ TEST_CASE("genRandomDAG") {
std::random_device rd;
const std::mt19937::result_type seed(rd());
RandomDAG randDAG;
const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(genRandomDAG(seed, 10, 0.5));
const bool unicity1 = g1->add(randDAG.gen(seed, 10));
const auto g2 = std::make_shared<GraphView>("g2");
const bool unicity2 = g2->add(genRandomDAG(seed, 10, 0.5));
const bool unicity2 = g2->add(randDAG.gen(seed, 10));
g1->save("./genRandomDAG1");
g2->save("./genRandomDAG2");
......@@ -106,9 +47,9 @@ TEST_CASE("genRandomDAG") {
REQUIRE(unicity1 == unicity2);
if (unicity1) {
REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
++nbUnicity;
}
}
......@@ -116,6 +57,68 @@ TEST_CASE("genRandomDAG") {
printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests);
}
TEST_CASE("clone") {
const size_t nbTests = 100;
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
RandomDAG randDAG;
const auto g1 = std::make_shared<GraphView>("g1");
g1->add(randDAG.gen(seed, 10));
const auto g2 = g1->clone();
REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
}
}
NodePtr nodeDel(NodePtr node) {
if (node->type() == "DelFictive") {
return nullptr;
}
return node->clone();
}
TEST_CASE("clone_with_delete") {
const size_t nbTests = 100;
size_t nbClonedWithDelete = 0;
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
RandomDAG randDAG;
randDAG.types = {"Fictive", "DelFictive"};
randDAG.typesWeights = {0.9, 0.1};
const auto g1 = std::make_shared<GraphView>("g1");
g1->add(randDAG.gen(seed, 10));
g1->save("./clone_with_delete1");
try {
const auto g2 = g1->cloneCallback(&nodeDel);
if (g2->getNodes().size() < g1->getNodes().size()) {
g2->save("./clone_with_delete2");
// These tests are not necessarily true if the deleted node is an input/output node!
//REQUIRE(g1->getOrderedInputs().size() == g2->getOrderedInputs().size());
//REQUIRE(g1->getOrderedOutputs().size() == g2->getOrderedOutputs().size());
++nbClonedWithDelete;
}
}
catch (const std::runtime_error& error) {
// pass
}
}
printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests);
}
TEST_CASE("[core/graph] GraphView(Constructor)") {
std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1");
......@@ -138,6 +141,9 @@ TEST_CASE("[core/graph] GraphView(add)") {
g->add(GOp5);
std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 2, 1, "Gop6");
g->add(GOp6);
g->save("node_alone");
REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop3", 0}, {"Gop4", 0}, {"Gop5", 0}, {"Gop6", 0}, {"Gop6", 1}}));
REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop2", 0}, {"Gop5", 0}, {"Gop6", 0}}));
}
SECTION("Several Nodes") {
......@@ -148,10 +154,14 @@ TEST_CASE("[core/graph] GraphView(add)") {
GOp1parent->addChild(GOp1, 0, 0);
g->add(GOp1);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent}));
REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({}));
REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}}));
// there should be no deplicates
g->add(GOp1);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent}));
REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({}));
REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}}));
}
SECTION("Initializer list ofr Node") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment