diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 6b23cda0d86a77487af7d63b3e7a0dfeae57bb37..dda3d8ee459e9f089f817f7222d717bf75ede0f5 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -35,17 +35,20 @@ private: /// @brief Name of the graphview std::string mName; + /// @brief GraphView root node + NodePtr mRootNode; + /// @brief Set of nodes included in the GraphView std::set<NodePtr> mNodes; /// @brief Set of nodes included in the graphview with names std::map<std::string, NodePtr> mNodeRegistry; - /// @brief Nodes without input link - std::set<NodePtr> mInputNodes; + /// @brief GraphView inputs + std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes; - /// @brief Nodes without output link - std::set<NodePtr> mOutputNodes; + /// @brief GraphView outputs + std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes; public: GraphView(std::string name="") @@ -54,12 +57,6 @@ public: // ctor } - // GraphView(std::set<NodePtr> nodes, std::string name="") - // : mName(name) - // { - // add(nodes); - // } - bool operator==(const GraphView &gv) const { return mNodes == gv.mNodes; @@ -105,57 +102,88 @@ public: return mNodes.find(nodePtr) != mNodes.end(); } + NodePtr getRootNode() { + return mRootNode; + } + /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ - inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; } + inline std::set<NodePtr> inputNodes() const noexcept { + std::set<NodePtr> nodes; + for (auto node : mInputNodes) { + nodes.insert(node.first); + } + return nodes; + } /** @brief Get reference to the set of output Nodes. */ - inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; } - + inline std::set<NodePtr> outputNodes() const noexcept { + std::set<NodePtr> nodes; + for (auto node : mOutputNodes) { + nodes.insert(node.first); + } + return nodes; + } /** @brief Assess if the given Node is an input Node of the GraphView object. */ inline bool isInputNode(NodePtr nodePtr) const { - return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false; + const auto nodes = inputNodes(); + return (nodes.find(nodePtr) != nodes.end()) ? true : false; } /** @brief Assess if the given Node is an output Node of the GraphView object. */ inline bool isOutputNode(NodePtr nodePtr) const { - return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false; + const auto nodes = outputNodes(); + return (nodes.find(nodePtr) != nodes.end()) ? true : false; } + void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); + void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); + + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; }; + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; }; + /** - * @brief List outside dataInput connections of the GraphView object's inputNodes. + * @brief List outside data input connections of the GraphView. + * Data inputs exclude inputs expecting parameters (weights or bias). + * The vector size is garanteed to match the number of outside data inputs of the GraphView. If there is + * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; /** - * @brief List dataInput connections of the GraphView object's inputNodes. + * @brief List all dataInput connections (within and outside) of the specified GraphView node named "name". + * Data inputs exclude inputs expecting parameters (weights or bias). * @param name Name of the Node. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** - * @brief List outside input connections of the GraphView object's inputNodes. + * @brief List outside input connections of the GraphView. The vector + * size is garanteed to match the number of outside inputs of the GraphView. If there is + * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; /** - * @brief List input connections of the specified GraphView object's inputNode. + * @brief List all input connections (within and outside) of the specified GraphView node named "name". * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const; /** - * @brief List output connections of the GraphView object's outputNodes. + * @brief List outside output connections of the GraphView. The vector + * size is garanteed to match the number of outputs of the GraphView. If there is + * no connection to a given output, the corresponding sub-vector will be empty. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; /** - * @brief Specific i-th output connection of the GraphView object. + * @brief List all output connections (within and outside) of the specified GraphView node named "name". * @param nodeName Name of the Node of which to show the output. * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> */ @@ -252,20 +280,34 @@ public: * in the GraphView automatically. Default: true. */ void add(NodePtr otherNode, bool includeLearnableParam = true); + + /** + * @brief Include a set of Nodes to the current GraphView object. + * @param otherNodes + * @param includeLearnableParam + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). + */ + bool add(std::set<NodePtr> otherNodes, + bool includeLearnableParam = true); + /** * @brief Include a set of Nodes to the current GraphView object. + * The first element of the otherNodes pair is the start node and + * the second is the remaining nodes to add. * @param otherNodes * @param includeLearnableParam + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - void add(std::set<NodePtr> otherNodes, + bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes, bool includeLearnableParam = true); /** * @brief Include every Node inside another GraphView to the current * GraphView. * @param other_graph GraphView containing the Nodes to include. + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - void add(std::shared_ptr<GraphView> otherGraph); + bool add(std::shared_ptr<GraphView> otherGraph); /** * @brief Include a Node in the current GraphView and link it to another @@ -350,26 +392,27 @@ public: IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx); - /** * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible. * Both sets should include all the necessary Producers. - * @details Replaced Nodes are removed from any GraphView pointing at them all. - * The oldNodes set should have only one input/output - * Tensor for automatic connections of newNodes set. - * @param oldNodes actual set of shared_ptr<Node> to replace. - * @param newNodes new set of shared_ptr<Node>. - * @return true - * @return false + * @details There are 3 cases of replacement: + * Case 1: same number of input/output connections for oldNodes and newNodes sets. + * - input/output connections are replacated according to their IDs. + * Case 2: different number of input/output connections for oldNodes and newNodes sets. + * - only a single parent/child node for the newNodes set, every input/output is + * connected to it. + * - several parents/children nodes for newNodes set => impossible to know, return false + * Case 3: newNodes set is empty + * - same number of input/output connections in oldNodes, parents and children are linked according + * to these connections IDs + * - different number of input/output connections in oldNodes => return false + * @param oldNodes + * @param newNodes + * @return true replacement has been performed + * @return false no replacement has been performed */ static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); - void updateInputNodes(); - /** - * @brief Process from zero the set of output Nodes. - */ - void updateOutputNodes(); - /** * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. * @return std::shared_ptr<GraphView> @@ -403,6 +446,7 @@ public: /** * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return IOIndex_t */ IOIndex_t getNbFreeDataInputs() const; @@ -413,33 +457,34 @@ private: /////////////////////////////////////////////////////// /** - * @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object. + * @brief Get the number of dataInput that are outside the GraphView. + * Data inputs exclude inputs expecting parameters (weights or bias). + * This number matches the size of the vector returned by GraphView::dataInputs(). * @return IOIndex_t */ IOIndex_t getNbDataInputs() const; /** - * @brief Update the set of inputNodes with a new Node, checking if it can be - * added and removing any Node not part of mInputNode anymore. + * @brief Automatically update GraphView inputs/outputs with a new Node, checking if + * it this Node becomes an input/output for the graph and if previous inputs are still + * inputs/outputs after adding this node. * @param nodePtr */ - void updateInputNodes(NodePtr node); + void updateInputsOutputsNew(NodePtr newNode); /** - * @brief Update the set of outputNodes with a new Node, checking if it can be - * added and removing any Node not part of mOutputNode anymore. + * @brief Automatically update GraphView inputs/outputs with a Node removed, checking if + * it this Node was an input/output for the graph and if this node childs become new inputs/outputs + * for the graph. * @param nodePtr */ - void updateOutputNodes(NodePtr node); + void updateInputsOutputsDelete(NodePtr deletedNode); /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// void _forwardDims(std::set<NodePtr> listNodes); - - void removeInputNode(const std::string nodeName); - void removeOutputNode(const std::string nodeName); }; } // namespace Aidge diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 118d925e1e5b7c4fcd0c353236998ff831f7e42d..5ae4eb5d893244fa842e6bb0435c0a8ab3bc0ac5 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -140,7 +140,8 @@ public: /** * @brief List of pair <Parent, ID of the data intput>. When an input is not - * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; @@ -167,6 +168,7 @@ public: /** * @brief Get the lowest index in the InputData Parent list equal to the * nullptr. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return std::size_t */ inline IOIndex_t getFirstFreeDataInput() const { @@ -180,7 +182,9 @@ public: IOIndex_t getNbFreeDataInputs() const; /** - * @brief List input ids of children linked to outputs of the node + * @brief List input ids of children linked to outputs of the node. The vector + * size is garanteed to match the number of outputs of the node. If there is + * no connection to a given output, the corresponding sub-vector will be empty. * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ @@ -203,7 +207,8 @@ public: inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); } /** - * @brief Number of input specifically for data + * @brief Number of input specifically for data. + * Data inputs exclude inputs expecting parameters (weights or bias). * @details [data, data, weight, bias] => 2 * @return IOIndex_t */ diff --git a/include/aidge/graph/Testing.hpp b/include/aidge/graph/Testing.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ecacdf66298cb83c919ad447c82463206836a3e9 --- /dev/null +++ b/include/aidge/graph/Testing.hpp @@ -0,0 +1,67 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_GRAPH_TESTING_H_ +#define AIDGE_CORE_GRAPH_TESTING_H_ + +#include <cstddef> +#include <vector> +#include <set> +#include <random> // std::mt19937::result_type +#include <utility> // std::pair + +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +/** + * Random (directed) graph generator +*/ +struct RandomGraph { + /// @brief If true, the generated graph is a DAG (no cycle) + bool acyclic = false; + /// @brief Connection density (between 0 and 1) + float density = 0.5; + /// @brief Max number of inputs per node (regardless if they are connected or not) + std::size_t maxIn = 5; + /// @brief Average number of inputs per node (regardless if they are connected or not) + float avgIn = 1.5; + /// @brief Max number of outputs per node (regardless if they are connected or not) + std::size_t maxOut = 2; + /// @brief Average number of outputs per node (regardless if they are connected or not) + float avgOut = 1.1; + /// @brief List of node types that should be generated in the graph (as GenericOperator) + std::vector<std::string> types = {"Fictive"}; + /// @brief Weights of each node type, used to compute the probability of generating this type + std::vector<float> typesWeights = {1.0}; + /// @brief Type of node that should be omitted from the generated topology + std::string omitType; + + /** + * Generate a DAG according to the parameters of the class. + * @param seed Random seed. For an identical seed, an identical topology is + * generated, but with a random node ordering in the return set of nodes. + * @param nbNodes Number of nodes to generate. + */ + std::pair<NodePtr, std::set<NodePtr>> gen(std::mt19937::result_type seed, std::size_t nbNodes) const; +}; + +std::string nodePtrToType(NodePtr node); +std::string nodePtrToName(NodePtr node); +std::set<std::string> nodePtrTo(const std::set<NodePtr>& nodes, + std::string(*nodeTo)(NodePtr) = nodePtrToType); +std::vector<std::pair<std::string, IOIndex_t>> nodePtrTo( + const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes, + std::string(*nodeTo)(NodePtr) = nodePtrToType); + +} // namespace Aidge + +#endif /* AIDGE_CORE_GRAPH_TESTING_H_ */ diff --git a/include/aidge/graphRegex/GraphStrInterpreter.hpp b/include/aidge/graphRegex/GraphStrInterpreter.hpp index 98dca0e9f84de0be2614aed0e47c9d86ae552674..38e89b3733e1a07062661fa520485f92fbd7f026 100644 --- a/include/aidge/graphRegex/GraphStrInterpreter.hpp +++ b/include/aidge/graphRegex/GraphStrInterpreter.hpp @@ -1,7 +1,6 @@ #ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ #define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ -#include <iostream> #include <sstream> #include <memory> #include <algorithm> diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index f5521a1d12728a7957cb67c09861ee673e21cbae..28f89cf09f41ff6225c8c9e7248d106f8a0c1428 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -30,7 +30,7 @@ namespace Aidge { class Add_Op : public OperatorTensor, public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> { public: - static constexpr const char* Type = "Add"; + static const std::string Type; Add_Op(const IOIndex_t nbIn) : OperatorTensor(Type, nbIn, 0, 1) @@ -86,10 +86,10 @@ public: } } - static const std::vector<std::string> getInputsName(){ + static const std::vector<std::string> getInputsName() { return {"data_input_0", "data_input_n"}; } - static const std::vector<std::string> getOutputsName(){ + static const std::vector<std::string> getOutputsName() { return {"data_output"}; } }; diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 5fb1d5b16c55f7f5b6cea4db02d3aa955831e08b..469d8485afe39692847ad88726ebca5926708c84 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -36,7 +36,7 @@ class AvgPooling_Op : public OperatorTensor, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "AvgPooling"; + static const std::string Type; AvgPooling_Op() = delete; @@ -150,6 +150,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string AvgPooling_Op<DIM>::Type = "AvgPooling"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index be850d377e5a1781b2cb04b5040c257ecc30cd92..31dbdd4df2340953e408d0ff5744cb4ff8ce3e9d 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -33,7 +33,7 @@ class BatchNorm_Op : public OperatorTensor, public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>, public StaticAttributes<BatchNormAttr, float, float> { public: - static constexpr const char *Type = "BatchNorm"; + static const std::string Type; BatchNorm_Op() = delete; @@ -82,9 +82,9 @@ public: associated &= !(getInput(i)->empty()); } if (associated) { - const DimSize_t nbChannels = getInput(0)->dims()[1]; + const DimSize_t nbFeatures = getInput(0)->dims()[1]; for (std::size_t i = nbData(); i < nbInputs(); ++i) { - if(getInput(i)->size() != nbChannels) { + if(getInput(i)->size() != nbFeatures) { // /!\ Input size should be handled BEFORE calling this function // This should raise an error getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]})); @@ -113,16 +113,20 @@ public: } }; +template <DimIdx_t DIM> +const std::string BatchNorm_Op<DIM>::Type = "BatchNorm"; + template <DimSize_t DIM> -inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, +inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures, + const float epsilon = 1.0e-5F, const float momentum = 0.1F, const std::string& name = "") { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); - addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); - addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift"); - addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean"); - addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance"); + addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale"); + addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift"); + addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean"); + addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance"); return batchNorm; } } // namespace Aidge diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 78e21f85250c361053857e27c582e1487aeec64e..080f763cb176c463f3e03a672de4a13cf05a497b 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -32,7 +32,7 @@ class Concat_Op : public OperatorTensor, public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>, public StaticAttributes<ConcatAttr, DimSize_t> { public: - static constexpr const char* Type = "Concat"; + static const std::string Type; using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>; template <ConcatAttr e> diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index b62d393bc37859f24c4f54f8ce1ba4458bf11ab4..194ac313dd7f9b22c55fdbe7e0e30d37d816bcb8 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -36,7 +36,7 @@ class Conv_Op : public OperatorTensor, DimSize_t, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "Conv"; + static const std::string Type; Conv_Op() = delete; @@ -188,6 +188,9 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel } }; +template <DimIdx_t DIM> +const std::string Conv_Op<DIM>::Type = "Conv"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> Conv(DimSize_t inChannels, DimSize_t outChannels, diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index c95315f6d63e817354fc82dded4e3cfb4ed1b704..6f1f3f7ffbaf8dd750f374f2b391ccc90fad8254 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -37,7 +37,7 @@ class ConvDepthWise_Op : public OperatorTensor, DimSize_t, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "ConvDepthWise"; + static const std::string Type; ConvDepthWise_Op() = delete; @@ -182,6 +182,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string ConvDepthWise_Op<DIM>::Type = "ConvDepthWise"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> ConvDepthWise(const DimSize_t nbChannels, const std::array<DimSize_t, DIM> &kernelDims, diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index fcdb03a6be36bc9e1be7d69d01005f92b535d00c..84de3308efcc07fa14bb3663ee7b66fde3f22123 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -29,7 +29,7 @@ class Div_Op : public OperatorTensor, public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> { public: - static constexpr const char* Type = "Div"; + static const std::string Type; Div_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 8dea38335dd052f2dbf7d0aa7fc4f7fe84741a06..ecd2b97ea8a524736e6dc3a44819df29bbf4e3d8 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -35,7 +35,7 @@ class FC_Op : public OperatorTensor, std::unique_ptr<OperatorImpl>(const FC_Op &)>, public StaticAttributes<FCAttr, DimSize_t, bool> { public: - static constexpr const char* Type = "FC"; + static const std::string Type; FC_Op() = delete; diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 505c5344990453c8f4ab84fa3893e75b216d7a54..0d0008c6307cd98fd5bba3d3480e7d225c70aa01 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -36,7 +36,7 @@ private: ComputeDimsFunc mComputeOutputDims; public: - GenericOperator_Op(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) + GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) : OperatorTensor(type, nbData, nbParam, nbOut) {} @@ -125,7 +125,7 @@ public: * @param name (optional) name of the Operator. * @return std::shared_ptr<Node> Node associated with the Generic Operator. */ -inline std::shared_ptr<Node> GenericOperator(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut, +inline std::shared_ptr<Node> GenericOperator(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name); } diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index c5cd9bb62e0097c9a0e646caaf14cddd73bf512d..c0be78646f7dccf732c49e4aea45bf139b49ad9e 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -37,7 +37,7 @@ namespace Aidge { class Identity_Op : public OperatorTensor, public Registrable<Identity_Op, std::string, std::unique_ptr<OperatorImpl>(const Identity_Op&)> { public: - static constexpr const char* Type = "Identity"; + static const std::string Type; Identity_Op() : OperatorTensor(Type, 1, 0, 0) @@ -105,9 +105,11 @@ public: } void setBackend(const std::string& name) override final { // setBackend do nothing, Identity node has no backend it just pass the same Tensor + (void) name; } void setDataType(const DataType& dataType) const override final { // setDatatype do nothing, Identity node has no backend it just pass the same Tensor + (void) dataType; } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 2474e2e5af4139b77cace03b27b603fb66b7699a..f9bbef46283ba8b9b480c1eba0a11c6caf954897 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -33,7 +33,7 @@ class LeakyReLU_Op : public OperatorTensor, public Registrable<LeakyReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const LeakyReLU_Op&)>, public StaticAttributes<LeakyReLUAttr, float> { public: - static constexpr const char* Type = "LeakyReLU"; + static const std::string Type; LeakyReLU_Op() = delete; diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 90930dd22a36f84a7479e245eb09d9c28dfd031d..1014488a77a6ffe5b6048cfc23da669416710c92 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -35,7 +35,7 @@ class MatMul_Op : public OperatorTensor, std::unique_ptr<OperatorImpl>(const MatMul_Op &)>, public StaticAttributes<MatMulAttr, DimSize_t> { public: - static constexpr const char* Type = "MatMul"; + static const std::string Type; MatMul_Op() = delete; diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index c46ddb3797e2303ee27814c96ef060156bdc9108..0a292449385807a4deb8b7d0458720c9d9a8e99f 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -36,7 +36,7 @@ class MaxPooling_Op : public OperatorTensor, std::array<DimSize_t, DIM>, bool> { public: - static constexpr const char *Type = "MaxPooling"; + static const std::string Type; MaxPooling_Op() = delete; @@ -120,6 +120,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string MaxPooling_Op<DIM>::Type = "MaxPooling"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> MaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 4c8feb46c3e3db33bd380302e3e0683f1b8734f5..652cecd44537963c3ee4743729d2e98c569e7de6 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -25,16 +25,9 @@ public: // Micro-graph handling: std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph std::shared_ptr<SequentialScheduler> mScheduler; - // Need to store an ordored list of input/output operators for the micro-graph, - // because input/output nodes in a GraphView are unordered. - // TODO: refactor GraphView to handle ordered input/output? - std::vector<std::pair<std::shared_ptr<OperatorTensor>, IOIndex_t>> mInputOps; - std::vector<std::pair<std::shared_ptr<OperatorTensor>, IOIndex_t>> mOutputOps; public: - MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, - std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), - std::vector<NodePtr> outputNodes = std::vector<NodePtr>()); + MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph); /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -47,7 +40,7 @@ public: /** * @brief Clone the operator using its copy-constructor. - * @see Operator::MatMul_Op + * @see Operator::MetaOperator_Op */ std::shared_ptr<Operator> clone() const override { return std::make_shared<MetaOperator_Op>(*this); @@ -64,8 +57,8 @@ public: void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); - const auto& inputOp = mInputOps[inputIdx]; - inputOp.first->associateInput(inputOp.second, data); + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->associateInput(inputOp.second, data); // Associate inputs for custom implementation mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -110,11 +103,9 @@ public: inline std::shared_ptr<Node> MetaOperator(const char *type, const std::shared_ptr<GraphView>& graph, - const std::string& name = "", - std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), - std::vector<NodePtr> outputNodes = std::vector<NodePtr>()) + const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph, inputNodes, outputNodes), name); + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name); } } // namespace Aidge diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index 9ec6cdb928cdfa433b04ea23c69344133a3c7064..615b8960403270efa1fe97235dbfeeb129338d5b 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -32,10 +32,8 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, // Construct micro-graph auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); - // Need to specify the ordered list of input operators - const std::vector<NodePtr> orderedInputNodes = {pad, conv}; - auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name, orderedInputNodes); + auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name); addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); addProducer(metaOp, 2, {out_channels}, "b"); return metaOp; @@ -66,10 +64,8 @@ inline std::shared_ptr<Node> PaddedConvDepthWise(const DimSize_t nb_channels, // Construct micro-graph auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(nb_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); - // Need to specify the ordered list of input operators - const std::vector<NodePtr> orderedInputNodes = {pad, conv}; - auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name, orderedInputNodes); + auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name); addProducer(metaOp, 1, std::array<DimSize_t,0>({}), "w"); addProducer(metaOp, 2, std::array<DimSize_t,0>({}), "b"); return metaOp; diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index 337fe6e65cc040e67ee033516731a7ba8de86d2d..47da898829f9581d4907ddad97bf847c6746a536 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -31,7 +31,7 @@ namespace Aidge { class Mul_Op : public OperatorTensor, public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> { public: - static constexpr const char* Type = "Mul"; + static const std::string Type; Mul_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 1f4cdd23f9a765924305ebeb43e3e6ee1ad73496..5cd35be72aa4ecf880818aaf10dddbb11735e53e 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -44,7 +44,7 @@ private: public: Operator() = delete; - Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) + Operator(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) : mType(type), mOperatorType(operatorType), mNbData(nbData), diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 126e5d467d0f341a8c5b8c5d16d188ebe92135d0..b956da474311b5863690f5a5e40329e443f1345a 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -40,7 +40,7 @@ protected: public: OperatorTensor() = delete; - OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, + OperatorTensor(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut) : Operator(type, nbData, nbParam, nbOut, OperatorType::Tensor), mInputs(std::vector<std::shared_ptr<Tensor>>(nbData + nbParam, nullptr)), diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp index 279b8b3d2c173d18c65c17e50385954a88fde77e..38829bab613981565bc20b88e299fa1e197f1c08 100644 --- a/include/aidge/operator/Pad.hpp +++ b/include/aidge/operator/Pad.hpp @@ -37,7 +37,7 @@ class Pad_Op : public OperatorTensor, PadBorderType, double> { public: - static constexpr const char *Type = "Pad"; + static const std::string Type; Pad_Op() = delete; @@ -113,6 +113,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string Pad_Op<DIM>::Type = "Pad"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, 2*DIM> &beginEndTuples, const std::string& name = "", diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index a5cd3a9b047f9a32665cc2de1ead4f2221fed4aa..ee22bd9aec908a66d2ca6cbac0b9a8dcd5dec409 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -29,7 +29,7 @@ namespace Aidge { class Pow_Op : public OperatorTensor, public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> { public: - static constexpr const char* Type = "Pow"; + static const std::string Type; Pow_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index a3f6e085ce3849c1b057f0fdb043093b338b48a1..1440a939f13da54dcae2cebedb0d4d807d8244d7 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -29,7 +29,7 @@ class Producer_Op public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( const Producer_Op &)> { public: - static constexpr const char* Type = "Producer"; + static const std::string Type; template <std::size_t DIM> Producer_Op(const std::array<DimSize_t, DIM>& dims) diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 15dec9be8516f71f5f4dfd0aec6a2985671da53d..e72db011795639c5231e6afe5fbd24bbbc71b8c5 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -28,7 +28,7 @@ namespace Aidge { class ReLU_Op : public OperatorTensor, public Registrable<ReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const ReLU_Op&)> { public: - static constexpr const char* Type = "ReLU"; + static const std::string Type; ReLU_Op() : OperatorTensor(Type, 1, 0, 1) {} diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 98e082ac27f7cdf90d5d0464d811f116ae9f59ae..b64c9f9b9513a97295ca5aa75db3f6e2979b2eef 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -32,7 +32,7 @@ class Scaling_Op : public OperatorTensor, public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, public StaticAttributes<ScalingAttr, float, size_t, bool> { public: - static constexpr const char* Type = "Scaling"; + static const std::string Type; Scaling_Op() = delete; diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index b92c1818d49b53d4a2eda9a8d2704a06ca2980ca..5968fdeb40ba864802fbdc5a164f4e8837ee788b 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -31,7 +31,7 @@ class Slice_Op public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, public StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>> { public: - static constexpr const char *Type = "Slice"; + static const std::string Type; Slice_Op() = delete; diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index d5c91945e83469dc9c6fef2b5adef026790b568d..bcf9a5a66147b821a062cd6b93087cb1c45bca00 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -29,7 +29,7 @@ namespace Aidge { class Softmax_Op : public OperatorTensor, public Registrable<Softmax_Op, std::string, std::unique_ptr<OperatorImpl>(const Softmax_Op&)> { public: - static constexpr const char* Type = "Softmax"; + static const std::string Type; Softmax_Op() : OperatorTensor(Type, 1, 0, 1) {} diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp index 1fe609fc2913afcda735ba2859126188aad4de5f..5eb4d89308d9684811876588917ab53efd1bd069 100644 --- a/include/aidge/operator/Sqrt.hpp +++ b/include/aidge/operator/Sqrt.hpp @@ -34,7 +34,7 @@ public: const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: - static constexpr const char* Type = "Sqrt"; + static const std::string Type; Sqrt_Op() : OperatorTensor(Type, 1, 0, 1) {} diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index d141ad42015838e89e6d59c22bcefe56e795170c..fad65e00e973c6b0352de2bdf5e43a79b4f3d4e4 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -34,7 +34,7 @@ public: const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: - static constexpr const char* Type = "Sub"; + static const std::string Type; Sub_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp index a17ead8f8f5fa5106c375050ef5b82e6f149535a..5ad08a6582aa886604d0068f75cab9fe1631b05e 100644 --- a/include/aidge/recipies/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -38,7 +38,7 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void fuseMulAdd(std::shared_ptr<GraphView> graphView); @@ -58,7 +58,7 @@ void removeFlatten(std::shared_ptr<MatchSolution> solution); /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void removeFlatten(std::shared_ptr<GraphView> graphView); @@ -80,7 +80,7 @@ void fuseBatchNorm(std::shared_ptr<MatchSolution> solution); * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void fuseBatchNorm(std::shared_ptr<GraphView> graphView); diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 12a007ab911585df29c814d9a5904013449ec5bf..b5b9ed37d877ecd4e22fb975e4606069f5e36037 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -56,7 +56,7 @@ void init_GraphView(py::module& m) { :type include_learnable_parameters: bool, optional )mydelimiter") - .def("add", (void (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add, + .def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add, py::arg("other_graph"), R"mydelimiter( Include a GraphView to the current GraphView object. diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index ff0b9e0dfcb0d1c5e5567a938b1ca74faf242bed..411a2e1b6ae78065a79b92f25c23dac13e341997 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) { .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); - m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } void init_BatchNorm(py::module &m) { diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 71231b8218ac6af28c97ec29039301bc25b2d195..2200cd3fec1450011d6e0b5197f8b99b4dfeb4c3 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -11,7 +11,6 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> -#include <iostream> #include <string> #include <vector> #include <array> diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index d1eff7b387f9b339e6641a8049e020a7e8a4f021..f5c5145e0a86d939b96e6d2a579dfa2579f8b3a5 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -128,9 +128,7 @@ void init_MetaOperatorDefs(py::module &m) { m.def("meta_operator", &MetaOperator, py::arg("type"), py::arg("graph"), - py::arg("name") = "", - py::arg("input_nodes") = std::vector<NodePtr>(), - py::arg("output_nodes") = std::vector<NodePtr>() + py::arg("name") = "" ); } diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index ce956d115e282c43751619070dd8a10ac5c9cfae..554c535f229af0ab5b59fa6f57607c7bacd872fa 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -71,33 +71,74 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { typeCounter[currentType] = 0; ++typeCounter[currentType]; - const std::string givenName = + std::string givenName = (node_ptr->name().empty()) - ? currentType + std::to_string(typeCounter[currentType]) - : node_ptr->name(); + ? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>" + : "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + "#" + std::to_string(typeCounter[currentType]) + " )</em></sub>\""; namePtrTable[node_ptr] = (currentType + "_" + std::to_string(typeCounter[currentType])); - std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), - givenName.c_str()); + + if (node_ptr == mRootNode) { + std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } + else { + std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } } + // Write every link - std::size_t emptyInputCounter = 0; for (const std::shared_ptr<Node> &node_ptr : mNodes) { - for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) { - if ((pa_ptr == nullptr) || !inView(pa_ptr)) { - std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter, - emptyInputCounter, namePtrTable[node_ptr].c_str()); - ++emptyInputCounter; - } else { - std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(), - namePtrTable[node_ptr].c_str()); - } + IOIndex_t outputIdx = 0; + for (auto childs : node_ptr->getOrderedChildren()) { + for (auto child : childs) { + if (child != nullptr) { + IOIndex_t inputIdx = 0; + for (auto parent : child->inputs()) { + if (parent.first == node_ptr && parent.second == outputIdx) { + if (mNodes.find(child) != mNodes.end()) { + std::fprintf(fp, "%s-->|%u→%u|%s\n", namePtrTable[node_ptr].c_str(), + outputIdx, inputIdx, namePtrTable[child].c_str()); + } + else if (verbose) { + std::fprintf(fp, "%s-->|%u→%u|%p:::externalCls\n", namePtrTable[node_ptr].c_str(), + outputIdx, inputIdx, static_cast<void*>(child.get())); + } + break; + } + ++inputIdx; + } + } } + ++outputIdx; + } + } + + size_t inputIdx = 0; + for (auto input : mInputNodes) { + std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|→%u|%s\n", inputIdx, inputIdx, + input.second, namePtrTable[input.first].c_str()); + ++inputIdx; + } + + size_t outputIdx = 0; + for (auto output : mOutputNodes) { + std::fprintf(fp, "%s--->|%u→|output%lu((out#%lu)):::outputCls\n", + namePtrTable[output.first].c_str(), output.second, + outputIdx, outputIdx); + ++outputIdx; } + + std::fprintf(fp, "classDef inputCls fill:#afa\n"); + std::fprintf(fp, "classDef outputCls fill:#ffa\n"); + std::fprintf(fp, "classDef externalCls fill:#ccc\n"); + std::fprintf(fp, "classDef rootCls stroke:#f00\n"); + if (verbose) { - for (const auto &c : typeCounter) { + for (const auto &c : typeCounter) { std::printf("%s - %zu\n", c.first.c_str(), c.second); - } + } } std::fprintf(fp, "\n"); @@ -108,20 +149,60 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { // TENSOR MANAGEMENT /////////////////////////////////////////////////////// +void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) { + AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs"); + + std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes); + for (auto input : inputs) { + auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input); + AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input"); + ignoredInputs.erase(it); + } + + mInputNodes = inputs; + mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end()); +} + +void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) { + AIDGE_ASSERT(outputs.size() <= mOutputNodes.size(), "too many specified number of outputs"); + + std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes); + for (auto output : outputs) { + auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output); + AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output"); + ignoredOutputs.erase(it); + } + + mOutputNodes = outputs; + mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end()); +} + Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const { - return std::accumulate(mInputNodes.cbegin(), mInputNodes.cend(), 0, - [](IOIndex_t sumData, const std::shared_ptr<Node> inNode) { - return sumData + inNode->nbData(); - } - ); + IOIndex_t nbDataInput = 0; + for (const std::shared_ptr<Node> &inNode : inputNodes()) { + // We cannot simply add inNode->nbDataInputs(), as input nodes may already + // have some inputs connected within the GraphView, which would therefore not + // constitue inputs (from outside) for the GraphView! + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inNode->dataInputs(); + + for (const auto& input : inputNodeinputs) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { + ++nbDataInput; + } + } + } + return nbDataInput; } Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { - return std::accumulate(mInputNodes.cbegin(), mInputNodes.cend(), 0, - [](IOIndex_t sumData, const std::shared_ptr<Node> inNode) { - return sumData + inNode->getNbFreeDataInputs(); - } - ); + IOIndex_t nbIn = 0; + // Free inputs within the GraphView are logically also free inputs from outside + // the GraphView. + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { + nbIn += inputNode->getNbFreeDataInputs(); + } + return nbIn; } @@ -129,12 +210,12 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); for (const auto& input : inputNodeinputs) { - if (mNodes.find(input.first) == mNodes.end()) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } @@ -147,12 +228,12 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); for (const auto& input : inputNodeinputs) { - if (mNodes.find(input.first) == mNodes.end()) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } @@ -250,68 +331,28 @@ void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) { } } -void Aidge::GraphView::updateOutputNodes() { - mOutputNodes.clear(); - for (const std::shared_ptr<Node>& go_it : mNodes) { - if (go_it->nbOutputs() != - go_it->nbValidOutputs()) { // an output linked to nothing - mOutputNodes.insert(go_it); - continue; - } - for (const std::shared_ptr<Node>& ch_ptr : go_it->getChildren()) { - if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph - mOutputNodes.insert(go_it); - break; - } - } - } -} - -void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) { - if (node->nbOutputs() != - node->nbValidOutputs()) { // an output linked to nothing - mOutputNodes.insert(node); - } else { // don't enter if was already added to outputNodes - for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) { - if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph - mOutputNodes.insert(node); - break; - } - } - } - // update other outputNodes - for (const std::shared_ptr<Node> &pa_ptr : - node->getParents()) { // check if any parent is in OutputNodes too - if ((pa_ptr != nullptr) && - (mOutputNodes.find(pa_ptr) != - mOutputNodes.end())) { // it's a match! Must check if the outputNode - // found is still an outputNode - bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs()); - for (const std::shared_ptr<Node>& ch_ptr : pa_ptr->getChildren()) { - if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph - remove = false; - break; - } - } - if (remove) { - mOutputNodes.erase(pa_ptr); - } - } - } -} - std::vector< std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::GraphView::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> - outputTensors; - for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { - std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> - tmpOutputs = (outputNode->outputs()); - outputTensors.insert(outputTensors.end(), tmpOutputs.begin(), - tmpOutputs.end()); + outsideOutputs; + for (const std::shared_ptr<Node>& outputNode : outputNodes()) { + const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> + outputNodeOutputs = outputNode->outputs(); + + for (const auto& outputPos : outputNodeOutputs) { + // Keep only the nodes connected at this output position that are outside the GraphView + std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos; + for (const auto& output : outputPos) { + if (mNodes.find(output.first) == mNodes.end()) { + outsideOutputPos.push_back(output); + } + } + + outsideOutputs.push_back(outsideOutputPos); + } } - return outputTensors; + return outsideOutputs; } std::vector< @@ -326,11 +367,20 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/, } void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) { + // first node to be added to the graph is the root node by default + if (mRootNode == nullptr) { + mRootNode = node; + } + // add to the GraphView nodes node->addView(shared_from_this()); mNodes.insert(node); if (!(node->name()).empty()) mNodeRegistry.insert(std::make_pair(node->name(), node)); + + // check if the node is an input/output node + updateInputsOutputsNew(node); + // add learnable parameters to the graph if (includeLearnableParam) { for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) { @@ -340,33 +390,124 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara mNodes.insert(parentNode); if (!(parentNode->name()).empty()) mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); - // check if the Node is an input node - updateInputNodes(parentNode); + // check if the parentNode is an input/output node + updateInputsOutputsNew(parentNode); + } + } + } +} + +bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { + if (otherNodes.empty()) { + return true; + } + + bool orderUnicity = true; + + // List only the nodes that are not already present in current graph + std::set<NodePtr> nodesToAdd; + std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin())); + + // List the nodes to rank, initially all the nodes in the GraphView + std::set<NodePtr> nodesToRank(mNodes); + nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end()); + std::vector<NodePtr> rankedNodesToAdd; + + if (mRootNode == nullptr) { + std::set<NodePtr> noParentNodes; + + // If no root node is defined, check nodes without parents + for (auto node : nodesToRank) { + bool noParent = true; + for (auto parent : node->getParents()) { + if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) { + noParent = false; + break; + } + } + + if (noParent) { + noParentNodes.insert(node); } } + + // Take the first one found (this is an arbitrary choice) + mRootNode = *noParentNodes.begin(); + + if (noParentNodes.size() > 1) { + // If there is more than one, order unicity cannot be garanteed! + orderUnicity = false; + } + + rankedNodesToAdd.push_back(mRootNode); } - // check if the Node is an input node - updateInputNodes(node); - // check if the Node is an input node - updateOutputNodes(node); + + nodesToRank.erase(mRootNode); + std::vector<NodePtr> rankedNodes; + rankedNodes.push_back(mRootNode); + + for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) { + NodePtr curNode = rankedNodes[curNodeIdx]; + + for (auto childs : curNode->getOrderedChildren()) { + for (auto child : childs) { + if (nodesToRank.find(child) != nodesToRank.end()) { + rankedNodes.push_back(child); + nodesToRank.erase(child); + + if (nodesToAdd.find(child) != nodesToAdd.end()) { + rankedNodesToAdd.push_back(child); + nodesToAdd.erase(child); + } + } + } + } + + for (auto parent : curNode->getParents()) { + if (nodesToRank.find(parent) != nodesToRank.end()) { + rankedNodes.push_back(parent); + nodesToRank.erase(parent); + + if (nodesToAdd.find(parent) != nodesToAdd.end()) { + rankedNodesToAdd.push_back(parent); + nodesToAdd.erase(parent); + } + } + } + } + + if (!nodesToAdd.empty()) { + // There are remaining nodes without path to the root node + orderUnicity = false; + + while (!nodesToAdd.empty()) { + const auto it = nodesToAdd.begin(); + rankedNodesToAdd.push_back(*it); + nodesToAdd.erase(it); + } + } + + for (auto node_ptr : rankedNodesToAdd) { + add(node_ptr, includeLearnableParam); + } + + return orderUnicity; } -void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { - for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); } +bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) { + if (nodes.first != nullptr) { + mRootNode = nodes.first; + add(nodes.first, includeLearnableParam); + } + return add(nodes.second, includeLearnableParam); } -void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { - for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) { - node_ptr->addView(shared_from_this()); - mNodes.insert(node_ptr); - if (!(node_ptr->name()).empty()) - mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr)); - // if node_ptr is part of graph inputNodes or outputNodes - // if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) { - // Update OutputNodes/inputNodes - updateInputNodes(); - updateOutputNodes(); +bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { + if (mRootNode == nullptr) { + mRootNode = graph->getRootNode(); } + + return add(graph->getNodes(), false); } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, @@ -414,7 +555,7 @@ void Aidge::GraphView::addChild( std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const { // TODO: choose if we return a set or a vector std::set<std::shared_ptr<Node>> parents; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { parents.insert(inputNode->getParents().begin(), inputNode->getParents().end()); } @@ -433,7 +574,7 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::GraphView::getOrderedParents() const { std::vector<std::vector<std::shared_ptr<Node>>> parents; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { parents.push_back(inputNode->getParents()); } return parents; @@ -441,7 +582,7 @@ Aidge::GraphView::getOrderedParents() const { std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { + for (const std::shared_ptr<Node>& outputNode : outputNodes()) { children.insert((outputNode->getChildren()).begin(), (outputNode->getChildren()).end()); } @@ -485,38 +626,44 @@ Aidge::GraphView::getNode(const std::string& nodeName) const { void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) { - if (mNodes.find(nodePtr) != mNodes.end()) { - mNodes.erase(nodePtr); - nodePtr->removeView(shared_from_this()); - } - if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } - // same for learnable params - + // remove learnable params if (includeLearnableParam) { for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) { auto inputI = nodePtr->input(i); - bool removeNode = true; - for (const auto& parentOutput : inputI.first->outputs()) { - for (const auto& childOfParentOutput : parentOutput) { - // only remove the learnable parameter if not related to any other Node in the GraphView - if (childOfParentOutput.first != nodePtr) { - removeNode = false; - break; + if (inputI.first != nullptr) { + bool removeNode = true; + for (const auto& parentOutput : inputI.first->outputs()) { + for (const auto& childOfParentOutput : parentOutput) { + // only remove the learnable parameter if not related to any other Node in the GraphView + if (childOfParentOutput.first != nodePtr) { + removeNode = false; + break; + } } } - } - if (removeNode) { - // assert Learnable Parameter in the GraphView scope - if (mNodes.find(inputI.first) != mNodes.end()) { - mNodes.erase(inputI.first); - inputI.first->removeView(shared_from_this()); + if (removeNode) { + // assert Learnable Parameter in the GraphView scope + if (mNodes.find(inputI.first) != mNodes.end()) { + mNodes.erase(inputI.first); + inputI.first->removeView(shared_from_this()); + } + if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } + + // check if the node was an input/output node + updateInputsOutputsDelete(inputI.first); } - if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } } } } - updateInputNodes(); - updateOutputNodes(); + + if (mNodes.find(nodePtr) != mNodes.end()) { + mNodes.erase(nodePtr); + nodePtr->removeView(shared_from_this()); + + // check if the nodePtr was an input/output node + updateInputsOutputsDelete(nodePtr); + } + if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } } @@ -547,211 +694,354 @@ void Aidge::GraphView::insertParent(NodePtr childNode, add(newParentNode); } - bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { - // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // How to distinguish it from data input? // TODO: Parameter Tensors could be identified with their dimensions // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // It also avoids specifying each producer since they are automatically included + // (1) create GraphViews from both sets of Nodes auto oldG = std::make_shared<GraphView>("oldG"); oldG->add(oldNodes, false); auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); - if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) { - return false; + const auto oldOI = oldG->getOrderedInputs(); + const auto oldOO = oldG->getOrderedOutputs(); + const auto newOI = newG->getOrderedInputs(); + const auto newOO = newG->getOrderedOutputs(); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size()); + + // keep in memory every parent + for (std::size_t i = 0; i < oldOI.size(); ++i) { + auto inputParent = oldOI[i].first -> input(oldOI[i].second); + inputParents[i]= inputParent; + // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); } - if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) || - (newG->outputNodes().size() != 1))) { - return false; + for (std::size_t i = 0; i < oldOO.size();) { + auto outputChildList = oldOO[i].first -> output(oldOO[i].second); + if (outputChildList.empty()) { + outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); + ++i; + } + else { + for (const auto& child : outputChildList) { + if (oldNodes.find(child.first) == oldNodes.cend()) { + outputChildren[i] = child; + ++i; + } + } + } } - // there is at least one inputNode in the old/new GraphView - std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin()); - std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin()); + // only keep common views to each node for the new set + // set of common GraphView for oldNodes' Nodes + std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); + for (const auto& nodePtr : oldNodes) { + const auto nodeView = nodePtr->views(); + std::set<std::shared_ptr<GraphView>> intersection; + std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), + nodeView.begin(), nodeView.end(), + std::inserter(intersection, intersection.begin())); + commonGraphViews = intersection; + } + commonGraphViews.erase(oldG); + commonGraphViews.erase(newG); - // find Node to link to new input Node - //compute number of input for firstPreviousInputNode not in oldNodes set - std::size_t nbExternalInputs = 0; - std::shared_ptr<Node> externalInput = nullptr; - IOIndex_t externalInputId = gk_IODefaultIndex; - for (const auto& input : firstPreviousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG - nbExternalInputs++; - externalInput = input.first; - externalInputId = input.second; + if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); } + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; } - if (nbExternalInputs > 1) { - AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); + + for (const auto& nodePtr : oldNodes) { + for (const auto& g : commonGraphViews) { + g -> remove(nodePtr, false); + g -> updateInputsOutputsDelete(nodePtr); + } + nodePtr -> resetConnections(true); } - if (oldG->inputNodes().size() > 1){ - // one or no input has been identified. Checking every input points to the same source - for (const auto& previousInputNode : oldG->inputNodes()) { - for (const auto& input : previousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { - if ( (externalInput != input.first) || (externalInputId != input.second) ) { - return false; // an inputNode points to an external Node different from the registered one - } + if ((oldOI.size() == newOI.size()) && + (oldOO.size() == newOO.size())) { + // Case 1 + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) { + inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second); + } + } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); + } + } + } + else { + // get the number of Parents for oldG->inputNodes() + // get the number of Children for oldg->outputNodes() + if (newNodes.size() == 0) { + // Case 3 + if (oldOI.size() == oldOO.size()) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) + inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); + } + } + else if (oldOI.size() == 1) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); + } + } + } + else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes + ((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1 + ((oldOO.size() == newOO.size())) + ) { + // Case 2 + if ((oldOI.size() == 1)) { + for (std::size_t i = 0; i < newOI.size(); ++i) { + inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second); + } + } else { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); + } + } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); } } } + else { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); + } + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; + } } - - if (firstPreviousOutputNode->nbOutputs() != 1) { - return false; + for (const auto& nodePtr : newNodes) { + for (const auto& g : commonGraphViews) { + g -> add(nodePtr); + } } - - // find Node to replicate output connections - std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); - - auto copyOutputs = firstPreviousOutputNode->outputs(); - // manage Views for newNodes - // only keep common views to each node for the new set - std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); for (const auto& nodePtr : oldNodes) { - const auto nodeView = nodePtr->views(); - std::set<std::shared_ptr<GraphView>> intersection; - std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), - nodeView.begin(), nodeView.end(), - std::inserter(intersection, intersection.begin())); - commonGraphViews = intersection; + nodePtr -> removeView(oldG); } - commonGraphViews.erase(oldG); - commonGraphViews.erase(newG); + for (const auto& nodePtr : newNodes) { + nodePtr -> removeView(newG); + } + return true; +} - // clean Nodes to replace - // Do not include common Nodes to avoid cleaning Producers linked to newNodes - std::set<std::shared_ptr<Node>> nodesToClean; - std::set_difference(oldNodes.begin(), oldNodes.end(), - newNodes.begin(), newNodes.end(), - std::inserter(nodesToClean, nodesToClean.begin())); - for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } - - // copy output connections - if (newOutputNode) { - for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { - auto outputPairs = copyOutputs[o]; - for (const auto& onePair : outputPairs) { - newOutputNode->addChild(onePair.first, o, onePair.second); +void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { + // Can be called several times with the same node, e.g. when addChild() is + // called on a node already part of the GraphView. In this case, inputs/outputs + // need to be updated! + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); + + // Remove inputs that are not input anymore because connected to newNode + for (auto orderedChilds : newNode->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // Check that newNode child is in current GraphView + if (mNodes.find(ch_ptr) != mNodes.cend()) { + IOIndex_t inputIdx = 0; + for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { + // If newNode is connected to it + if (pa_ptr == newNode) { + const auto val = std::make_pair(ch_ptr, inputIdx); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + + // Check that it was not already the case (if node UPDATE) + if (iter != mInputNodes.cend()) { // newNode is linked to an actual inputNode to an input connection + // The first old (removed) input becomes the insertion point for newNode GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); + } } + } + ++inputIdx; } + } } + } - // copy input connections - if (!newNodes.empty() && externalInput) { - for (const auto& newInputNode : newG->inputNodes()) { - IOIndex_t inputId = 0; - for (const auto& input : newInputNode->inputs()) { - if (newNodes.find(input.first) == newNodes.end()) { - externalInput->addChild(newInputNode, externalInputId, inputId); - } - inputId++; + // Manage newNode parents + // Check if any input connection is an input for the GraphView + IOIndex_t inputIdx = 0U; + for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { + const auto val = std::make_pair(newNode, inputIdx); + const auto it = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == mNodes.cend())) { + // Parent doesn't exist || Parent not in the graph + if (it == mInputNodes.cend()) { + // If node's inputs are inputs for the GraphView: add them to the input list + // Addition rule: + // - Inputs addition order follows node inputs order + // - Inputs are inserted at the position of the first input removed + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); } + } else if (it != mInputNodes.cend()) { + // Parent already in the graph SO edge is not an input anymore for the graph + mInputNodes.erase(it); } + ++inputIdx; } - // insert new Nodes in the right GraphViews - for (const auto& graphPtr : commonGraphViews) { - graphPtr->add(newNodes, false); - if (newNodes.empty()) { - graphPtr->updateInputNodes(); - graphPtr->updateOutputNodes(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); + + // Remove outputs that are not output anymore because connected to newNode + for (const std::shared_ptr<Node>& parent : newNode->getParents()) { + // Check that newNode parent is in current GraphView + if (mNodes.find(parent) != mNodes.cend()) { + IOIndex_t outputIdx = 0; + for (auto orderedChilds : parent->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // If newNode is connected to it + if (ch_ptr == newNode) { + const auto val = std::make_pair(parent, outputIdx); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); + + if (iter != mOutputNodes.cend()) { + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } + } + } } + ++outputIdx; + } } + } - for (const auto& node : oldNodes) { - node->removeView(oldG); + // Check if node outputs are outputs for the GraphView and add them to the output list if so + IOIndex_t outputIdx = 0; + for (auto orderedChilds : newNode->getOrderedChildren()) { + bool noInsideConnection = true; + for (auto ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.end()) { + noInsideConnection = false; + break; + } } - for (const auto& node : newNodes) { - node->removeView(newG); + + if (noInsideConnection) { + const auto val = std::make_pair(newNode, outputIdx); + // Output may be already be present (see addChild() with a node already in GraphView) + if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); + } } - return true; + ++outputIdx; + } } +void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) { + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); -void Aidge::GraphView::updateInputNodes() { - mInputNodes.clear(); - for (const std::shared_ptr<Node>& go_ptr : mNodes) { - for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) { - if ((pa_ptr == nullptr) || - (mNodes.find(pa_ptr) == - mNodes.end())) { // Parent doesn't exist || Parent not in the graph - mInputNodes.insert(go_ptr); - break; + // Check if node inputs were inputs for the GraphView and remove them from the list if so + for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) { + const auto val = std::make_pair(deletedNode, inputIdx); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + + if (iter != mInputNodes.cend()) { + // The first old (removed) input becomes the insertion point for new GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); } } } -} -void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { - // add node_ptr to inputNode if it can - std::size_t filledWithKnownInputs = 0U; - bool wasAdded = mInputNodes.find(node) != mInputNodes.end(); - for (const std::shared_ptr<Node>& pa_ptr : node->getParents()) { - if ((pa_ptr == nullptr) || - (mNodes.find(pa_ptr) == - mNodes.end())) { // Parent doesn't exist || Parent not in the graph - mInputNodes.insert(node); - wasAdded = true; - break; - } - ++filledWithKnownInputs; - } - if (filledWithKnownInputs == node->nbInputs() && wasAdded) { - mInputNodes.erase(node); - } - // update other inputNodes - for (const std::shared_ptr<Node>& ch_ptr : - node->getChildren()) { // check if any child is in InputNodes too - if (mInputNodes.find(ch_ptr) != - mInputNodes.end()) { // it's a match! Must check if the inputNode found - // is still an inputNode - // change here - bool remove = true; - for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { - if (pa_ptr == nullptr || - mNodes.find(pa_ptr) == - mNodes - .end()) { // Parent doesn't exist || Parent not in the graph - remove = false; - break; + // Add child node inputs that become GraphView input following the removal of the node + // Inputs addition order follows deletedNode outputs order + for (auto orderedChilds : deletedNode->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // Check that deletedNode child is in current GraphView + if (mNodes.find(ch_ptr) != mNodes.cend()) { + IOIndex_t inputIdx = 0; + for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { + // If newNode was connected to it + if (pa_ptr == deletedNode) { + const auto val = std::make_pair(ch_ptr, inputIdx); + if (std::find(mInputNodes.cbegin(), mInputNodes.cend(), val) == mInputNodes.cend()) { + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); + } + } + ++inputIdx; } } - if (remove) { - mInputNodes.erase(ch_ptr); - } } } -} + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); -void Aidge::GraphView::removeInputNode(const std::string nodeName) { - std::map<std::string, std::shared_ptr<Node>>::iterator it = - mNodeRegistry.find(nodeName); - if (it != mNodeRegistry.end()) { - const std::shared_ptr<Node> val = (*it).second; - if (mInputNodes.find(val) != mInputNodes.end()) { - mInputNodes.erase(val); + // Check if node outputs were outputs for the GraphView and remove them from the list if so + for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) { + const auto val = std::make_pair(deletedNode, outputIdx); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); + + if (iter != mOutputNodes.cend()) { + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } } } -} -void Aidge::GraphView::removeOutputNode(const std::string nodeName) { - std::map<std::string, std::shared_ptr<Node>>::iterator it = - mNodeRegistry.find(nodeName); - if (it != mNodeRegistry.end()) { - const std::shared_ptr<Node> val = (*it).second; - if (mOutputNodes.find(val) != mOutputNodes.end()) { - mOutputNodes.erase(val); + // Add parent node outputs that become GraphView output following the removal of the node + // Outputs addition order follows deletedNode inputs order + for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { + if (mNodes.find(parent) != mNodes.end()) { + IOIndex_t outputIdx = 0; + for (auto orderedChilds : parent->getOrderedChildren()) { + bool noInsideConnection = true; + for (auto ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.end()) { + noInsideConnection = false; + break; + } + } + + if (noInsideConnection) { + const auto val = std::make_pair(parent, outputIdx); + if (std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val) == mOutputNodes.cend()) { + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); + } + } + ++outputIdx; + } } } } + std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); @@ -759,46 +1049,132 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone std::map<NodePtr, NodePtr> oldToNewNodes; for (const std::shared_ptr<Node> &node_ptr : mNodes) { - oldToNewNodes[node_ptr] = cloneNode(node_ptr); + auto clonedNode = cloneNode(node_ptr); + if (clonedNode == nullptr) { + AIDGE_ASSERT(node_ptr->getChildren().size() <= 1, "deleted nodes in GraphView::clone() cannot have multiple children"); + AIDGE_ASSERT(node_ptr->nbData() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents"); + } + oldToNewNodes[node_ptr] = clonedNode; } // For each node, convert old node -> new node connections for (auto &oldToNewNode : oldToNewNodes) { - if (oldToNewNode.second == nullptr) + if (oldToNewNode.second == nullptr) { continue; // deleted node - - // Add new node to new GraphView - newGraph->add(oldToNewNode.second, false); + } // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr size_t parentId = 0; for (auto parent : oldToNewNode.first->inputs()) { - while (oldToNewNodes[parent.first] == nullptr) { - // Find next valid parent in line, going backward in the graph - assert(parent.first->nbData() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); - const auto& parents = parent.first->inputs(); + if (parent.first != nullptr) { + while (oldToNewNodes[parent.first] == nullptr) { + // Find next valid parent in line, going backward in the graph + AIDGE_INTERNAL_ASSERT(parent.first->getChildren().size() == 1); + AIDGE_INTERNAL_ASSERT(parent.first->nbData() <= 1); + const auto& parents = parent.first->dataInputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + parent = parents[0]; + } + else { + break; + } + } - if (!parents.empty() && parents[0].first != nullptr // a valid parent exists - && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView - { - parent = parents[0]; + if (oldToNewNodes[parent.first]) { + AIDGE_INTERNAL_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs()); + oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); } - else { - break; + } + + ++parentId; + } + } + + // Once connected, add each new nodes to new GraphView + // This has to be done in a second step to ensure that new GraphView inputs/outputs + // are properly set (otherwise, some node's inputs/outputs may be wrongly registered as + // GraphView inputs/outputs because not yet connected to other nodes) + if (oldToNewNodes[mRootNode] != nullptr) { + // Add root node first if is still exists! + newGraph->add(oldToNewNodes[mRootNode], false); + } + + for (auto &oldToNewNode : oldToNewNodes) { + if (oldToNewNode.second == nullptr) + continue; // deleted node + + newGraph->add(oldToNewNode.second, false); + } + + // Update cloned graph inputs/outputs order to match initial graph order + auto newInputNodes = mInputNodes; + for (auto it = newInputNodes.begin(); it != newInputNodes.end(); ) { + // If input node was removed, find next valid input + while (oldToNewNodes[it->first] == nullptr) { + // Removed node should have only one connected output, otherwise cloning is invalid + AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() <= 1); + bool found = false; + + if (it->first->getChildren().size() == 1) { + auto child = *it->first->getChildren().begin(); + + std::size_t inputIdx = 0; + for (auto parent : child->getParents()) { + if (parent == it->first) { + it->first = child; + it->second = inputIdx; + found = true; + break; + } + ++inputIdx; } } - if (oldToNewNodes[parent.first]) { - oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); + if (!found) { + break; } + } - ++parentId; + if (oldToNewNodes[it->first] == nullptr) { + it = newInputNodes.erase(it); + } + else { + it->first = oldToNewNodes[it->first]; + ++it; } } + newGraph->setOrderedInputs(newInputNodes); + + auto newOutputNodes = mOutputNodes; + for (auto it = newOutputNodes.begin(); it != newOutputNodes.end(); ) { + // If output node was removed, find previous valid output + while (oldToNewNodes[it->first] == nullptr) { + // Removed node should have only one connected data input, otherwise cloning is invalid + AIDGE_INTERNAL_ASSERT(it->first->nbData() <= 1); + auto parents = it->first->dataInputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + *it = parents[0]; + } + else { + break; + } + } - // Update OutputNodes/inputNodes - newGraph->updateInputNodes(); - newGraph->updateOutputNodes(); + if (oldToNewNodes[it->first] == nullptr) { + it = newOutputNodes.erase(it); + } + else { + it->first = oldToNewNodes[it->first]; + ++it; + } + } + newGraph->setOrderedOutputs(newOutputNodes); return newGraph; } diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5a7b05e469daab10a4abd468177a3ad137096f63..6f0cc55159b1cc72b87bb34230376eb140b7ab8a 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -11,22 +11,25 @@ #include "aidge/graph/Node.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/operator/Producer.hpp" #include <memory> #include <vector> + +#include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/utils/Types.h" Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) : mName(name), mOperator(op), - mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), - mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), - std::vector<std::weak_ptr<Node>>())), - mIdInChildren( - std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), - mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { + mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), + nullptr)), + mChildren(std::vector<std::vector<std::weak_ptr<Node>>>( + static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())), + mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), + std::vector<IOIndex_t>())), + mIdOutParents( + std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { // ctor } @@ -34,14 +37,15 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// -Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { +Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) { assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); - for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { - assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); - (void) input; // avoid unused warning + for (std::pair<std::shared_ptr<Node>, IOIndex_t>& input : inputs()) { + assert((gk_IODefaultIndex == input.second) && + "At least one input connection is not free.\n"); + (void)input; // avoid unused warning } IOIndex_t i = 0; - for (const Connector &ctor : ctors) { + for (const Connector& ctor : ctors) { if (ctor.node() != nullptr) { // ctor must be associated with a node ctor.node()->addChild(shared_from_this(), ctor.index(), i++); } @@ -53,7 +57,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { // INNER /////////////////////////////////////////////////////// -void Aidge::Node::setName(const std::string &name) { mName = name; } +void Aidge::Node::setName(const std::string& name) { mName = name; } /////////////////////////////////////////////////////// // OPERATORS @@ -92,8 +96,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { return nbFreeDataIn; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::dataInputs() const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() + const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { @@ -104,15 +108,15 @@ Aidge::Node::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); for (std::size_t i = 0; i < nbInputs(); ++i) { - res[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); + res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } -// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { +// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> +// tensor) { // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); // if (mParents[idx] != nullptr) { // mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); @@ -128,20 +132,21 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::Node::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = - std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); + std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>( + mIdInChildren.size()); for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { listOutputs[i] = output(static_cast<IOIndex_t>(i)); } return listOutputs; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::output(Aidge::IOIndex_t outId) const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output( + Aidge::IOIndex_t outId) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { - listOutputs[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); + listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), + mIdInChildren[outId][i]); } return listOutputs; } @@ -180,7 +185,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) // TOPOLOGY /////////////////////////////////////////////////////// -void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { +void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + const IOIndex_t otherInId) { assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second != gk_IODefaultIndex) { @@ -196,33 +202,41 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou } void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { - assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + assert((otherInId.second < otherInId.first->nbInputs()) && + "Other graph input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { - assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node + assert((inNodes.find(otherInId.first) != + inNodes.end())); // assert it really is an input node addChildOp(otherInId.first, outId, otherInId.second); } } -void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { - otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); - addChildOp(otherNode, outId, otherInId); +void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + IOIndex_t otherInId) { + if (otherNode) { + otherInId = + (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); + addChildOp(otherNode, outId, otherInId); + } } void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { if (!otherInId.first) { assert((otherView->inputNodes().size() == 1U) && "Specify an input Node for the GraphView. More or less than one " "Node is not explicit."); otherInId.first = *(otherView->inputNodes().begin()); } - otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); + otherInId.second = (otherInId.second != gk_IODefaultIndex) + ? otherInId.second + : otherInId.first->getFirstFreeDataInput(); addChildView(otherView, outId, otherInId); } @@ -255,8 +269,8 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) { std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const auto &childrenOfOneOutput : mChildren) { - for (const auto &oneChild : childrenOfOneOutput) { + for (const auto& childrenOfOneOutput : mChildren) { + for (const auto& oneChild : childrenOfOneOutput) { children.insert(oneChild.lock()); } } @@ -264,7 +278,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { - std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); + std::vector<std::vector<std::shared_ptr<Node>>> children = + std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { children[outId] = getChildren(outId); } @@ -273,14 +288,16 @@ std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedCh std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { assert((outId < nbOutputs()) && "Output index out of bound."); - std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); + std::vector<std::shared_ptr<Node>> children = + std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { - children.push_back(mChildren[outId][i].lock()); - } + children.push_back(mChildren[outId][i].lock()); + } return children; } -bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { +bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, + const Aidge::IOIndex_t outId) { assert((outId < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { @@ -301,7 +318,8 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); if (parent.first) { // number of children linked to the parent's output - while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} + while (parent.first->removeChild(shared_from_this(), parent.second) == true) { + } } // every reference to this object as child has been removed // removing reference to parents. @@ -316,24 +334,23 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to - for (auto& graph : views()) { - // if keeping connections with LEarnable Parameters, then also remove them from graph - graph->remove(shared_from_this(), !includeLearnableParam); - } + // for (auto& graph : views()) { + // // if keeping connections with LEarnable Parameters, then also remove them from graph + // graph->remove(shared_from_this(), !includeLearnableParam); + // } } - /////////////////////////////////////////////////////// - // CLONE - /////////////////////////////////////////////////////// +/////////////////////////////////////////////////////// +// CLONE +/////////////////////////////////////////////////////// Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { return std::make_shared<Node>(mOperator, mName); } Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { - std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) - ? mOperator - : mOperator->clone(); + std::shared_ptr<Operator> op = + (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone(); return std::make_shared<Node>(op, mName); } @@ -342,27 +359,25 @@ Aidge::NodePtr Aidge::Node::clone() const { return std::make_shared<Node>(mOperator->clone(), mName); } - -std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ - +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::NodePtr> nodeSee) { std::set<Aidge::NodePtr> out; nodeSee.insert(shared_from_this()); - if(delta == 0) { + if (delta == 0) { out.insert(shared_from_this()); - }else if (delta > 0){ - for (const NodePtr& node : getChildren()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + } else if (delta > 0) { + for (const NodePtr& node : getChildren()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta - 1, nodeSee)) { out.insert(ch); } } } - }else{ - for (const NodePtr& node : getParents()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + } else { + for (const NodePtr& node : getParents()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta + 1, nodeSee)) { out.insert(pr); } } diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f30ad6e25b81e1ce7768fcc201ddf00c2226eebf --- /dev/null +++ b/src/graph/Testing.cpp @@ -0,0 +1,133 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> // std::shuffle, std::transform +#include <cstddef> +#include <memory> +#include <numeric> // std::iota +#include <random> // std::binomial_distribution, std::mt19937, std::discrete_distribution +#include <string> +#include <utility> // std::pair +#include <vector> + +#include "aidge/graph/Testing.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/utils/Types.h" + +std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomGraph::gen(std::mt19937::result_type seed, std::size_t nbNodes) const { + std::mt19937 gen(seed); + std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn); + std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut); + std::binomial_distribution<> dLink(1, density); + std::discrete_distribution<> dType(typesWeights.begin(), typesWeights.end()); + + std::vector<std::pair<IOIndex_t, IOIndex_t>> nbIOs; + std::vector<std::string> nodesType; + for (std::size_t i = 0; i < nbNodes; ++i) { + const auto nbIn = 1 + dIn(gen); + nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen))); + nodesType.push_back(types[dType(gen)]); + } + + std::vector<std::size_t> nodesSeq(nbNodes); + std::iota(nodesSeq.begin(), nodesSeq.end(), static_cast<std::size_t>(0)); + // Don't use gen or seed here, must be different each time! + std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}())); + + std::vector<NodePtr> nodes(nbNodes, nullptr); + for (auto idx : nodesSeq) { + const std::string name = nodesType[idx] + std::to_string(idx); + nodes[idx] = GenericOperator(nodesType[idx], nbIOs[idx].first, 0, nbIOs[idx].second, name); + } + + for (std::size_t i = 0; i < nbNodes; ++i) { + for (std::size_t j = (acyclic) ? i + 1 : 0; j < nbNodes; ++j) { + if (i == j) { + // Do not connected node to itself in case of cyclic graph! + continue; + } + + for (std::size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { + for (std::size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { + if (dLink(gen)) { + // Warning: connections can be set multiple time for the + // same node input! In this case, the previous connection + // is overwritten. This is the expected behavior. + nodes[i]->addChild(nodes[j], outId, inId); + if (nodes[i]->type() == omitType || nodes[j]->type() == omitType) { + // Let nodes[i]->addChild() overwrite the previous connection. + // Now we remove the new one! + nodes[i]->removeChild(nodes[j], outId); + nodes[j]->removeParent(inId); + } +/* + // Alternative: only add child if no node is omitted + // and remove the potential previous connection, like this: + if (nodes[i]->type() != omitType && nodes[j]->type() != omitType) { + nodes[i]->addChild(nodes[j], outId, inId); + } + else { + const auto prevIn = nodes[j]->input(inId); + + if (prevIn.first != nullptr) { + prevIn.first->removeChild(nodes[j], prevIn.second); + nodes[j]->removeParent(inId); + } + } +*/ + break; + } + } + } + } + } + + NodePtr rootNode = nullptr; + std::set<NodePtr> nodesSet; + for (std::size_t i = 0; i < nbNodes; ++i) { + if (nodes[i]->type() != omitType) { + if (rootNode == nullptr) { + rootNode = nodes[i]; + } + nodesSet.insert(nodes[i]); + } + } + + return std::make_pair(rootNode, nodesSet); +} + +std::string Aidge::nodePtrToType(NodePtr node) { + return node->type(); +} + +std::string Aidge::nodePtrToName(NodePtr node) { + return node->name(); +} + +std::set<std::string> Aidge::nodePtrTo(const std::set<NodePtr>& nodes, + std::string(*nodeTo)(NodePtr)) +{ + std::set<std::string> nodesStr; + std::transform(nodes.begin(), nodes.end(), std::inserter(nodesStr, nodesStr.begin()), nodeTo); + return nodesStr; +} + +std::vector<std::pair<std::string, Aidge::IOIndex_t>> Aidge::nodePtrTo( + const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes, + std::string(*nodeTo)(NodePtr)) +{ + std::vector<std::pair<std::string, IOIndex_t>> nodesStr; + std::transform(nodes.begin(), nodes.end(), std::back_inserter(nodesStr), + [nodeTo](const std::pair<NodePtr, IOIndex_t>& node) { + return std::make_pair(nodeTo(node.first), node.second); + }); + return nodesStr; +} diff --git a/src/operator/Add.cpp b/src/operator/Add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e638fd86da487565a89760925e45339213fa8f9 --- /dev/null +++ b/src/operator/Add.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Add.hpp" + +const std::string Aidge::Add_Op::Type = "Add"; \ No newline at end of file diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eafcd126480df6da2c0127bdbb896d3ce98d0e0a --- /dev/null +++ b/src/operator/Concat.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Concat.hpp" + +const std::string Aidge::Concat_Op::Type = "Concat"; \ No newline at end of file diff --git a/src/operator/Div.cpp b/src/operator/Div.cpp index 273eac2e8fa9623e617d1be204ac2ae46d8da02d..85db3ac6ef66c837c86dbece288185deaca88ba6 100644 --- a/src/operator/Div.cpp +++ b/src/operator/Div.cpp @@ -11,6 +11,7 @@ #include <cassert> #include <cstddef> +#include <string> #include <vector> #include <utility> @@ -19,6 +20,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Div_Op::Type = "Div"; + void Aidge::Div_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32114f5bf9e0d160db9fdc2d1971481be0b4e703 --- /dev/null +++ b/src/operator/FC.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/FC.hpp" + +const std::string Aidge::FC_Op::Type = "FC"; \ No newline at end of file diff --git a/src/operator/Identity.cpp b/src/operator/Identity.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f57906dd4f3564b52cde16236bda87370e8f86d7 --- /dev/null +++ b/src/operator/Identity.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Identity.hpp" + +const std::string Aidge::Identity_Op::Type = "Identity"; \ No newline at end of file diff --git a/src/operator/LeakyReLU.cpp b/src/operator/LeakyReLU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32e050ee1595cf83b5cd0ffbfeba6153dc2243af --- /dev/null +++ b/src/operator/LeakyReLU.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/LeakyReLU.hpp" + +const std::string Aidge::LeakyReLU_Op::Type = "LeakyReLU"; \ No newline at end of file diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..666ed3921ed1190a91935bd9f38303e23963d912 --- /dev/null +++ b/src/operator/MatMul.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/MatMul.hpp" + +const std::string Aidge::MatMul_Op::Type = "MatMul"; \ No newline at end of file diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index bbc921d3c7b334223b2a92a8fbfee1ffae9c10e1..530357085a16ca3e834669cebd2d26882ca8ddab 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -12,63 +12,19 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" -Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, - std::vector<NodePtr> inputNodes, - std::vector<NodePtr> outputNodes) +Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph) : OperatorTensor(type, graph->dataInputs().size(), (graph->inputs().size() - graph->dataInputs().size()), graph->outputs().size()), mGraph(graph) { - // Fill inputsNodes and outputsNodes when there is no ambiguity - if (inputNodes.empty()) { - AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping"); - inputNodes.push_back(*mGraph->inputNodes().begin()); + mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); + for (std::size_t i = 0; i < mInputs.size(); ++i) { + mInputs[i] = std::make_shared<Tensor>(); } - - if (outputNodes.empty()) { - AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping"); - outputNodes.push_back(*mGraph->outputNodes().begin()); - } - - AIDGE_ASSERT(mGraph->inputNodes().size() == inputNodes.size(), "wrong number of specified input nodes"); - AIDGE_ASSERT(mGraph->outputNodes().size() == outputNodes.size(), "wrong number of specified output nodes"); - - // Identify inputs that are outside the micro-graph - for (const auto& inputNode : inputNodes) { - AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph"); - const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = - inputNode->inputs(); - - int inputIdx = 0; // input idx relative to the current node - for (const auto& in : inputNodeinputs) { - if (in.first == nullptr || !mGraph->inView(in.first)) { - // The input is not connected inside the micro-graph - // (no connection to this input or connection outside the micro-graph) - // => it is therefore an input for the meta-operator - mInputOps.push_back(std::make_pair(std::dynamic_pointer_cast<OperatorTensor>(inputNode->getOperator()), inputIdx)); - } - - ++inputIdx; - } - } - - // The outputs of the output nodes are also the outputs of the meta-operator - for (const auto& outputNode : outputNodes) { - AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph"); - const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs = - outputNode->outputs(); - - for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) { - mOutputOps.push_back(std::make_pair(std::dynamic_pointer_cast<OperatorTensor>(outputNode->getOperator()), outputIdx)); - } - } - - - AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); - AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); // Associate outputs to micro-graph outputs for custom implementation - for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { - const auto& outputOp = mOutputOps[outputIdx]; - mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedOutputs().size()); + for (size_t outputIdx = 0; outputIdx < mOutputs.size(); ++outputIdx) { + const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; + mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second)); } } @@ -77,8 +33,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI return mImpl->getNbRequiredData(inputIdx); } else { - const auto& inputOp = mInputOps[inputIdx]; - return inputOp.first->getNbRequiredData(inputOp.second); + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); } } @@ -87,8 +43,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co return mImpl->getNbConsumedData(inputIdx); } else { - const auto& inputOp = mInputOps[inputIdx]; - return inputOp.first->getNbConsumedData(inputOp.second); + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); } } @@ -97,8 +53,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c return mImpl->getNbProducedData(outputIdx); } else { - const auto& outputOp = mOutputOps[outputIdx]; - return outputOp.first->getNbProducedData(outputOp.second); + const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; + return outputOp.first->getOperator()->getNbProducedData(outputOp.second); } } diff --git a/src/operator/Mul.cpp b/src/operator/Mul.cpp index 2e3e77288bf1e0613f0aa572e3c50e94599a902f..bc268263e8a6e2ec7c9944faa31da84dc50c4f53 100644 --- a/src/operator/Mul.cpp +++ b/src/operator/Mul.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Mul_Op::Type = "Mul"; + void Aidge::Mul_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp index c213a47a4a590026c07625aeb532d303ca8dbced..de1f0c3694f51fbd5b365573f61d3e3e2b9109ff 100644 --- a/src/operator/Pow.cpp +++ b/src/operator/Pow.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Pow_Op::Type = "Pow"; + void Aidge::Pow_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..443f2fa7d8a60cd25ccb622f2dad5b4926b88eea --- /dev/null +++ b/src/operator/Producer.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Producer.hpp" + +const std::string Aidge::Producer_Op::Type = "Producer"; \ No newline at end of file diff --git a/src/operator/ReLU.cpp b/src/operator/ReLU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f7874acfe7d865ea8c56d4bca02b51864480df6 --- /dev/null +++ b/src/operator/ReLU.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/ReLU.hpp" + +const std::string Aidge::ReLU_Op::Type = "ReLU"; \ No newline at end of file diff --git a/src/operator/Scaling.cpp b/src/operator/Scaling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c121e1268c1e1a62f793f38c6d816e7c6b48c25 --- /dev/null +++ b/src/operator/Scaling.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Scaling.hpp" + +const std::string Aidge::Scaling_Op::Type = "Scaling"; \ No newline at end of file diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a25e290dff35e4257d486613a5fe06894119d367 --- /dev/null +++ b/src/operator/Slice.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Slice.hpp" + +const std::string Aidge::Slice_Op::Type = "Slice"; \ No newline at end of file diff --git a/src/operator/Softmax.cpp b/src/operator/Softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e88ff4bb4ec6e2cb1357d578c2d07cc4edcb59f7 --- /dev/null +++ b/src/operator/Softmax.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Softmax.hpp" + +const std::string Aidge::Softmax_Op::Type = "Softmax"; \ No newline at end of file diff --git a/src/operator/Sqrt.cpp b/src/operator/Sqrt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbcaba42619762f8fd00bb2f6e0aa0de11d92960 --- /dev/null +++ b/src/operator/Sqrt.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Sqrt.hpp" + +const std::string Aidge::Sqrt_Op::Type = "Sqrt"; \ No newline at end of file diff --git a/src/operator/Sub.cpp b/src/operator/Sub.cpp index 8175f1b7ae5bb5eccd36267c1d739f764bd3c236..639eaf798c1c2a9a6685e8b8d2c4a2cb00a4b57a 100644 --- a/src/operator/Sub.cpp +++ b/src/operator/Sub.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Sub_Op::Type = "Sub"; + void Aidge::Sub_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/unit_tests/graph/Test_Connector.cpp b/unit_tests/graph/Test_Connector.cpp index a7cee610e0014dc024271a008ed964fa67d367ea..79acce9281039f9f3c67b7235d8999b6c7173685 100644 --- a/unit_tests/graph/Test_Connector.cpp +++ b/unit_tests/graph/Test_Connector.cpp @@ -16,6 +16,7 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" +#include "aidge/graph/Testing.hpp" using namespace Aidge; @@ -112,7 +113,9 @@ TEST_CASE("GraphGeneration from Connector", "[GraphView]") { x= (*node09)({x}); x = (*node10)({a, x}); std::shared_ptr<GraphView> gv = generateGraph({x}); - gv->save("GraphGeneration"); + // gv->save("GraphGeneration"); + REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_matmul1", 0}})); } TEST_CASE("Connector connection GraphView", "[Connector]") { @@ -131,6 +134,9 @@ TEST_CASE("Connector connection GraphView", "[Connector]") { GenericOperator("g_conv3", 1, 0, 1), GenericOperator("g_matmul1", 2, 0, 1) }); + REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv1", 0}})); + REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_matmul1", 0}})); + x = (*prod)({}); x = (*g)({x}); std::shared_ptr<GraphView> g2 = generateGraph({x}); @@ -151,10 +157,14 @@ TEST_CASE("Connector connection GraphView", "[Connector]") { GenericOperator("g_concat", 3, 0, 1), GenericOperator("g_conv3", 1, 0, 1) }); + REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"ElemWise", 0}, {"ElemWise", 1}, {"ElemWise", 2}})); + REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv3", 0}})); x = (*g)({x, y, z}); std::shared_ptr<GraphView> gv = generateGraph({x}); - gv->save("MultiInputSequentialConnector"); + REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv3", 0}})); + // gv->save("MultiInputSequentialConnector"); REQUIRE(gv->inputNodes().size() == 0U); } } @@ -169,7 +179,9 @@ TEST_CASE("Connector Mini-graph", "[Connector]") { } y = (*GenericOperator("ElemWise", 2, 0, 1))({y, x}); std::shared_ptr<GraphView> g = generateGraph({y}); - g->save("TestGraph"); + REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"ElemWise", 0}})); + // g->save("TestGraph"); } TEST_CASE("Structural descrition - Sequential", "[GraphView]") { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index bb726bd4d92b5674d0e19ea3138e165e1329959a..acbea04a27a0b6be22105bb73fda53fedf621235 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -9,6 +9,7 @@ * ********************************************************************************/ +#include <algorithm> // std::sort #include <cassert> #include <map> #include <memory> @@ -20,20 +21,199 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Testing.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" using namespace Aidge; -TEST_CASE("[core/graph] GraphView(Constructor)") { +TEST_CASE("genRandomGraph", "[GraphView][randomGen]") { + const size_t nbTests = 100; + size_t nbUnicity = 0; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + const auto g2 = std::make_shared<GraphView>("g2"); + const bool unicity2 = g2->add(randGraph.gen(seed, 10)); + + // g1->save("./genRandomGraph1"); + // g2->save("./genRandomGraph2"); + + REQUIRE(unicity1 == unicity2); + + if (unicity1) { + REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); + ++nbUnicity; + + // Check that inputs/outputs are the same regardless of the order + auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); + auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); + auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); + auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); + std::sort(orderedInputs1.begin(), orderedInputs1.end()); + std::sort(orderedInputs2.begin(), orderedInputs2.end()); + std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); + std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); + + REQUIRE(orderedInputs1 == orderedInputs2); + REQUIRE(orderedOutputs1 == orderedOutputs2); + REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); + } + } + + printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests); +} + +TEST_CASE("clone", "[GraphView][clone]") { + const size_t nbTests = 100; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + const auto g1 = std::make_shared<GraphView>("g1"); + g1->add(randGraph.gen(seed, 10)); + // g1 -> save("GraphView_clone"); + const auto g2 = g1->clone(); + + REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + } +} + +NodePtr nodeDel(NodePtr node) { + if (node->type() == "DelFictive") { + return nullptr; + } + return node->clone(); +} + +TEST_CASE("clone_with_delete", "[GraphView][cloneDelete]") { + const size_t nbTests = 100; + size_t nbClonedWithDelete = 0; + + // Note: initial seed is chosen such that for nbTests=100, the generated + // graphs keep the same inputs/outputs despites the deleted nodes + // (meaning the deleted nodes are not input/output of the graph). + // Otherwise, the last two REQUIRE are not garanteed to be true! + // Warning: distributions are not required to behave the same way by the standard, + // therefore the seed has to work for both GCC and MSVC... + // See https://stackoverflow.com/questions/38532927/why-gcc-and-msvc-stdnormal-distribution-are-different + std::mt19937::result_type seed(243); + + for (int test = 0; test < nbTests; ++test) { + RandomGraph randGraph; + randGraph.types = {"Fictive", "DelFictive"}; + randGraph.typesWeights = {0.9, 0.1}; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + + if (unicity1) { + randGraph.omitType = "DelFictive"; + const auto g2 = std::make_shared<GraphView>("g2"); + const bool unicity2 = g2->add(randGraph.gen(seed, 10)); + + // g1->save("./clone_with_delete1"); + // g2->save("./clone_with_delete2"); + + try { + const auto gCloned = g1->cloneCallback(&nodeDel); + + REQUIRE(nodePtrTo(gCloned->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(gCloned->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + REQUIRE(nodePtrTo(gCloned->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + ++nbClonedWithDelete; + } + catch (const std::runtime_error& error) { + // pass + } + } + + ++seed; + } + + printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests); +} + +TEST_CASE("remove", "[GraphView][remove]") { + const size_t nbTests = 100; + size_t nbTested = 0; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + randGraph.types = {"Fictive", "DelFictive"}; + randGraph.typesWeights = {0.8, 0.2}; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + + if (unicity1) { + // g1->save("./remove1_before"); + const auto nodes = g1->getNodes(); + int step = 1; + for (auto node : nodes) { + if (node->type() == "DelFictive") { + g1->remove(node, false); + // g1->save("./remove1_after" + std::to_string(step)); + step++; + } + } + + randGraph.omitType = "DelFictive"; + const auto g2 = std::make_shared<GraphView>("g2"); + g2->add(randGraph.gen(seed, 10)); + + // g1->save("./remove1"); + // g2->save("./remove2"); + + REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + // Order not garanteed, because when a node is removed, it can create new GraphView inputs/outputs + // Their order thus depends on the deletion order! + //REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + //REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + + // Check that inputs/outputs are the same regardless of the order + auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); + auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); + auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); + auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); + std::sort(orderedInputs1.begin(), orderedInputs1.end()); + std::sort(orderedInputs2.begin(), orderedInputs2.end()); + std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); + std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); + + REQUIRE(orderedInputs1 == orderedInputs2); + REQUIRE(orderedOutputs1 == orderedOutputs2); + ++nbTested; + } + } + + printf("nbTested = %zu/%zu\n", nbTested, nbTests); +} + +TEST_CASE("[core/graph] GraphView(Constructor)", "[GraphView][constructor()]") { std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1"); REQUIRE(g0 != nullptr); REQUIRE(g1 != nullptr); } -TEST_CASE("[core/graph] GraphView(add)") { +TEST_CASE("[core/graph] GraphView(add)", "[GraphView][add]") { SECTION("Node alone") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); @@ -48,6 +228,9 @@ TEST_CASE("[core/graph] GraphView(add)") { g->add(GOp5); std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 1, 1, "Gop6"); g->add(GOp6); + // g->save("node_alone"); + REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop3", 0}, {"Gop4", 0}, {"Gop5", 0}, {"Gop6", 0}, {"Gop6", 1}})); + REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop2", 0}, {"Gop5", 0}, {"Gop6", 0}})); } SECTION("Several Nodes") { @@ -58,10 +241,14 @@ TEST_CASE("[core/graph] GraphView(add)") { GOp1parent->addChild(GOp1, 0, 0); g->add(GOp1); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}})); // there should be no deplicates g->add(GOp1); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}})); } SECTION("Initializer list ofr Node") { @@ -221,7 +408,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } } - SECTION("disconnect data iput + learnable parameters") { + SECTION("disconnect data input + learnable parameters") { std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index ef0c4e7f72d3148eccb97896a3d6e3d5ae5ad6e1..68e2d4d4d5b4fe1b40f83c087eb61c7865d3db75 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -14,6 +14,7 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Testing.hpp" #include <cstddef> using namespace Aidge; @@ -26,8 +27,8 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { REQUIRE(microGraph->getNodes().size() == 2); REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias) - // Order not garanteed by the GraphView - //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad"); + REQUIRE(nodePtrTo(microGraph->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"Pad", 0}, {"Conv", 1}, {"Conv", 2}})); + REQUIRE(nodePtrTo(microGraph->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"Conv", 0}})); REQUIRE(microGraph->outputNodes().size() == 1); REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); REQUIRE(op->nbInputs() == 3); @@ -43,8 +44,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { REQUIRE(opTensor->outputDimsForwarded()); REQUIRE(std::static_pointer_cast<Tensor>(opTensor->getRawOutput(0))->dims() == std::vector<size_t>({2,3,5,5})); REQUIRE(std::static_pointer_cast<Tensor>(opTensor->getRawInput(0)) == myInput); - // Order not garanteed by the GraphView - //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getRawInput(0) == myInput); + REQUIRE(microGraph->getOrderedInputs()[0].first->getOperator()->getRawInput(0) == myInput); REQUIRE(opTensor->getRawOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getRawOutput(0)); //op->getOperator()->updateConsummerProducer(); // require implementation diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp deleted file mode 100644 index 5d9c02d5582e3c56aba9d374d7087946c7d94bde..0000000000000000000000000000000000000000 --- a/unit_tests/recipies/Test_FuseBatchNorm.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ -/* -#include <catch2/catch_test_macros.hpp> -#include <set> - - -//#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" -//#include "aidge/backend/cpu/operator/ConvImpl.hpp" - - - -#include "aidge/operator/Conv.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/graph/OpArgs.hpp" -#include "aidge/operator/BatchNorm.hpp" -#include "aidge/utils/Recipies.hpp" - -//#include "aidge/backend/TensorImpl.hpp" -//#include "aidge/backend/cpu.hpp" -//#include "aidge/" - -#include <cstddef> - - -namespace Aidge { - - - TEST_CASE("[FuseBatchNorm] conv") { - auto g1 = Sequential({ - Producer({16, 3, 224, 224}, "dataProvider"), - Conv(3, 32, {3, 3}, "conv1"), - BatchNorm<2>() - }); - - g1->setDataType(DataType::Float32); - g1->setBackend("cpu"); - g1->forwardDims(); - - // std::set<std::string> availableBackends = Tensor::getAvailableBackends(); - // if (availableBackends.find("cpu") != availableBackends.end()){ - // g1->setBackend("cpu"); - // newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); - // }else{ - // printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); - // } - - fuseBatchNorm(g1); - - SECTION("Check resulting nodes") { - // REQUIRE(g1->getNodes().size() == 2); - // REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); - // REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); - // REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); - } - } - -} -*/ \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 0c65db98917e33a11f4b7bac678b271b1a10fb94..968826230dfdf85290ee377aee155e06855c4b28 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -61,7 +61,6 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // Transform GraphView inplace fuseMulAdd(g); - g->save("bonjour"); // Check new GraphView std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); diff --git a/unit_tests/recipies/Test_removeFlatten.cpp b/unit_tests/recipies/Test_removeFlatten.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d0ff29dae19ba2dd8009441c39da53bf44378f0 --- /dev/null +++ b/unit_tests/recipies/Test_removeFlatten.cpp @@ -0,0 +1,49 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/recipies/Recipies.hpp" + +namespace Aidge { + + +TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { + // generate the original GraphView + auto flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten"); + auto fc = FC(10, 50, "myFC"); + + flatten -> addChild(fc); + + auto g = std::make_shared<GraphView>(); + g->add({fc, flatten}); + + // Check original graph + // g -> save("before_remove_flatten"); + + // use recipie + removeFlatten(g); + + // Check transformed graph + // g -> save("after_remove_flatten"); + + REQUIRE(g->getOrderedInputs().size() == 1); + REQUIRE(g->getOrderedOutputs().size() == 1); + REQUIRE(g->getOrderedInputs()[0].first == fc); + REQUIRE(g->getOrderedOutputs()[0].first == fc); +} + +} // namespace Aidge \ No newline at end of file