diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index d10270b62bb75412a6cbd9203b9b7a3fe220e5aa..453e30a8636d86794c96723350bff615af090e3e 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -14,11 +14,13 @@ #include <cstddef> #include <vector> +#include <memory> #include "aidge/utils/Types.h" namespace Aidge { class OperatorImpl { public: + virtual void forward(){}; virtual void backward(){}; diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 7a2b4bac008a82d0454a6dd057d8bf78c7605926..1f1eeafa859b116606613392a13a65ad398669ad 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -348,6 +348,37 @@ public: */ void updateOutputNodes(); + /** + * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> cloneSharedOperators() const { + return cloneCallback(&Node::cloneSharedOperators); + } + + /** + * @brief Clone the GraphView with shared Producers. All the other Operators are copied. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> cloneSharedProducers() const { + return cloneCallback(&Node::cloneSharedProducers); + } + + /** + * @brief Clone the GraphView. Everything is cloned: Nodes and Operators. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> clone() const { + return cloneCallback(&Node::clone); + } + + /** + * @brief Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity). When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an identity operation). + * @param cloneNode Callback function to clone a node + * @return std::shared_ptr<GraphView> + */ + std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const; + private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 340a8318cbd0d59b7710bce7d46b7acd1670dd5b..dbe017fc7f8935e83ff1672392992c75a2e0658c 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -350,6 +350,55 @@ public: */ void resetConnections(bool includeLearnableParam = false); + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + + /** + * @brief Clone the current Node. The Operator attribute of the new Node is not copied but shared with the current Node. The new node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedOperators() const; + + /** + * @brief Clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedProducers() const; + + /** + * @brief Clone the Node and its Operator. The new Node has no connection. + * @return NodePtr + */ + NodePtr clone() const; + + /** + * @brief Callback function to clone the Node keeping the same Operator object instance. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedOperators(NodePtr node) { + return node->cloneSharedOperators(); + } + + /** + * @brief Callback function to clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedProducers(NodePtr node) { + return node->cloneSharedProducers(); + } + + /** + * @brief Callback function to clone the Node and its Operator. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr clone(NodePtr node) { + return node->clone(); + } + private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index ff3d1888c3bc70b61a3d4da42908d40de2d1d73e..303092911ae369473c1f3d6b7f122e3068d77028 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -32,14 +32,13 @@ class Add_Op : public Operator, public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, NUM> mInputs; - const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(shared_from_this()); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char* Type = "Add"; constexpr Add_Op() - : Operator(Type), - mOutput(std::make_shared<Tensor>()) + : Operator(Type) { assert(NUM > 0 && "Add should have at least one input"); for (std::size_t i = 0; i<NUM; ++i) { @@ -48,6 +47,31 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Add_Op(const Add_Op<NUM>& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + assert(NUM > 0 && "Add should have at least one input"); + for (std::size_t i = 0; i<NUM; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Add_Op<NUM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Add_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Add_Op>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index bf76bd45893b43043b81cd6563c500be27c66b42..2fbff53c30e376e80d07f0859851057177bf0868 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -58,11 +58,32 @@ public: : Operator(Type), Parameterizable_(param<AvgPoolingParam::StrideDims>(stride_dims), param<AvgPoolingParam::KernelDims>(kernel_dims), - param<AvgPoolingParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + param<AvgPoolingParam::PaddingDims>(padding_dims)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + AvgPooling_Op(const AvgPooling_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::AvgPooling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<AvgPooling_Op<DIM>>(*this); + } + constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 6861c1359737f3f344f0c7d9b2d12c9ff35b88ad..f1a6ae8f52141839f72211f23511a0607e2138b6 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -51,11 +51,32 @@ public: constexpr BatchNorm_Op(float epsilon, float momentum) : Operator(Type), Parameterizable_(param<BatchNormParam::Epsilon>(epsilon), - param<BatchNormParam::Momentum>(momentum)), - mOutput(std::make_shared<Tensor>()) { + param<BatchNormParam::Momentum>(momentum)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + BatchNorm_Op(const BatchNorm_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<BatchNorm_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::BatchNorm_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<BatchNorm_Op<DIM>>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 1edc94b96763cc163646037a8bd069023511df67..e95b46ae5583df9e6b471dc4005d0d9c4636ca9b 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -61,11 +61,32 @@ public: param<ConvParam::InChannels>(in_channels), param<ConvParam::OutChannels>(out_channels), param<ConvParam::KernelDims>(kernel_dims), - param<ConvParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + param<ConvParam::PaddingDims>(padding_dims)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Conv_Op(const Conv_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Conv_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Conv_Op<DIM>>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 95a2ff55b70dbed9299fb3dca98fb9b0e700d210..12d15328cbabbe5b066fa2fb375adecd7935c889 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -66,11 +66,32 @@ class ConvDepthWise_Op : public Operator, param<ConvDepthWiseParam::DilationDims>(dilation_dims), param<ConvDepthWiseParam::Channels>(0), param<ConvDepthWiseParam::KernelDims>(kernel_dims), - param<ConvDepthWiseParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + param<ConvDepthWiseParam::PaddingDims>(padding_dims)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ConvDepthWise_Op(const ConvDepthWise_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<ConvDepthWise_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ConvDepthWise_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ConvDepthWise_Op<DIM>>(*this); + } + constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index db92dc9c735416d250fa32e2f9010b21b8f808c0..73cdab54c2cfade6fbd397d33d537b16cb5245f1 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -51,12 +51,33 @@ public: : Operator(Type), Parameterizable_( param<FCParam::OutChannels>(out_channels), - param<FCParam::NoBias>(noBias)), - mOutput(std::make_shared<Tensor>()) + param<FCParam::NoBias>(noBias)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + FC_Op(const FC_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<FC_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::FC_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<FC_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 12fb7e16741e9f7ad96d51b0b847b91265c3a7d2..184100174714df5fc059e374cb85549f6bfd4135 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -16,6 +16,7 @@ #include <vector> #include <string> #include <cassert> +#include <cstring> #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" @@ -28,12 +29,15 @@ class GenericOperator_Op : public Operator, public Registrable<GenericOperator_Op, std::string, std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)> { private: + using ComputeDimsFunc = std::function<std::vector<std::vector<size_t>>(const std::vector<std::vector<size_t>>&)>; + CParameter mParams; IOIndex_t mNbDataIn; IOIndex_t mNbIn; IOIndex_t mNbOut; std::vector<std::shared_ptr<Tensor>> mInputs; std::vector<std::shared_ptr<Tensor>> mOutputs; + ComputeDimsFunc mComputeOutputDims; public: GenericOperator_Op(const char *type, IOIndex_t nbDataIn, IOIndex_t nbIn, IOIndex_t nbOut) @@ -49,6 +53,32 @@ class GenericOperator_Op } } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + GenericOperator_Op(const GenericOperator_Op& op) + : Operator(op.type().c_str()), mParams(op.mParams), mNbDataIn(op.mNbDataIn), mNbIn(op.mNbIn), mNbOut(op.mNbOut) + { + // cpy-ctor + mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn); + for (std::size_t i = 0; i < mNbIn; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + mOutputs = std::vector<std::shared_ptr<Tensor>>(mNbOut); + for (std::size_t i = 0; i < mNbOut; ++i) { + mOutputs[i] = std::make_shared<Tensor>(*op.mOutputs[i]); + } + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::GenericOperator_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<GenericOperator_Op>(*this); + } + /** * @brief Get the Parameter object identified by its name. * @tparam T expected parameter type. @@ -84,23 +114,55 @@ class GenericOperator_Op mParams.Add<T>(key, std::forward<T>(value)); } + // Helper functions that can be used with setComputeOutputDims(): + static const ComputeDimsFunc Identity; + + void setComputeOutputDims(ComputeDimsFunc func) { + mComputeOutputDims = func; + } std::string getParameterType(std::string const &key) { return mParams.getParamType(key); } std::vector<std::string> getParametersName() { return mParams.getParametersName(); } // Override Virtual Opertor methods - void associateInput(const IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { - printf("Info: using associateInput() on a GenericOperator.\n"); + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < mNbIn && "operators supports only x inputs"); + + if (strcmp(data->type(), Tensor::Type) == 0) { + // TODO: associate input only if of type Tensor, otherwise do nothing + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } } void computeOutputDims() override final { - assert(false && "Cannot compute output dim of a GenericOperator"); + if (mComputeOutputDims) { + std::vector<std::vector<size_t>> inputsDims(mNbIn, std::vector<size_t>()); + for (std::size_t i = 0; i < mNbIn; ++i) { + if (mInputs[i]) { + inputsDims[i] = mInputs[i]->dims(); + } + } + + const auto& outputsDims = mComputeOutputDims(inputsDims); + assert(outputsDims.size() == mNbOut && "The provided ComputeDimsFunc function returns the wrong number of outputs"); + for (std::size_t i = 0; i < mNbOut; ++i) { + mOutputs[i]->resize(outputsDims[i]); + } + } + else { + assert(false && "Cannot compute output dim of a GenericOperator"); + } } bool outputDimsForwarded() const override final { - assert(false && "GenericOperator cannot forward dims"); - return false; + if (mComputeOutputDims) { + return !(mOutputs[0]->empty()); + } + else { + assert(false && "GenericOperator cannot forward dims"); + return false; + } } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 1dff2550a42245351afab5b8bb1a708a8d0d8c0b..dc9548515134a68ad28a8b58213b536cd43fc406 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -53,6 +53,28 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + LeakyReLU_Op(const LeakyReLU_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<LeakyReLU_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::LeakyReLU_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<LeakyReLU_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/Matmul.hpp b/include/aidge/operator/Matmul.hpp index 639b366912060b3e085510f312d94568e6b65f03..54bbcb267f346fd79a2b9e3a8aca571ed2e6ba91 100644 --- a/include/aidge/operator/Matmul.hpp +++ b/include/aidge/operator/Matmul.hpp @@ -49,12 +49,33 @@ public: Matmul_Op(DimSize_t out_channels) : Operator(Type), Parameterizable_( - param<MatmulParam::OutChannels>(out_channels)), - mOutput(std::make_shared<Tensor>()) + param<MatmulParam::OutChannels>(out_channels)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Matmul_Op(const Matmul_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Matmul_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Matmul_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Matmul_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 2 && "operators supports only 2 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index 073243e801c6e1297129424b0c93b1a7c4f112f3..775583fd4c2132a5474d136c60c1b53b47ea4c3d 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -63,6 +63,28 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + MaxPooling_Op(const MaxPooling_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<MaxPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::MaxPooling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<MaxPooling_Op<DIM>>(*this); + } + constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 35a59b56cbf5c10a78116f81de96a8baddc03ff0..9e12b159888923cfea10dd02b7b267a46abcb3b7 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -21,6 +21,25 @@ public: : Operator("MetaOp") { } + + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + MetaOperator(const MetaOperator& op) + : Operator("MetaOp") + { + // cpy-ctor + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Matmul_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<MetaOperator>(*this); + } + ~MetaOperator() = default; }; } diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 122a42a42f38309aa1cd1661324fcc6f5c2d3fcc..3ac651cfd6f700a129e36fb461f948f50137cfd6 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -35,8 +35,18 @@ private: public: Operator() = delete; Operator(const char* type) : mType(type) {} + virtual std::shared_ptr<Operator> clone() const = 0; virtual ~Operator(); + Operator(const Operator& op): + std::enable_shared_from_this<Operator>() + { + mType = op.mType; + mImpl = nullptr; + // Implementation is never cloned. It is up to the non-abstract Operator copy-constructor to create a new implementation matching the copied Operator implementation. + // See https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/merge_requests/8#note_1214050 for the discussion. + // Hooks are not copied. + } public: diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index acdc69b69ab86b25a11d889980b9236e41928316..fbab24a0d23712b138c41e969372701fdb3d749e 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -29,15 +29,14 @@ class Producer_Op public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( const Producer_Op &)> { private: - std::shared_ptr<Tensor> mOutput; + std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char* Type = "Producer"; template <std::size_t DIM> Producer_Op(const std::array<DimSize_t, DIM>& dims) - : Operator(Type), - mOutput(std::make_shared<Tensor>()) + : Operator(Type) { //ctor setDatatype(DataType::Float32); @@ -51,6 +50,27 @@ public: setDatatype(tensor->dataType()); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Producer_Op(const Producer_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Producer_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Producer_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Producer_Op>(*this); + } + void associateInput(const IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { assert(false && "Producer operator takes no input"); } diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 141bd3ae12c7875a90d2549a24e5c141f3ff6aba..cebfa5718886ec26871462f48edcdbc28117da59 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -42,6 +42,27 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ReLU_Op(const ReLU_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<ReLU_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ReLU_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ReLU_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index e158ecd7567eb683558d9e09a6cf03e5cc35ce42..e3cba81a490d3b4b28dd3754df7d274eb2e3519a 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -55,6 +55,28 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Scaling_Op(const Scaling_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Scaling_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Scaling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Scaling_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); @@ -84,7 +106,7 @@ public: } - inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { assert((inputIdx == 0) && "Scaling Operator has only 1 input"); (void) inputIdx; // avoid unused warning return mInput; diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 64e713b331bbbbf612ee5102ba0ea82fb108350e..ffaf0001fbaadf7dc700fca43d77b9998ab26eb2 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -42,6 +42,27 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Softmax_Op(const Softmax_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Softmax_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Softmax_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Softmax_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/recipies/LabelGraph.hpp b/include/aidge/recipies/LabelGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9dd77e5e9f397260cf936cf77b15616c17ea33b8 --- /dev/null +++ b/include/aidge/recipies/LabelGraph.hpp @@ -0,0 +1,35 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_RECIPIES_LABELGRAPH_H_ +#define AIDGE_RECIPIES_LABELGRAPH_H_ + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge { +NodePtr nodeLabel(NodePtr node); + +/** + * @brief Generate the graph for the pixel-wise labels corresponding to a data graph, taking into account the scaling changes (padding, stride, pooling...). + * @details Right now, the behavior is to replace the following operators: + * - Conv: MaxPooling + * - ConvDepthWie: MaxPooling + * - AvgPooling: MaxPooling + * - MaxPooling: MaxPooling + * - all others: identity (removed) + * @param graph Data graph + * @param return Computing graph for the labels derived from the data graph + */ +std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph); +} // namespace Aidge + +#endif /* AIDGE_RECIPIES_LABELGRAPH_H_ */ diff --git a/include/aidge/utils/CParameter.hpp b/include/aidge/utils/CParameter.hpp index 7d60ed239ae58666833c4ce227aaf16542679036..7246bc3c7555c12402e864f62416b714052320d7 100644 --- a/include/aidge/utils/CParameter.hpp +++ b/include/aidge/utils/CParameter.hpp @@ -14,6 +14,7 @@ #include <map> #include <vector> +#include <string> #include <type_traits> #include <typeinfo> #include <assert.h> @@ -41,11 +42,6 @@ private: throw std::bad_cast(); } public: - // not copyable, not movable - CParameter(CParameter const &) = delete; - CParameter(CParameter &&) = delete; - CParameter &operator=(CParameter const &) = delete; - CParameter &operator=(CParameter &&) = delete; CParameter() : m_Params({}){}; ~CParameter() = default; diff --git a/include/aidge/utils/Parameter.hpp b/include/aidge/utils/Parameter.hpp index b0c6e35950187f17d991cfe5b2c9bd2b09f1e70f..a475576170915182e25dbaa193ca8a7a3853c0e0 100644 --- a/include/aidge/utils/Parameter.hpp +++ b/include/aidge/utils/Parameter.hpp @@ -94,6 +94,12 @@ public: (void)p; // avoid unused warning } + Parameterizable(const Parameterizable& params): + mParams(params.mParams) + { + // cpy-ctor (required for Operator cpy-ctor) + } + // Compile-time access with enum template <PARAM_ENUM paramEnum> constexpr typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& get() { diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index e50ea20e680e6ab874b14c14b23f77b27286f367..bbf895285e0e00d1132eb1f46c7e67a455d705d7 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -682,4 +682,55 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { mOutputNodes.erase(val); } } -} \ No newline at end of file +} + +std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { + std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); + + // Map for old node -> new node correspondance + std::map<NodePtr, NodePtr> oldToNewNodes; + + for (const std::shared_ptr<Node> &node_ptr : mNodes) { + oldToNewNodes[node_ptr] = cloneNode(node_ptr); + } + + // For each node, convert old node -> new node connections + for (auto &oldToNewNode : oldToNewNodes) { + if (oldToNewNode.second == nullptr) + continue; // deleted node + + // Add new node to new GraphView + newGraph->add(oldToNewNode.second, false); + + // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr + size_t parentId = 0; + for (auto parent : oldToNewNode.first->inputs()) { + while (oldToNewNodes[parent.first] == nullptr) { + // Find next valid parent in line, going backward in the graph + assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); + const auto& parents = parent.first->inputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + parent = parents[0]; + } + else { + break; + } + } + + if (oldToNewNodes[parent.first]) { + oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); + } + + ++parentId; + } + } + + // Update OutputNodes/inputNodes + newGraph->updateInputNodes(); + newGraph->updateOutputNodes(); + + return newGraph; +} diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index abf572831d8f0b5c2c5eb836ea46e05b8114da55..54fdac808642f3ae603e237737e265ba394fccbd 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -321,6 +321,26 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { } } + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + +Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { + return std::make_shared<Node>(mOperator, mName); +} + +Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { + std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) + ? mOperator + : mOperator->clone(); + + return std::make_shared<Node>(op, mName); +} + +Aidge::NodePtr Aidge::Node::clone() const { + return std::make_shared<Node>(mOperator->clone(), mName); +} + ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..192036651cfbe2df71139dd63ca3d71f07300964 --- /dev/null +++ b/src/operator/GenericOperator.cpp @@ -0,0 +1,17 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <vector> + +#include "aidge/operator/GenericOperator.hpp" + +const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Identity + = [](const std::vector<std::vector<size_t>>& inputsDims) { return inputsDims; }; diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ac2cbf6ca65c7ecbced9596efb71c2052405984 --- /dev/null +++ b/src/recipies/LabelGraph.cpp @@ -0,0 +1,56 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/recipies/LabelGraph.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" + +Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { + // Conv => MaxPooling + if (node->type() == Conv_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvParam::KernelDims>(), op->get<ConvParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // ConvDepthWise => MaxPooling + if (node->type() == ConvDepthWise_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<ConvDepthWise_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvDepthWiseParam::KernelDims>(), op->get<ConvDepthWiseParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // AvgPooling => MaxPooling + if (node->type() == AvgPooling_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<AvgPooling_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<AvgPoolingParam::KernelDims>(), op->get<AvgPoolingParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // MaxPooling => MaxPooling + if (node->type() == MaxPooling_Op<2>::Type) { + return node->clone(); + } + + // By default, remove the node from the graph + return nullptr; +} + +std::shared_ptr<Aidge::GraphView> Aidge::labelGraph(std::shared_ptr<GraphView> graph) { + return graph->cloneCallback(&nodeLabel); +} diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 319370ebad95869efd450eade58a2ecd36075090..4b929286ba494a452c7f9cb71ce944c7d576c03a 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -332,6 +332,234 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { } } +TEST_CASE("[GraphView] clone") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("clone_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->clone(); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("clone_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check new connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) != g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) != g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) != g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) != g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) != g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) != g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + +TEST_CASE("[GraphView] cloneSharedProducers") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedProducers_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedProducers(); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("cloneSharedProducers_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check new connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + +TEST_CASE("[GraphView] cloneSharedOperators") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedOperators_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedOperators(); + g2->forwardDims(); + g2->save("cloneSharedOperators_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + + TEST_CASE("[core/graph] GraphView(insertParent)") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); @@ -352,7 +580,7 @@ TEST_CASE("[core/graph] GraphView(insertParent)") { std::set<NodePtr> expectedConv1Children = {conv3, newConv}; std::set<NodePtr> expectedNewConvChildren = {conv2}; - + REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); @@ -374,4 +602,4 @@ TEST_CASE("[core/graph] GraphView(insertParent)") { REQUIRE((conv1->getChildren()) == expectedConv1Children2); } -} \ No newline at end of file +} diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..873ad68f3198c6b6adf44d8c7ae31e667c63a18d --- /dev/null +++ b/unit_tests/recipies/Test_LabelGraph.cpp @@ -0,0 +1,154 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/recipies/LabelGraph.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[LabelGraph] conv") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("LabelGraph_conv_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_conv_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } +} + +TEST_CASE("[LabelGraph] deleted node") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + Conv(64, 10, {1, 1}, "conv3", {2, 2}) + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleted_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 220, 220})); + REQUIRE(g2->getNode("conv3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 110, 110})); + } +} + +TEST_CASE("[LabelGraph] deleted nodes") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(64, 10, {1, 1}, "conv3") + }); + + g1->save("LabelGraph_deleteds_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleteds_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } +} + +TEST_CASE("[LabelGraph] pooling") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + AvgPooling({2, 2}, "pool1"), + MaxPooling({2, 2}, "pool2"), + MaxPooling({2, 2}, "pool3", {2, 2}) + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("pool1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_pooling"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("pool1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0) == g2->getNode("pool2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0) == g2->getNode("pool3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 223, 223})); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("pool3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 111, 111})); + } +}