diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 9bc5a4cfbc3de3983dfd50f8e54f832da4a47a5a..e22bde7c6a2fee047ffe6fb0b570388d1ad67d7d 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -292,7 +292,8 @@ public: /** * @brief Include a set of Nodes to the current GraphView object. - * The second element in the otherNodes pair is the start node. + * The first element of the otherNodes pair is the start node and + * the second is the remaining nodes to add. * @param otherNodes * @param includeLearnableParam * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). @@ -391,17 +392,24 @@ public: IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx); - /** * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible. * Both sets should include all the necessary Producers. - * @details Replaced Nodes are removed from any GraphView pointing at them all. - * The oldNodes set should have only one input/output - * Tensor for automatic connections of newNodes set. - * @param oldNodes actual set of shared_ptr<Node> to replace. - * @param newNodes new set of shared_ptr<Node>. - * @return true - * @return false + * @details There are 3 cases of replacement: + * Case 1: same number of input/output connections for oldNodes and newNodes sets. + * - input/output connections are replacated according to their IDs. + * Case 2: different number of input/output connections for oldNodes and newNodes sets. + * - only a single parent/child node for the newNodes set, every input/output is + * connected to it. + * - several parents/children nodes for newNodes set => impossible to know, return false + * Case 3: newNodes set is empty + * - same number of input/output connections in oldNodes, parents and children are linked according + * to these connections IDs + * - different number of input/output connections in oldNodes => return false + * @param oldNodes + * @param newNodes + * @return true replacement has been performed + * @return false no replacement has been performed */ static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); @@ -443,14 +451,6 @@ public: */ IOIndex_t getNbFreeDataInputs() const; -protected: - /** - * @brief Update inputs/outputs of the GraphView, with no particular order. - * This function DOES NOT preserve inputs/outputs order and should NOT BE USED. - * It is here only to leave time to adapt the replace() function. - */ - void updateInputsOutputsNodes_DEPRECATED(); - private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT diff --git a/include/aidge/graph/Testing.hpp b/include/aidge/graph/Testing.hpp index 06bbd0f554bab0b9a5ef123dc287c2397258fae3..ecacdf66298cb83c919ad447c82463206836a3e9 100644 --- a/include/aidge/graph/Testing.hpp +++ b/include/aidge/graph/Testing.hpp @@ -12,11 +12,11 @@ #ifndef AIDGE_CORE_GRAPH_TESTING_H_ #define AIDGE_CORE_GRAPH_TESTING_H_ +#include <cstddef> #include <vector> #include <set> -#include <random> -#include <algorithm> -#include <utility> +#include <random> // std::mt19937::result_type +#include <utility> // std::pair #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" @@ -31,11 +31,11 @@ struct RandomGraph { /// @brief Connection density (between 0 and 1) float density = 0.5; /// @brief Max number of inputs per node (regardless if they are connected or not) - size_t maxIn = 5; + std::size_t maxIn = 5; /// @brief Average number of inputs per node (regardless if they are connected or not) float avgIn = 1.5; /// @brief Max number of outputs per node (regardless if they are connected or not) - size_t maxOut = 2; + std::size_t maxOut = 2; /// @brief Average number of outputs per node (regardless if they are connected or not) float avgOut = 1.1; /// @brief List of node types that should be generated in the graph (as GenericOperator) @@ -47,11 +47,11 @@ struct RandomGraph { /** * Generate a DAG according to the parameters of the class. - * @param seed Random seed. For an identical seed, an identical topology is + * @param seed Random seed. For an identical seed, an identical topology is * generated, but with a random node ordering in the return set of nodes. * @param nbNodes Number of nodes to generate. */ - std::pair<NodePtr, std::set<NodePtr>> gen(std::mt19937::result_type seed, size_t nbNodes) const; + std::pair<NodePtr, std::set<NodePtr>> gen(std::mt19937::result_type seed, std::size_t nbNodes) const; }; std::string nodePtrToType(NodePtr node); @@ -61,6 +61,7 @@ std::set<std::string> nodePtrTo(const std::set<NodePtr>& nodes, std::vector<std::pair<std::string, IOIndex_t>> nodePtrTo( const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes, std::string(*nodeTo)(NodePtr) = nodePtrToType); -} + +} // namespace Aidge #endif /* AIDGE_CORE_GRAPH_TESTING_H_ */ diff --git a/include/aidge/graphRegex/GraphStrInterpreter.hpp b/include/aidge/graphRegex/GraphStrInterpreter.hpp index 98dca0e9f84de0be2614aed0e47c9d86ae552674..38e89b3733e1a07062661fa520485f92fbd7f026 100644 --- a/include/aidge/graphRegex/GraphStrInterpreter.hpp +++ b/include/aidge/graphRegex/GraphStrInterpreter.hpp @@ -1,7 +1,6 @@ #ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ #define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ -#include <iostream> #include <sstream> #include <memory> #include <algorithm> diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 4e91182686f436718e97a739cf50da6f25c31608..3ecedf72756a43ed011ac7063dddedef2e81dae2 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -30,7 +30,7 @@ namespace Aidge { class Add_Op : public OperatorTensor, public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> { public: - static constexpr const char* Type = "Add"; + static const std::string Type; Add_Op(const IOIndex_t nbIn) : OperatorTensor(Type, nbIn, 0, 1) @@ -81,10 +81,10 @@ public: mOutputs[0]->setBackend(name, device); } - static const std::vector<std::string> getInputsName(){ + static const std::vector<std::string> getInputsName() { return {"data_input_0", "data_input_n"}; } - static const std::vector<std::string> getOutputsName(){ + static const std::vector<std::string> getOutputsName() { return {"data_output"}; } }; diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index e3202f76bdfe6d9fa1591880faec300ca8eee614..2d550d1734f7d68d9061dd67d9b89984e70c8509 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -36,7 +36,7 @@ class AvgPooling_Op : public OperatorTensor, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "AvgPooling"; + static const std::string Type; AvgPooling_Op() = delete; @@ -147,6 +147,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string AvgPooling_Op<DIM>::Type = "AvgPooling"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index e9e1c770447fa42bb93b395654f0688cd039e702..076739198a816e47990b9a594ef9703fb39a4302 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -33,7 +33,7 @@ class BatchNorm_Op : public OperatorTensor, public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>, public StaticAttributes<BatchNormAttr, float, float> { public: - static constexpr const char *Type = "BatchNorm"; + static const std::string Type; BatchNorm_Op() = delete; @@ -82,9 +82,9 @@ public: associated &= !(getInput(i)->empty()); } if (associated) { - const DimSize_t nbChannels = getInput(0)->dims()[1]; + const DimSize_t nbFeatures = getInput(0)->dims()[1]; for (std::size_t i = nbData(); i < nbInputs(); ++i) { - if(getInput(i)->size() != nbChannels) { + if(getInput(i)->size() != nbFeatures) { // /!\ Input size should be handled BEFORE calling this function // This should raise an error getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]})); @@ -123,16 +123,20 @@ public: } }; +template <DimIdx_t DIM> +const std::string BatchNorm_Op<DIM>::Type = "BatchNorm"; + template <DimSize_t DIM> -inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, +inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures, + const float epsilon = 1.0e-5F, const float momentum = 0.1F, const std::string& name = "") { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); - addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); - addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift"); - addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean"); - addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance"); + addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale"); + addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift"); + addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean"); + addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance"); return batchNorm; } } // namespace Aidge diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 9d0eed63d27644b3e1cef0a5e1e8144f91aa1784..ca91172f65f5509cf24f32ac463f96474b292e3e 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -32,7 +32,7 @@ class Concat_Op : public OperatorTensor, public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>, public StaticAttributes<ConcatAttr, DimSize_t> { public: - static constexpr const char* Type = "Concat"; + static const std::string Type; using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>; template <ConcatAttr e> diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index fc16359aafd8ff3376b424fcd7e981fecb25a8bb..1ebda2c599023c300e258c2c45123d23a478a351 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -36,7 +36,7 @@ class Conv_Op : public OperatorTensor, DimSize_t, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "Conv"; + static const std::string Type; Conv_Op() = delete; @@ -196,6 +196,9 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel } }; +template <DimIdx_t DIM> +const std::string Conv_Op<DIM>::Type = "Conv"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> Conv(DimSize_t inChannels, DimSize_t outChannels, diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index e03524643763e4e03edf91a51f6e41aafe2fe972..c97bbd21e664c0365b081da4e57dd3200e37ef8c 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -37,7 +37,7 @@ class ConvDepthWise_Op : public OperatorTensor, DimSize_t, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "ConvDepthWise"; + static const std::string Type; ConvDepthWise_Op() = delete; @@ -190,6 +190,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string ConvDepthWise_Op<DIM>::Type = "ConvDepthWise"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> ConvDepthWise(const DimSize_t nbChannels, const std::array<DimSize_t, DIM> &kernelDims, diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index 5bb0efd803f64f549b34fe775bdb2f42590f38f6..323f3058bdf9e6311bbfdf93f6675ee651efa27a 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -29,7 +29,7 @@ class Div_Op : public OperatorTensor, public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> { public: - static constexpr const char* Type = "Div"; + static const std::string Type; Div_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 2e7b3a22f333c103e1ebda632ff798cbd47e5b46..545e923fb08a8d71077340da2b0d2b3f052abc4b 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -35,7 +35,7 @@ class FC_Op : public OperatorTensor, std::unique_ptr<OperatorImpl>(const FC_Op &)>, public StaticAttributes<FCAttr, DimSize_t, bool> { public: - static constexpr const char* Type = "FC"; + static const std::string Type; FC_Op() = delete; diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 56e57a14ac2a22c91fd9f300c9e85fda998285fc..6adf031051b554878fc165c59f2aff0c81e35a9a 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -36,7 +36,7 @@ private: ComputeDimsFunc mComputeOutputDims; public: - GenericOperator_Op(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) + GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) : OperatorTensor(type, nbData, nbParam, nbOut) {} @@ -125,7 +125,7 @@ public: * @param name (optional) name of the Operator. * @return std::shared_ptr<Node> Node associated with the Generic Operator. */ -inline std::shared_ptr<Node> GenericOperator(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut, +inline std::shared_ptr<Node> GenericOperator(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name); } diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index 55d7f492ffd1e64ace2ab820749356144d675e9c..5f13e1d3a0a75e449b8daa8d1ea1c72ac0fe3e51 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -37,7 +37,7 @@ namespace Aidge { class Identity_Op : public OperatorTensor, public Registrable<Identity_Op, std::string, std::unique_ptr<OperatorImpl>(const Identity_Op&)> { public: - static constexpr const char* Type = "Identity"; + static const std::string Type; Identity_Op() : OperatorTensor(Type, 1, 0, 0) @@ -105,9 +105,11 @@ public: } void setBackend(const std::string& name, int device = 0) override final { // setBackend do nothing, Identity node has no backend it just pass the same Tensor + (void) name; } void setDataType(const DataType& dataType) const override final { // setDatatype do nothing, Identity node has no backend it just pass the same Tensor + (void) dataType; } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index a8bc0a477c823069318bc8f392140a99ffbc7bd7..b8e95b07d81b68aa865e55cba55c7c49c061f63b 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -33,7 +33,7 @@ class LeakyReLU_Op : public OperatorTensor, public Registrable<LeakyReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const LeakyReLU_Op&)>, public StaticAttributes<LeakyReLUAttr, float> { public: - static constexpr const char* Type = "LeakyReLU"; + static const std::string Type; LeakyReLU_Op() = delete; diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 3d1ed900a171caf340e953063ed397fbbb175904..5b733f6a57edb08cc35f912960398486f48acd27 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -35,7 +35,7 @@ class MatMul_Op : public OperatorTensor, std::unique_ptr<OperatorImpl>(const MatMul_Op &)>, public StaticAttributes<MatMulAttr, DimSize_t> { public: - static constexpr const char* Type = "MatMul"; + static const std::string Type; MatMul_Op() = delete; diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index 4494371803825e5380c54fef7741084a97c5a975..1cfa29b949ee4d2ebf5293069b553a5f829ffb39 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -36,7 +36,7 @@ class MaxPooling_Op : public OperatorTensor, std::array<DimSize_t, DIM>, bool> { public: - static constexpr const char *Type = "MaxPooling"; + static const std::string Type; MaxPooling_Op() = delete; @@ -117,6 +117,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string MaxPooling_Op<DIM>::Type = "MaxPooling"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> MaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 3427890b048b74c7ba1968c6fe3ea7369883c1c6..ba1ed5f16043364a14993420868a1974f5785598 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -40,7 +40,7 @@ public: /** * @brief Clone the operator using its copy-constructor. - * @see Operator::MatMul_Op + * @see Operator::MetaOperator_Op */ std::shared_ptr<Operator> clone() const override { return std::make_shared<MetaOperator_Op>(*this); diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index 41049c43f110975cc2f2eda0a7f0d4c92c81c1ad..33011ed1e25d5a1e78326f6ccff161c9299ee9b4 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -31,7 +31,7 @@ namespace Aidge { class Mul_Op : public OperatorTensor, public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> { public: - static constexpr const char* Type = "Mul"; + static const std::string Type; Mul_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index c20fffb9ca1d63aded66841b17aad12444e07a6a..ffd627ee7eca17c4870dc914c3ff53e2a5a24ec3 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -44,7 +44,7 @@ private: public: Operator() = delete; - Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) + Operator(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) : mType(type), mOperatorType(operatorType), mNbData(nbData), diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 126e5d467d0f341a8c5b8c5d16d188ebe92135d0..b956da474311b5863690f5a5e40329e443f1345a 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -40,7 +40,7 @@ protected: public: OperatorTensor() = delete; - OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, + OperatorTensor(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut) : Operator(type, nbData, nbParam, nbOut, OperatorType::Tensor), mInputs(std::vector<std::shared_ptr<Tensor>>(nbData + nbParam, nullptr)), diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp index 9f49cb9a9db918c232e5590d4f013c899082b80a..11cbc34093d3d4649e1b746f370602ffa9c5712f 100644 --- a/include/aidge/operator/Pad.hpp +++ b/include/aidge/operator/Pad.hpp @@ -37,7 +37,7 @@ class Pad_Op : public OperatorTensor, PadBorderType, double> { public: - static constexpr const char *Type = "Pad"; + static const std::string Type; Pad_Op() = delete; @@ -110,6 +110,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string Pad_Op<DIM>::Type = "Pad"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, 2*DIM> &beginEndTuples, const std::string& name = "", diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index 464c49909e2cb15e2da88e8533ce10459860d875..d89776d2172fe25689593c2036929a746b974376 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -29,7 +29,7 @@ namespace Aidge { class Pow_Op : public OperatorTensor, public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> { public: - static constexpr const char* Type = "Pow"; + static const std::string Type; Pow_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index a6a350a5dc1886e052d7ca566bcf8b31a2d130be..51ce579f6b7e1a9c691bb021d88b5bd77d975459 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -29,7 +29,7 @@ class Producer_Op public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( const Producer_Op &)> { public: - static constexpr const char* Type = "Producer"; + static const std::string Type; template <std::size_t DIM> Producer_Op(const std::array<DimSize_t, DIM>& dims) diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 8a8f3f8544e8581518850bbfb8b2981f2a039fc5..d6a8c2b6189023f3ca03c85c892e74724eab36d0 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -28,7 +28,7 @@ namespace Aidge { class ReLU_Op : public OperatorTensor, public Registrable<ReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const ReLU_Op&)> { public: - static constexpr const char* Type = "ReLU"; + static const std::string Type; ReLU_Op() : OperatorTensor(Type, 1, 0, 1) {} diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 6c49b784849b766615bb8baf4e8c7b0ba216d489..0770329065e3a0a15939516181428c8c4b2e7986 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -32,7 +32,7 @@ class Scaling_Op : public OperatorTensor, public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, public StaticAttributes<ScalingAttr, float, size_t, bool> { public: - static constexpr const char* Type = "Scaling"; + static const std::string Type; Scaling_Op() = delete; diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index e48de88c8726ba548a84d42abfe4feed2cb6d220..95e0c72eb38587221732e8ec357b1c106d468275 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -31,7 +31,7 @@ class Slice_Op public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, public StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>> { public: - static constexpr const char *Type = "Slice"; + static const std::string Type; Slice_Op() = delete; diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index d4716dc5cb0908c0c64c332f50177280ebf0fa62..913b58cb5347aea9f13b46ef28b27cdfae182756 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -29,7 +29,7 @@ namespace Aidge { class Softmax_Op : public OperatorTensor, public Registrable<Softmax_Op, std::string, std::unique_ptr<OperatorImpl>(const Softmax_Op&)> { public: - static constexpr const char* Type = "Softmax"; + static const std::string Type; Softmax_Op() : OperatorTensor(Type, 1, 0, 1) {} diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp index b679b3d6e2dd62d5290cc89f39a1025b18e8c37a..b95cdfe85051f59202f84e76c15103d23c9edb93 100644 --- a/include/aidge/operator/Sqrt.hpp +++ b/include/aidge/operator/Sqrt.hpp @@ -34,7 +34,7 @@ public: const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: - static constexpr const char* Type = "Sqrt"; + static const std::string Type; Sqrt_Op() : OperatorTensor(Type, 1, 0, 1) {} diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index 7eb7c8f9b1756d061b7abe30dc002785fc2c7573..9b84cf3d28c4e5ce0ed9b84c128ef3c05c1c6f35 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -34,7 +34,7 @@ public: const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: - static constexpr const char* Type = "Sub"; + static const std::string Type; Sub_Op() : OperatorTensor(Type, 2, 0, 1) {} diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index ff0b9e0dfcb0d1c5e5567a938b1ca74faf242bed..411a2e1b6ae78065a79b92f25c23dac13e341997 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) { .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); - m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } void init_BatchNorm(py::module &m) { diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 71231b8218ac6af28c97ec29039301bc25b2d195..2200cd3fec1450011d6e0b5197f8b99b4dfeb4c3 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -11,7 +11,6 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> -#include <iostream> #include <string> #include <vector> #include <array> diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 8ea10427ca97f6435c370a843590302dd12eccd7..9b788fc5e972c1020994b31d5fc89459d0f4a915 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -74,7 +74,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { std::string givenName = (node_ptr->name().empty()) ? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>" - : node_ptr->name() + " <sub><em>" + currentType + "</em></sub>"; + : "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + "#" + std::to_string(typeCounter[currentType]) + " )</em></sub>\""; namePtrTable[node_ptr] = (currentType + "_" + std::to_string(typeCounter[currentType])); @@ -117,14 +117,14 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { size_t inputIdx = 0; for (auto input : mInputNodes) { - std::fprintf(fp, "input%lu((in#%lu)):::inputCls-->|→%u|%s\n", inputIdx, inputIdx, + std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|→%u|%s\n", inputIdx, inputIdx, input.second, namePtrTable[input.first].c_str()); ++inputIdx; } size_t outputIdx = 0; for (auto output : mOutputNodes) { - std::fprintf(fp, "%s-->|%u→|output%lu((out#%lu)):::outputCls\n", + std::fprintf(fp, "%s--->|%u→|output%lu((out#%lu)):::outputCls\n", namePtrTable[output.first].c_str(), output.second, outputIdx, outputIdx); ++outputIdx; @@ -694,128 +694,151 @@ void Aidge::GraphView::insertParent(NodePtr childNode, add(newParentNode); } - bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { - // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // How to distinguish it from data input? // TODO: Parameter Tensors could be identified with their dimensions // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // It also avoids specifying each producer since they are automatically included + // (1) create GraphViews from both sets of Nodes auto oldG = std::make_shared<GraphView>("oldG"); oldG->add(oldNodes, false); auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); - if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) { - return false; - } - if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) || - (newG->outputNodes().size() != 1))) { - return false; + const auto oldOI = oldG->getOrderedInputs(); + const auto oldOO = oldG->getOrderedOutputs(); + const auto newOI = newG->getOrderedInputs(); + const auto newOO = newG->getOrderedOutputs(); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size()); + + // keep in memory every parent + for (std::size_t i = 0; i < oldOI.size(); ++i) { + auto inputParent = oldOI[i].first -> input(oldOI[i].second); + inputParents[i]= inputParent; + // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); } - - // there is at least one inputNode in the old/new GraphView - std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin()); - std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin()); - - // find Node to link to new input Node - //compute number of input for firstPreviousInputNode not in oldNodes set - std::size_t nbExternalInputs = 0; - std::shared_ptr<Node> externalInput = nullptr; - IOIndex_t externalInputId = gk_IODefaultIndex; - for (const auto& input : firstPreviousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG - nbExternalInputs++; - externalInput = input.first; - externalInputId = input.second; + for (std::size_t i = 0; i < oldOO.size();) { + auto outputChildList = oldOO[i].first -> output(oldOO[i].second); + if (outputChildList.empty()) { + outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); + ++i; } - } - if (nbExternalInputs > 1) { - AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); - } - - if (oldG->inputNodes().size() > 1){ - // one or no input has been identified. Checking every input points to the same source - for (const auto& previousInputNode : oldG->inputNodes()) { - for (const auto& input : previousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { - if ( (externalInput != input.first) || (externalInputId != input.second) ) { - return false; // an inputNode points to an external Node different from the registered one - } + else { + for (const auto& child : outputChildList) { + if (oldNodes.find(child.first) == oldNodes.cend()) { + outputChildren[i] = child; + ++i; } } } } - if (firstPreviousOutputNode->nbOutputs() != 1) { - return false; - } - - // find Node to replicate output connections - std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); - - auto copyOutputs = firstPreviousOutputNode->outputs(); - // manage Views for newNodes // only keep common views to each node for the new set + // set of common GraphView for oldNodes' Nodes std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); for (const auto& nodePtr : oldNodes) { - const auto nodeView = nodePtr->views(); - std::set<std::shared_ptr<GraphView>> intersection; - std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), - nodeView.begin(), nodeView.end(), - std::inserter(intersection, intersection.begin())); - commonGraphViews = intersection; + const auto nodeView = nodePtr->views(); + std::set<std::shared_ptr<GraphView>> intersection; + std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), + nodeView.begin(), nodeView.end(), + std::inserter(intersection, intersection.begin())); + commonGraphViews = intersection; } commonGraphViews.erase(oldG); commonGraphViews.erase(newG); - // clean Nodes to replace - // Do not include common Nodes to avoid cleaning Producers linked to newNodes - std::set<std::shared_ptr<Node>> nodesToClean; - std::set_difference(oldNodes.begin(), oldNodes.end(), - newNodes.begin(), newNodes.end(), - std::inserter(nodesToClean, nodesToClean.begin())); - for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } - - // copy output connections - if (newOutputNode) { - for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { - auto outputPairs = copyOutputs[o]; - for (const auto& onePair : outputPairs) { - newOutputNode->addChild(onePair.first, o, onePair.second); - } + if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); + } + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; + } + + for (const auto& nodePtr : oldNodes) { + for (const auto& g : commonGraphViews) { + g -> remove(nodePtr, false); + g -> updateInputsOutputsDelete(nodePtr); } + nodePtr -> resetConnections(true); } - // copy input connections - if (!newNodes.empty() && externalInput) { - for (const auto& newInputNode : newG->inputNodes()) { - IOIndex_t inputId = 0; - for (const auto& input : newInputNode->inputs()) { - if (newNodes.find(input.first) == newNodes.end()) { - externalInput->addChild(newInputNode, externalInputId, inputId); + if ((oldOI.size() == newOI.size()) && + (oldOO.size() == newOO.size())) { + // Case 1 + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) { + inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second); + } + } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); + } + } + } + else { + // get the number of Parents for oldG->inputNodes() + // get the number of Children for oldg->outputNodes() + if (newNodes.size() == 0) { + // Case 3 + if (oldOI.size() == oldOO.size()) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) + inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); + } + } + else if (oldOI.size() == 1) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); + } + } + } + else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes + ((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1 + ((oldOO.size() == newOO.size())) + ) { + // Case 2 + if ((oldOI.size() == 1)) { + for (std::size_t i = 0; i < newOI.size(); ++i) { + inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second); + } + } else { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); + } + } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); } - inputId++; } } + else { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); + } + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; + } } - - // insert new Nodes in the right GraphViews - for (const auto& graphPtr : commonGraphViews) { - graphPtr->add(newNodes, false); - if (newNodes.empty()) { - // TODO: FIXME: this function should not be called anymore! - graphPtr->updateInputsOutputsNodes_DEPRECATED(); + for (const auto& nodePtr : newNodes) { + for (const auto& g : commonGraphViews) { + g -> add(nodePtr); } } - - for (const auto& node : oldNodes) { - node->removeView(oldG); + for (const auto& nodePtr : oldNodes) { + nodePtr -> removeView(oldG); } - for (const auto& node : newNodes) { - node->removeView(newG); + for (const auto& nodePtr : newNodes) { + nodePtr -> removeView(newG); } return true; } @@ -824,22 +847,22 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { // Can be called several times with the same node, e.g. when addChild() is // called on a node already part of the GraphView. In this case, inputs/outputs // need to be updated! - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); // Remove inputs that are not input anymore because connected to newNode for (auto orderedChilds : newNode->getOrderedChildren()) { for (auto ch_ptr : orderedChilds) { // Check that newNode child is in current GraphView - if (mNodes.find(ch_ptr) != mNodes.end()) { + if (mNodes.find(ch_ptr) != mNodes.cend()) { IOIndex_t inputIdx = 0; for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { // If newNode is connected to it if (pa_ptr == newNode) { const auto val = std::make_pair(ch_ptr, inputIdx); - const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); // Check that it was not already the case (if node UPDATE) - if (iter != mInputNodes.end()) { + if (iter != mInputNodes.cend()) { // newNode is linked to an actual inputNode to an input connection // The first old (removed) input becomes the insertion point for newNode GraphView inputs if (std::distance(newInputsInsertionPoint, iter) <= 0) { newInputsInsertionPoint = mInputNodes.erase(iter); @@ -855,55 +878,45 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { } } - // Check if node inputs are inputs for the GraphView and add them to the input list if so - // Inputs addition order follows node inputs order - // Inputs are inserted at the position of the first input removed - IOIndex_t inputIdx = 0U; - for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { - if ((pa_ptr == nullptr) || - (mNodes.find(pa_ptr) == - mNodes.end())) { // Parent doesn't exist || Parent not in the graph - const auto val = std::make_pair(newNode, inputIdx); - if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { - newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); - newInputsInsertionPoint = std::next(newInputsInsertionPoint); - } - } - ++inputIdx; - } - - // (if node UPDATE) - // newNode may already exists in the graph and may have been updated - // Check and remove inputs that are not inputs anymore - inputIdx = 0U; - for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { - if ((pa_ptr != nullptr) && - (mNodes.find(pa_ptr) != - mNodes.end())) { - const auto val = std::make_pair(newNode, inputIdx); - auto it = std::find(mInputNodes.begin(), mInputNodes.end(), val); - if (it != mInputNodes.end()) { - mInputNodes.erase(it); - } + // Manage newNode parents + // Check if any input connection is an input for the GraphView + IOIndex_t inputIdx = 0U; + for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { + const auto val = std::make_pair(newNode, inputIdx); + const auto it = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == mNodes.cend())) { + // Parent doesn't exist || Parent not in the graph + if (it == mInputNodes.cend()) { + // If node's inputs are inputs for the GraphView: add them to the input list + // Addition rule: + // - Inputs addition order follows node inputs order + // - Inputs are inserted at the position of the first input removed + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); + } + } else if (it != mInputNodes.cend()) { + // Parent already in the graph SO edge is not an input anymore for the graph + mInputNodes.erase(it); + } + ++inputIdx; } - ++inputIdx; - } - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); // Remove outputs that are not output anymore because connected to newNode for (const std::shared_ptr<Node>& parent : newNode->getParents()) { // Check that newNode parent is in current GraphView - if (mNodes.find(parent) != mNodes.end()) { + if (mNodes.find(parent) != mNodes.cend()) { IOIndex_t outputIdx = 0; for (auto orderedChilds : parent->getOrderedChildren()) { for (auto ch_ptr : orderedChilds) { // If newNode is connected to it if (ch_ptr == newNode) { const auto val = std::make_pair(parent, outputIdx); - const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); - if (iter != mOutputNodes.end()) { + if (iter != mOutputNodes.cend()) { // The first old (removed) output becomes the insertion point for newNode GraphView outputs if (std::distance(newOutputsInsertionPoint, iter) <= 0) { newOutputsInsertionPoint = mOutputNodes.erase(iter); @@ -943,14 +956,14 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { } void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) { - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); // Check if node inputs were inputs for the GraphView and remove them from the list if so for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) { const auto val = std::make_pair(deletedNode, inputIdx); - const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); - if (iter != mInputNodes.end()) { + if (iter != mInputNodes.cend()) { // The first old (removed) input becomes the insertion point for new GraphView inputs if (std::distance(newInputsInsertionPoint, iter) <= 0) { newInputsInsertionPoint = mInputNodes.erase(iter); @@ -966,13 +979,13 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo for (auto orderedChilds : deletedNode->getOrderedChildren()) { for (auto ch_ptr : orderedChilds) { // Check that deletedNode child is in current GraphView - if (mNodes.find(ch_ptr) != mNodes.end()) { + if (mNodes.find(ch_ptr) != mNodes.cend()) { IOIndex_t inputIdx = 0; for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { // If newNode was connected to it if (pa_ptr == deletedNode) { const auto val = std::make_pair(ch_ptr, inputIdx); - if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { + if (std::find(mInputNodes.cbegin(), mInputNodes.cend(), val) == mInputNodes.cend()) { newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); newInputsInsertionPoint = std::next(newInputsInsertionPoint); } @@ -982,15 +995,15 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo } } } - - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); + + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); // Check if node outputs were outputs for the GraphView and remove them from the list if so for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) { const auto val = std::make_pair(deletedNode, outputIdx); - const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); - if (iter != mOutputNodes.end()) { + if (iter != mOutputNodes.cend()) { // The first old (removed) output becomes the insertion point for newNode GraphView outputs if (std::distance(newOutputsInsertionPoint, iter) <= 0) { newOutputsInsertionPoint = mOutputNodes.erase(iter); @@ -1004,7 +1017,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo // Add parent node outputs that become GraphView output following the removal of the node // Outputs addition order follows deletedNode inputs order for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { - if (parent != nullptr && mNodes.find(parent) != mNodes.end()) { + if (mNodes.find(parent) != mNodes.end()) { IOIndex_t outputIdx = 0; for (auto orderedChilds : parent->getOrderedChildren()) { bool noInsideConnection = true; @@ -1017,7 +1030,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo if (noInsideConnection) { const auto val = std::make_pair(parent, outputIdx); - if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { + if (std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val) == mOutputNodes.cend()) { newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); } @@ -1028,40 +1041,6 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo } } -void Aidge::GraphView::updateInputsOutputsNodes_DEPRECATED() { - mInputNodes.clear(); - for (const std::shared_ptr<Node>& go_ptr : mNodes) { - IOIndex_t inputIdx = 0; - for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) { - if ((pa_ptr == nullptr) || - (mNodes.find(pa_ptr) == - mNodes.end())) { // Parent doesn't exist || Parent not in the graph - mInputNodes.push_back(std::make_pair(go_ptr, inputIdx)); - } - - ++inputIdx; - } - } - - mOutputNodes.clear(); - for (const std::shared_ptr<Node>& go_ptr : mNodes) { - IOIndex_t outputIdx = 0; - for (auto orderedChilds : go_ptr->getOrderedChildren()) { - bool noInsideConnection = true; - for (auto ch_ptr : orderedChilds) { - if (mNodes.find(ch_ptr) != mNodes.end()) { - noInsideConnection = false; - break; - } - } - - if (noInsideConnection) { - mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx)); - } - ++outputIdx; - } - } -} std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5a7b05e469daab10a4abd468177a3ad137096f63..6f0cc55159b1cc72b87bb34230376eb140b7ab8a 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -11,22 +11,25 @@ #include "aidge/graph/Node.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/operator/Producer.hpp" #include <memory> #include <vector> + +#include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/utils/Types.h" Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) : mName(name), mOperator(op), - mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), - mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), - std::vector<std::weak_ptr<Node>>())), - mIdInChildren( - std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), - mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { + mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), + nullptr)), + mChildren(std::vector<std::vector<std::weak_ptr<Node>>>( + static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())), + mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), + std::vector<IOIndex_t>())), + mIdOutParents( + std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { // ctor } @@ -34,14 +37,15 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// -Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { +Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) { assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); - for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { - assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); - (void) input; // avoid unused warning + for (std::pair<std::shared_ptr<Node>, IOIndex_t>& input : inputs()) { + assert((gk_IODefaultIndex == input.second) && + "At least one input connection is not free.\n"); + (void)input; // avoid unused warning } IOIndex_t i = 0; - for (const Connector &ctor : ctors) { + for (const Connector& ctor : ctors) { if (ctor.node() != nullptr) { // ctor must be associated with a node ctor.node()->addChild(shared_from_this(), ctor.index(), i++); } @@ -53,7 +57,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { // INNER /////////////////////////////////////////////////////// -void Aidge::Node::setName(const std::string &name) { mName = name; } +void Aidge::Node::setName(const std::string& name) { mName = name; } /////////////////////////////////////////////////////// // OPERATORS @@ -92,8 +96,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { return nbFreeDataIn; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::dataInputs() const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() + const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { @@ -104,15 +108,15 @@ Aidge::Node::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); for (std::size_t i = 0; i < nbInputs(); ++i) { - res[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); + res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } -// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { +// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> +// tensor) { // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); // if (mParents[idx] != nullptr) { // mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); @@ -128,20 +132,21 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::Node::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = - std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); + std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>( + mIdInChildren.size()); for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { listOutputs[i] = output(static_cast<IOIndex_t>(i)); } return listOutputs; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::output(Aidge::IOIndex_t outId) const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output( + Aidge::IOIndex_t outId) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { - listOutputs[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); + listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), + mIdInChildren[outId][i]); } return listOutputs; } @@ -180,7 +185,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) // TOPOLOGY /////////////////////////////////////////////////////// -void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { +void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + const IOIndex_t otherInId) { assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second != gk_IODefaultIndex) { @@ -196,33 +202,41 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou } void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { - assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + assert((otherInId.second < otherInId.first->nbInputs()) && + "Other graph input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { - assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node + assert((inNodes.find(otherInId.first) != + inNodes.end())); // assert it really is an input node addChildOp(otherInId.first, outId, otherInId.second); } } -void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { - otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); - addChildOp(otherNode, outId, otherInId); +void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + IOIndex_t otherInId) { + if (otherNode) { + otherInId = + (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); + addChildOp(otherNode, outId, otherInId); + } } void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { if (!otherInId.first) { assert((otherView->inputNodes().size() == 1U) && "Specify an input Node for the GraphView. More or less than one " "Node is not explicit."); otherInId.first = *(otherView->inputNodes().begin()); } - otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); + otherInId.second = (otherInId.second != gk_IODefaultIndex) + ? otherInId.second + : otherInId.first->getFirstFreeDataInput(); addChildView(otherView, outId, otherInId); } @@ -255,8 +269,8 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) { std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const auto &childrenOfOneOutput : mChildren) { - for (const auto &oneChild : childrenOfOneOutput) { + for (const auto& childrenOfOneOutput : mChildren) { + for (const auto& oneChild : childrenOfOneOutput) { children.insert(oneChild.lock()); } } @@ -264,7 +278,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { - std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); + std::vector<std::vector<std::shared_ptr<Node>>> children = + std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { children[outId] = getChildren(outId); } @@ -273,14 +288,16 @@ std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedCh std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { assert((outId < nbOutputs()) && "Output index out of bound."); - std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); + std::vector<std::shared_ptr<Node>> children = + std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { - children.push_back(mChildren[outId][i].lock()); - } + children.push_back(mChildren[outId][i].lock()); + } return children; } -bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { +bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, + const Aidge::IOIndex_t outId) { assert((outId < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { @@ -301,7 +318,8 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); if (parent.first) { // number of children linked to the parent's output - while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} + while (parent.first->removeChild(shared_from_this(), parent.second) == true) { + } } // every reference to this object as child has been removed // removing reference to parents. @@ -316,24 +334,23 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to - for (auto& graph : views()) { - // if keeping connections with LEarnable Parameters, then also remove them from graph - graph->remove(shared_from_this(), !includeLearnableParam); - } + // for (auto& graph : views()) { + // // if keeping connections with LEarnable Parameters, then also remove them from graph + // graph->remove(shared_from_this(), !includeLearnableParam); + // } } - /////////////////////////////////////////////////////// - // CLONE - /////////////////////////////////////////////////////// +/////////////////////////////////////////////////////// +// CLONE +/////////////////////////////////////////////////////// Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { return std::make_shared<Node>(mOperator, mName); } Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { - std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) - ? mOperator - : mOperator->clone(); + std::shared_ptr<Operator> op = + (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone(); return std::make_shared<Node>(op, mName); } @@ -342,27 +359,25 @@ Aidge::NodePtr Aidge::Node::clone() const { return std::make_shared<Node>(mOperator->clone(), mName); } - -std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ - +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::NodePtr> nodeSee) { std::set<Aidge::NodePtr> out; nodeSee.insert(shared_from_this()); - if(delta == 0) { + if (delta == 0) { out.insert(shared_from_this()); - }else if (delta > 0){ - for (const NodePtr& node : getChildren()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + } else if (delta > 0) { + for (const NodePtr& node : getChildren()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta - 1, nodeSee)) { out.insert(ch); } } } - }else{ - for (const NodePtr& node : getParents()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + } else { + for (const NodePtr& node : getParents()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta + 1, nodeSee)) { out.insert(pr); } } diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp index ad193f000b787c458b29934e8b95c122ef15f409..f30ad6e25b81e1ce7768fcc201ddf00c2226eebf 100644 --- a/src/graph/Testing.cpp +++ b/src/graph/Testing.cpp @@ -9,46 +9,56 @@ * ********************************************************************************/ +#include <algorithm> // std::shuffle, std::transform +#include <cstddef> +#include <memory> +#include <numeric> // std::iota +#include <random> // std::binomial_distribution, std::mt19937, std::discrete_distribution +#include <string> +#include <utility> // std::pair +#include <vector> + #include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" +#include "aidge/utils/Types.h" -std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomGraph::gen(std::mt19937::result_type seed, size_t nbNodes) const { +std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomGraph::gen(std::mt19937::result_type seed, std::size_t nbNodes) const { std::mt19937 gen(seed); std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn); std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut); std::binomial_distribution<> dLink(1, density); std::discrete_distribution<> dType(typesWeights.begin(), typesWeights.end()); - std::vector<std::pair<int, int>> nbIOs; + std::vector<std::pair<IOIndex_t, IOIndex_t>> nbIOs; std::vector<std::string> nodesType; - for (size_t i = 0; i < nbNodes; ++i) { + for (std::size_t i = 0; i < nbNodes; ++i) { const auto nbIn = 1 + dIn(gen); nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen))); nodesType.push_back(types[dType(gen)]); } - std::vector<int> nodesSeq(nbNodes); - std::iota(nodesSeq.begin(), nodesSeq.end(), 0); + std::vector<std::size_t> nodesSeq(nbNodes); + std::iota(nodesSeq.begin(), nodesSeq.end(), static_cast<std::size_t>(0)); // Don't use gen or seed here, must be different each time! std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}())); std::vector<NodePtr> nodes(nbNodes, nullptr); for (auto idx : nodesSeq) { const std::string name = nodesType[idx] + std::to_string(idx); - nodes[idx] = GenericOperator(nodesType[idx].c_str(), nbIOs[idx].first, 0, nbIOs[idx].second, name.c_str()); + nodes[idx] = GenericOperator(nodesType[idx], nbIOs[idx].first, 0, nbIOs[idx].second, name); } - for (size_t i = 0; i < nbNodes; ++i) { - for (size_t j = (acyclic) ? i + 1 : 0; j < nbNodes; ++j) { + for (std::size_t i = 0; i < nbNodes; ++i) { + for (std::size_t j = (acyclic) ? i + 1 : 0; j < nbNodes; ++j) { if (i == j) { // Do not connected node to itself in case of cyclic graph! continue; } - for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { - for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { + for (std::size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { + for (std::size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { if (dLink(gen)) { - // Warning: connections can be set multiple time for the + // Warning: connections can be set multiple time for the // same node input! In this case, the previous connection // is overwritten. This is the expected behavior. nodes[i]->addChild(nodes[j], outId, inId); @@ -82,7 +92,7 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomGraph::gen(std: NodePtr rootNode = nullptr; std::set<NodePtr> nodesSet; - for (size_t i = 0; i < nbNodes; ++i) { + for (std::size_t i = 0; i < nbNodes; ++i) { if (nodes[i]->type() != omitType) { if (rootNode == nullptr) { rootNode = nodes[i]; diff --git a/src/operator/Add.cpp b/src/operator/Add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e638fd86da487565a89760925e45339213fa8f9 --- /dev/null +++ b/src/operator/Add.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Add.hpp" + +const std::string Aidge::Add_Op::Type = "Add"; \ No newline at end of file diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eafcd126480df6da2c0127bdbb896d3ce98d0e0a --- /dev/null +++ b/src/operator/Concat.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Concat.hpp" + +const std::string Aidge::Concat_Op::Type = "Concat"; \ No newline at end of file diff --git a/src/operator/Div.cpp b/src/operator/Div.cpp index 273eac2e8fa9623e617d1be204ac2ae46d8da02d..85db3ac6ef66c837c86dbece288185deaca88ba6 100644 --- a/src/operator/Div.cpp +++ b/src/operator/Div.cpp @@ -11,6 +11,7 @@ #include <cassert> #include <cstddef> +#include <string> #include <vector> #include <utility> @@ -19,6 +20,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Div_Op::Type = "Div"; + void Aidge::Div_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32114f5bf9e0d160db9fdc2d1971481be0b4e703 --- /dev/null +++ b/src/operator/FC.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/FC.hpp" + +const std::string Aidge::FC_Op::Type = "FC"; \ No newline at end of file diff --git a/src/operator/Identity.cpp b/src/operator/Identity.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f57906dd4f3564b52cde16236bda87370e8f86d7 --- /dev/null +++ b/src/operator/Identity.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Identity.hpp" + +const std::string Aidge::Identity_Op::Type = "Identity"; \ No newline at end of file diff --git a/src/operator/LeakyReLU.cpp b/src/operator/LeakyReLU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32e050ee1595cf83b5cd0ffbfeba6153dc2243af --- /dev/null +++ b/src/operator/LeakyReLU.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/LeakyReLU.hpp" + +const std::string Aidge::LeakyReLU_Op::Type = "LeakyReLU"; \ No newline at end of file diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..666ed3921ed1190a91935bd9f38303e23963d912 --- /dev/null +++ b/src/operator/MatMul.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/MatMul.hpp" + +const std::string Aidge::MatMul_Op::Type = "MatMul"; \ No newline at end of file diff --git a/src/operator/Mul.cpp b/src/operator/Mul.cpp index 2e3e77288bf1e0613f0aa572e3c50e94599a902f..bc268263e8a6e2ec7c9944faa31da84dc50c4f53 100644 --- a/src/operator/Mul.cpp +++ b/src/operator/Mul.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Mul_Op::Type = "Mul"; + void Aidge::Mul_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp index c213a47a4a590026c07625aeb532d303ca8dbced..de1f0c3694f51fbd5b365573f61d3e3e2b9109ff 100644 --- a/src/operator/Pow.cpp +++ b/src/operator/Pow.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Pow_Op::Type = "Pow"; + void Aidge::Pow_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..443f2fa7d8a60cd25ccb622f2dad5b4926b88eea --- /dev/null +++ b/src/operator/Producer.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Producer.hpp" + +const std::string Aidge::Producer_Op::Type = "Producer"; \ No newline at end of file diff --git a/src/operator/ReLU.cpp b/src/operator/ReLU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f7874acfe7d865ea8c56d4bca02b51864480df6 --- /dev/null +++ b/src/operator/ReLU.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/ReLU.hpp" + +const std::string Aidge::ReLU_Op::Type = "ReLU"; \ No newline at end of file diff --git a/src/operator/Scaling.cpp b/src/operator/Scaling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c121e1268c1e1a62f793f38c6d816e7c6b48c25 --- /dev/null +++ b/src/operator/Scaling.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Scaling.hpp" + +const std::string Aidge::Scaling_Op::Type = "Scaling"; \ No newline at end of file diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a25e290dff35e4257d486613a5fe06894119d367 --- /dev/null +++ b/src/operator/Slice.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Slice.hpp" + +const std::string Aidge::Slice_Op::Type = "Slice"; \ No newline at end of file diff --git a/src/operator/Softmax.cpp b/src/operator/Softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e88ff4bb4ec6e2cb1357d578c2d07cc4edcb59f7 --- /dev/null +++ b/src/operator/Softmax.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Softmax.hpp" + +const std::string Aidge::Softmax_Op::Type = "Softmax"; \ No newline at end of file diff --git a/src/operator/Sqrt.cpp b/src/operator/Sqrt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbcaba42619762f8fd00bb2f6e0aa0de11d92960 --- /dev/null +++ b/src/operator/Sqrt.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Sqrt.hpp" + +const std::string Aidge::Sqrt_Op::Type = "Sqrt"; \ No newline at end of file diff --git a/src/operator/Sub.cpp b/src/operator/Sub.cpp index 8175f1b7ae5bb5eccd36267c1d739f764bd3c236..639eaf798c1c2a9a6685e8b8d2c4a2cb00a4b57a 100644 --- a/src/operator/Sub.cpp +++ b/src/operator/Sub.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Sub_Op::Type = "Sub"; + void Aidge::Sub_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/unit_tests/graph/Test_Connector.cpp b/unit_tests/graph/Test_Connector.cpp index a3bcc5783bd1de89ebf82d2b6078a26bdd49eaa7..79acce9281039f9f3c67b7235d8999b6c7173685 100644 --- a/unit_tests/graph/Test_Connector.cpp +++ b/unit_tests/graph/Test_Connector.cpp @@ -113,7 +113,7 @@ TEST_CASE("GraphGeneration from Connector", "[GraphView]") { x= (*node09)({x}); x = (*node10)({a, x}); std::shared_ptr<GraphView> gv = generateGraph({x}); - gv->save("GraphGeneration"); + // gv->save("GraphGeneration"); REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_matmul1", 0}})); } @@ -164,7 +164,7 @@ TEST_CASE("Connector connection GraphView", "[Connector]") { std::shared_ptr<GraphView> gv = generateGraph({x}); REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv3", 0}})); - gv->save("MultiInputSequentialConnector"); + // gv->save("MultiInputSequentialConnector"); REQUIRE(gv->inputNodes().size() == 0U); } } @@ -181,7 +181,7 @@ TEST_CASE("Connector Mini-graph", "[Connector]") { std::shared_ptr<GraphView> g = generateGraph({y}); REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"ElemWise", 0}})); - g->save("TestGraph"); + // g->save("TestGraph"); } TEST_CASE("Structural descrition - Sequential", "[GraphView]") { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 784a618c8ed38aea527ea460e221fd1ba0082741..acbea04a27a0b6be22105bb73fda53fedf621235 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -9,6 +9,7 @@ * ********************************************************************************/ +#include <algorithm> // std::sort #include <cassert> #include <map> #include <memory> @@ -27,20 +28,7 @@ using namespace Aidge; -class GraphView_Test : public GraphView { -public: - GraphView_Test(const std::string& name="") - : GraphView(name) - { - // ctor - } - - void updateInputsOutputsNodes_DEPRECATED_Test() { - GraphView::updateInputsOutputsNodes_DEPRECATED(); - } -}; - -TEST_CASE("genRandomGraph") { +TEST_CASE("genRandomGraph", "[GraphView][randomGen]") { const size_t nbTests = 100; size_t nbUnicity = 0; @@ -49,13 +37,13 @@ TEST_CASE("genRandomGraph") { const std::mt19937::result_type seed(rd()); RandomGraph randGraph; - const auto g1 = std::make_shared<GraphView_Test>("g1"); + const auto g1 = std::make_shared<GraphView>("g1"); const bool unicity1 = g1->add(randGraph.gen(seed, 10)); const auto g2 = std::make_shared<GraphView>("g2"); const bool unicity2 = g2->add(randGraph.gen(seed, 10)); - g1->save("./genRandomGraph1"); - g2->save("./genRandomGraph2"); + // g1->save("./genRandomGraph1"); + // g2->save("./genRandomGraph2"); REQUIRE(unicity1 == unicity2); @@ -67,8 +55,6 @@ TEST_CASE("genRandomGraph") { REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); ++nbUnicity; - // Test deprecated function - g1->updateInputsOutputsNodes_DEPRECATED_Test(); // Check that inputs/outputs are the same regardless of the order auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); @@ -89,7 +75,7 @@ TEST_CASE("genRandomGraph") { printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests); } -TEST_CASE("clone") { +TEST_CASE("clone", "[GraphView][clone]") { const size_t nbTests = 100; for (int test = 0; test < nbTests; ++test) { @@ -99,7 +85,7 @@ TEST_CASE("clone") { RandomGraph randGraph; const auto g1 = std::make_shared<GraphView>("g1"); g1->add(randGraph.gen(seed, 10)); - + // g1 -> save("GraphView_clone"); const auto g2 = g1->clone(); REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); @@ -115,7 +101,7 @@ NodePtr nodeDel(NodePtr node) { return node->clone(); } -TEST_CASE("clone_with_delete") { +TEST_CASE("clone_with_delete", "[GraphView][cloneDelete]") { const size_t nbTests = 100; size_t nbClonedWithDelete = 0; @@ -140,8 +126,8 @@ TEST_CASE("clone_with_delete") { const auto g2 = std::make_shared<GraphView>("g2"); const bool unicity2 = g2->add(randGraph.gen(seed, 10)); - g1->save("./clone_with_delete1"); - g2->save("./clone_with_delete2"); + // g1->save("./clone_with_delete1"); + // g2->save("./clone_with_delete2"); try { const auto gCloned = g1->cloneCallback(&nodeDel); @@ -162,7 +148,7 @@ TEST_CASE("clone_with_delete") { printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests); } -TEST_CASE("remove") { +TEST_CASE("remove", "[GraphView][remove]") { const size_t nbTests = 100; size_t nbTested = 0; @@ -177,13 +163,13 @@ TEST_CASE("remove") { const bool unicity1 = g1->add(randGraph.gen(seed, 10)); if (unicity1) { - g1->save("./remove1_before"); + // g1->save("./remove1_before"); const auto nodes = g1->getNodes(); int step = 1; for (auto node : nodes) { if (node->type() == "DelFictive") { g1->remove(node, false); - g1->save("./remove1_after" + std::to_string(step)); + // g1->save("./remove1_after" + std::to_string(step)); step++; } } @@ -192,8 +178,8 @@ TEST_CASE("remove") { const auto g2 = std::make_shared<GraphView>("g2"); g2->add(randGraph.gen(seed, 10)); - g1->save("./remove1"); - g2->save("./remove2"); + // g1->save("./remove1"); + // g2->save("./remove2"); REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); // Order not garanteed, because when a node is removed, it can create new GraphView inputs/outputs @@ -220,14 +206,14 @@ TEST_CASE("remove") { printf("nbTested = %zu/%zu\n", nbTested, nbTests); } -TEST_CASE("[core/graph] GraphView(Constructor)") { +TEST_CASE("[core/graph] GraphView(Constructor)", "[GraphView][constructor()]") { std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1"); REQUIRE(g0 != nullptr); REQUIRE(g1 != nullptr); } -TEST_CASE("[core/graph] GraphView(add)") { +TEST_CASE("[core/graph] GraphView(add)", "[GraphView][add]") { SECTION("Node alone") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); @@ -242,7 +228,7 @@ TEST_CASE("[core/graph] GraphView(add)") { g->add(GOp5); std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 1, 1, "Gop6"); g->add(GOp6); - g->save("node_alone"); + // g->save("node_alone"); REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop3", 0}, {"Gop4", 0}, {"Gop5", 0}, {"Gop6", 0}, {"Gop6", 1}})); REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop2", 0}, {"Gop5", 0}, {"Gop6", 0}})); } @@ -422,7 +408,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } } - SECTION("disconnect data iput + learnable parameters") { + SECTION("disconnect data input + learnable parameters") { std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp deleted file mode 100644 index 5d9c02d5582e3c56aba9d374d7087946c7d94bde..0000000000000000000000000000000000000000 --- a/unit_tests/recipies/Test_FuseBatchNorm.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ -/* -#include <catch2/catch_test_macros.hpp> -#include <set> - - -//#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" -//#include "aidge/backend/cpu/operator/ConvImpl.hpp" - - - -#include "aidge/operator/Conv.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/graph/OpArgs.hpp" -#include "aidge/operator/BatchNorm.hpp" -#include "aidge/utils/Recipies.hpp" - -//#include "aidge/backend/TensorImpl.hpp" -//#include "aidge/backend/cpu.hpp" -//#include "aidge/" - -#include <cstddef> - - -namespace Aidge { - - - TEST_CASE("[FuseBatchNorm] conv") { - auto g1 = Sequential({ - Producer({16, 3, 224, 224}, "dataProvider"), - Conv(3, 32, {3, 3}, "conv1"), - BatchNorm<2>() - }); - - g1->setDataType(DataType::Float32); - g1->setBackend("cpu"); - g1->forwardDims(); - - // std::set<std::string> availableBackends = Tensor::getAvailableBackends(); - // if (availableBackends.find("cpu") != availableBackends.end()){ - // g1->setBackend("cpu"); - // newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); - // }else{ - // printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); - // } - - fuseBatchNorm(g1); - - SECTION("Check resulting nodes") { - // REQUIRE(g1->getNodes().size() == 2); - // REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); - // REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); - // REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); - } - } - -} -*/ \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 0c65db98917e33a11f4b7bac678b271b1a10fb94..968826230dfdf85290ee377aee155e06855c4b28 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -61,7 +61,6 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // Transform GraphView inplace fuseMulAdd(g); - g->save("bonjour"); // Check new GraphView std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); diff --git a/unit_tests/recipies/Test_removeFlatten.cpp b/unit_tests/recipies/Test_removeFlatten.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d0ff29dae19ba2dd8009441c39da53bf44378f0 --- /dev/null +++ b/unit_tests/recipies/Test_removeFlatten.cpp @@ -0,0 +1,49 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/recipies/Recipies.hpp" + +namespace Aidge { + + +TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { + // generate the original GraphView + auto flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten"); + auto fc = FC(10, 50, "myFC"); + + flatten -> addChild(fc); + + auto g = std::make_shared<GraphView>(); + g->add({fc, flatten}); + + // Check original graph + // g -> save("before_remove_flatten"); + + // use recipie + removeFlatten(g); + + // Check transformed graph + // g -> save("after_remove_flatten"); + + REQUIRE(g->getOrderedInputs().size() == 1); + REQUIRE(g->getOrderedOutputs().size() == 1); + REQUIRE(g->getOrderedInputs()[0].first == fc); + REQUIRE(g->getOrderedOutputs()[0].first == fc); +} + +} // namespace Aidge \ No newline at end of file