diff --git a/README.md b/README.md index 992344a796a4634a25d2127fc49b57adeae45863..5b07e147cb05c2fa1a6d275d567dda218b131996 100644 --- a/README.md +++ b/README.md @@ -6,16 +6,19 @@ You can find here the C++ code of the Core library of Aidge. ## Pip installation -To install aidge_core using pip, make sure to set the desired install path : -``` bash -export AIDGE_INSTALL = '<path_to_aidge>/install' -``` -Then run in your python environnement : + +To install aidge_core using pip, run the following command in your python environnement : ``` bash pip install . -v ``` +**Note:** you can specify a custom install folder by setting an environment variable: + +``` bash +export AIDGE_INSTALL='<path_to_aidge>/install' +``` + ## Standard C++ Compilation Create two directories ``build`` and ``ìnstall``. diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index b326e0748c2c77612dd79122fe891a6207d945dc..8898bc5a7ac6ce771cab8402933d464c1f04316f 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -61,5 +61,15 @@ class test_operator_binding(unittest.TestCase): self.generic_operator.add_parameter("l_str", ["ok"]) self.assertEqual(self.generic_operator.get_parameter("l_str"), ["ok"]) + def test_compute_output_dims(self): + in_dims=[25, 25] + input = aidge_core.Producer(in_dims, name="In") + genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp") + _ = aidge_core.sequential([input, genOp]) + self.assertListEqual(genOp.get_operator().output(0).dims(), []) + genOp.get_operator().set_compute_output_dims(lambda x:x) + genOp.get_operator().compute_output_dims() + self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index d10270b62bb75412a6cbd9203b9b7a3fe220e5aa..453e30a8636d86794c96723350bff615af090e3e 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -14,11 +14,13 @@ #include <cstddef> #include <vector> +#include <memory> #include "aidge/utils/Types.h" namespace Aidge { class OperatorImpl { public: + virtual void forward(){}; virtual void backward(){}; diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 7a2b4bac008a82d0454a6dd057d8bf78c7605926..1f1eeafa859b116606613392a13a65ad398669ad 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -348,6 +348,37 @@ public: */ void updateOutputNodes(); + /** + * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> cloneSharedOperators() const { + return cloneCallback(&Node::cloneSharedOperators); + } + + /** + * @brief Clone the GraphView with shared Producers. All the other Operators are copied. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> cloneSharedProducers() const { + return cloneCallback(&Node::cloneSharedProducers); + } + + /** + * @brief Clone the GraphView. Everything is cloned: Nodes and Operators. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> clone() const { + return cloneCallback(&Node::clone); + } + + /** + * @brief Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity). When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an identity operation). + * @param cloneNode Callback function to clone a node + * @return std::shared_ptr<GraphView> + */ + std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const; + private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 340a8318cbd0d59b7710bce7d46b7acd1670dd5b..dbe017fc7f8935e83ff1672392992c75a2e0658c 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -350,6 +350,55 @@ public: */ void resetConnections(bool includeLearnableParam = false); + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + + /** + * @brief Clone the current Node. The Operator attribute of the new Node is not copied but shared with the current Node. The new node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedOperators() const; + + /** + * @brief Clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedProducers() const; + + /** + * @brief Clone the Node and its Operator. The new Node has no connection. + * @return NodePtr + */ + NodePtr clone() const; + + /** + * @brief Callback function to clone the Node keeping the same Operator object instance. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedOperators(NodePtr node) { + return node->cloneSharedOperators(); + } + + /** + * @brief Callback function to clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedProducers(NodePtr node) { + return node->cloneSharedProducers(); + } + + /** + * @brief Callback function to clone the Node and its Operator. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr clone(NodePtr node) { + return node->clone(); + } + private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index ff3d1888c3bc70b61a3d4da42908d40de2d1d73e..303092911ae369473c1f3d6b7f122e3068d77028 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -32,14 +32,13 @@ class Add_Op : public Operator, public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, NUM> mInputs; - const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(shared_from_this()); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char* Type = "Add"; constexpr Add_Op() - : Operator(Type), - mOutput(std::make_shared<Tensor>()) + : Operator(Type) { assert(NUM > 0 && "Add should have at least one input"); for (std::size_t i = 0; i<NUM; ++i) { @@ -48,6 +47,31 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Add_Op(const Add_Op<NUM>& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + assert(NUM > 0 && "Add should have at least one input"); + for (std::size_t i = 0; i<NUM; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Add_Op<NUM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Add_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Add_Op>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index bf76bd45893b43043b81cd6563c500be27c66b42..2fbff53c30e376e80d07f0859851057177bf0868 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -58,11 +58,32 @@ public: : Operator(Type), Parameterizable_(param<AvgPoolingParam::StrideDims>(stride_dims), param<AvgPoolingParam::KernelDims>(kernel_dims), - param<AvgPoolingParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + param<AvgPoolingParam::PaddingDims>(padding_dims)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + AvgPooling_Op(const AvgPooling_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::AvgPooling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<AvgPooling_Op<DIM>>(*this); + } + constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index d808e0ce6fb041f96cbf6bb1b418ba033bf52f82..c95ecac92d14be6d56edd7abda6c20b011e65aba 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -51,11 +51,32 @@ public: constexpr BatchNorm_Op(float epsilon, float momentum) : Operator(Type), Parameterizable_(param<BatchNormParam::Epsilon>(epsilon), - param<BatchNormParam::Momentum>(momentum)), - mOutput(std::make_shared<Tensor>()) { + param<BatchNormParam::Momentum>(momentum)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + BatchNorm_Op(const BatchNorm_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<BatchNorm_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::BatchNorm_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<BatchNorm_Op<DIM>>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 96804f8117d6037fe8e833f44e81c507e099ffff..7113adb41d051d67e22456bbac05c00aa15333ab 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -61,11 +61,32 @@ public: param<ConvParam::InChannels>(in_channels), param<ConvParam::OutChannels>(out_channels), param<ConvParam::KernelDims>(kernel_dims), - param<ConvParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + param<ConvParam::PaddingDims>(padding_dims)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Conv_Op(const Conv_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Conv_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Conv_Op<DIM>>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 95a2ff55b70dbed9299fb3dca98fb9b0e700d210..12d15328cbabbe5b066fa2fb375adecd7935c889 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -66,11 +66,32 @@ class ConvDepthWise_Op : public Operator, param<ConvDepthWiseParam::DilationDims>(dilation_dims), param<ConvDepthWiseParam::Channels>(0), param<ConvDepthWiseParam::KernelDims>(kernel_dims), - param<ConvDepthWiseParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + param<ConvDepthWiseParam::PaddingDims>(padding_dims)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ConvDepthWise_Op(const ConvDepthWise_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<ConvDepthWise_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ConvDepthWise_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ConvDepthWise_Op<DIM>>(*this); + } + constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index db92dc9c735416d250fa32e2f9010b21b8f808c0..73cdab54c2cfade6fbd397d33d537b16cb5245f1 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -51,12 +51,33 @@ public: : Operator(Type), Parameterizable_( param<FCParam::OutChannels>(out_channels), - param<FCParam::NoBias>(noBias)), - mOutput(std::make_shared<Tensor>()) + param<FCParam::NoBias>(noBias)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + FC_Op(const FC_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<FC_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::FC_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<FC_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 12fb7e16741e9f7ad96d51b0b847b91265c3a7d2..184100174714df5fc059e374cb85549f6bfd4135 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -16,6 +16,7 @@ #include <vector> #include <string> #include <cassert> +#include <cstring> #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" @@ -28,12 +29,15 @@ class GenericOperator_Op : public Operator, public Registrable<GenericOperator_Op, std::string, std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)> { private: + using ComputeDimsFunc = std::function<std::vector<std::vector<size_t>>(const std::vector<std::vector<size_t>>&)>; + CParameter mParams; IOIndex_t mNbDataIn; IOIndex_t mNbIn; IOIndex_t mNbOut; std::vector<std::shared_ptr<Tensor>> mInputs; std::vector<std::shared_ptr<Tensor>> mOutputs; + ComputeDimsFunc mComputeOutputDims; public: GenericOperator_Op(const char *type, IOIndex_t nbDataIn, IOIndex_t nbIn, IOIndex_t nbOut) @@ -49,6 +53,32 @@ class GenericOperator_Op } } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + GenericOperator_Op(const GenericOperator_Op& op) + : Operator(op.type().c_str()), mParams(op.mParams), mNbDataIn(op.mNbDataIn), mNbIn(op.mNbIn), mNbOut(op.mNbOut) + { + // cpy-ctor + mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn); + for (std::size_t i = 0; i < mNbIn; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + mOutputs = std::vector<std::shared_ptr<Tensor>>(mNbOut); + for (std::size_t i = 0; i < mNbOut; ++i) { + mOutputs[i] = std::make_shared<Tensor>(*op.mOutputs[i]); + } + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::GenericOperator_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<GenericOperator_Op>(*this); + } + /** * @brief Get the Parameter object identified by its name. * @tparam T expected parameter type. @@ -84,23 +114,55 @@ class GenericOperator_Op mParams.Add<T>(key, std::forward<T>(value)); } + // Helper functions that can be used with setComputeOutputDims(): + static const ComputeDimsFunc Identity; + + void setComputeOutputDims(ComputeDimsFunc func) { + mComputeOutputDims = func; + } std::string getParameterType(std::string const &key) { return mParams.getParamType(key); } std::vector<std::string> getParametersName() { return mParams.getParametersName(); } // Override Virtual Opertor methods - void associateInput(const IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { - printf("Info: using associateInput() on a GenericOperator.\n"); + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < mNbIn && "operators supports only x inputs"); + + if (strcmp(data->type(), Tensor::Type) == 0) { + // TODO: associate input only if of type Tensor, otherwise do nothing + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } } void computeOutputDims() override final { - assert(false && "Cannot compute output dim of a GenericOperator"); + if (mComputeOutputDims) { + std::vector<std::vector<size_t>> inputsDims(mNbIn, std::vector<size_t>()); + for (std::size_t i = 0; i < mNbIn; ++i) { + if (mInputs[i]) { + inputsDims[i] = mInputs[i]->dims(); + } + } + + const auto& outputsDims = mComputeOutputDims(inputsDims); + assert(outputsDims.size() == mNbOut && "The provided ComputeDimsFunc function returns the wrong number of outputs"); + for (std::size_t i = 0; i < mNbOut; ++i) { + mOutputs[i]->resize(outputsDims[i]); + } + } + else { + assert(false && "Cannot compute output dim of a GenericOperator"); + } } bool outputDimsForwarded() const override final { - assert(false && "GenericOperator cannot forward dims"); - return false; + if (mComputeOutputDims) { + return !(mOutputs[0]->empty()); + } + else { + assert(false && "GenericOperator cannot forward dims"); + return false; + } } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 1dff2550a42245351afab5b8bb1a708a8d0d8c0b..dc9548515134a68ad28a8b58213b536cd43fc406 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -53,6 +53,28 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + LeakyReLU_Op(const LeakyReLU_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<LeakyReLU_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::LeakyReLU_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<LeakyReLU_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 77ab2c972a636aefd8aede428c025dba2bc0c545..ff22823fd9c620b133c1f9a1200e463f71b49e92 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -49,12 +49,33 @@ public: MatMul_Op(DimSize_t out_channels) : Operator(Type), Parameterizable_( - param<MatMulParam::OutChannels>(out_channels)), - mOutput(std::make_shared<Tensor>()) + param<MatMulParam::OutChannels>(out_channels)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Matmul_Op(const Matmul_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Matmul_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Matmul_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Matmul_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 2 && "operators supports only 2 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index 073243e801c6e1297129424b0c93b1a7c4f112f3..775583fd4c2132a5474d136c60c1b53b47ea4c3d 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -63,6 +63,28 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + MaxPooling_Op(const MaxPooling_Op<DIM>& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<MaxPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::MaxPooling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<MaxPooling_Op<DIM>>(*this); + } + constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 35a59b56cbf5c10a78116f81de96a8baddc03ff0..9e12b159888923cfea10dd02b7b267a46abcb3b7 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -21,6 +21,25 @@ public: : Operator("MetaOp") { } + + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + MetaOperator(const MetaOperator& op) + : Operator("MetaOp") + { + // cpy-ctor + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Matmul_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<MetaOperator>(*this); + } + ~MetaOperator() = default; }; } diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 122a42a42f38309aa1cd1661324fcc6f5c2d3fcc..3ac651cfd6f700a129e36fb461f948f50137cfd6 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -35,8 +35,18 @@ private: public: Operator() = delete; Operator(const char* type) : mType(type) {} + virtual std::shared_ptr<Operator> clone() const = 0; virtual ~Operator(); + Operator(const Operator& op): + std::enable_shared_from_this<Operator>() + { + mType = op.mType; + mImpl = nullptr; + // Implementation is never cloned. It is up to the non-abstract Operator copy-constructor to create a new implementation matching the copied Operator implementation. + // See https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/merge_requests/8#note_1214050 for the discussion. + // Hooks are not copied. + } public: diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index de885d47c0f9a5b6ffc2b38a05c1bc3e05ac21c3..681ed96892e216094f9392df01f5a10f66609638 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -29,15 +29,14 @@ class Producer_Op public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( const Producer_Op &)> { private: - std::shared_ptr<Tensor> mOutput; + std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char* Type = "Producer"; template <std::size_t DIM> Producer_Op(const std::array<DimSize_t, DIM>& dims) - : Operator(Type), - mOutput(std::make_shared<Tensor>()) + : Operator(Type) { //ctor setDatatype(DataType::Float32); @@ -51,6 +50,27 @@ public: setDatatype(tensor->dataType()); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Producer_Op(const Producer_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Producer_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Producer_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Producer_Op>(*this); + } + void associateInput(const IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { assert(false && "Producer operator takes no input"); } diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 141bd3ae12c7875a90d2549a24e5c141f3ff6aba..cebfa5718886ec26871462f48edcdbc28117da59 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -42,6 +42,27 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ReLU_Op(const ReLU_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<ReLU_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ReLU_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ReLU_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index e158ecd7567eb683558d9e09a6cf03e5cc35ce42..e3cba81a490d3b4b28dd3754df7d274eb2e3519a 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -55,6 +55,28 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Scaling_Op(const Scaling_Op& op) + : Operator(Type), + Parameterizable_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Scaling_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Scaling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Scaling_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); @@ -84,7 +106,7 @@ public: } - inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { assert((inputIdx == 0) && "Scaling Operator has only 1 input"); (void) inputIdx; // avoid unused warning return mInput; diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 64e713b331bbbbf612ee5102ba0ea82fb108350e..ffaf0001fbaadf7dc700fca43d77b9998ab26eb2 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -42,6 +42,27 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Softmax_Op(const Softmax_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Softmax_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Softmax_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Softmax_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning diff --git a/include/aidge/recipies/LabelGraph.hpp b/include/aidge/recipies/LabelGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9dd77e5e9f397260cf936cf77b15616c17ea33b8 --- /dev/null +++ b/include/aidge/recipies/LabelGraph.hpp @@ -0,0 +1,35 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_RECIPIES_LABELGRAPH_H_ +#define AIDGE_RECIPIES_LABELGRAPH_H_ + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge { +NodePtr nodeLabel(NodePtr node); + +/** + * @brief Generate the graph for the pixel-wise labels corresponding to a data graph, taking into account the scaling changes (padding, stride, pooling...). + * @details Right now, the behavior is to replace the following operators: + * - Conv: MaxPooling + * - ConvDepthWie: MaxPooling + * - AvgPooling: MaxPooling + * - MaxPooling: MaxPooling + * - all others: identity (removed) + * @param graph Data graph + * @param return Computing graph for the labels derived from the data graph + */ +std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph); +} // namespace Aidge + +#endif /* AIDGE_RECIPIES_LABELGRAPH_H_ */ diff --git a/include/aidge/utils/CParameter.hpp b/include/aidge/utils/CParameter.hpp index 7d60ed239ae58666833c4ce227aaf16542679036..7246bc3c7555c12402e864f62416b714052320d7 100644 --- a/include/aidge/utils/CParameter.hpp +++ b/include/aidge/utils/CParameter.hpp @@ -14,6 +14,7 @@ #include <map> #include <vector> +#include <string> #include <type_traits> #include <typeinfo> #include <assert.h> @@ -41,11 +42,6 @@ private: throw std::bad_cast(); } public: - // not copyable, not movable - CParameter(CParameter const &) = delete; - CParameter(CParameter &&) = delete; - CParameter &operator=(CParameter const &) = delete; - CParameter &operator=(CParameter &&) = delete; CParameter() : m_Params({}){}; ~CParameter() = default; diff --git a/include/aidge/utils/Parameter.hpp b/include/aidge/utils/Parameter.hpp index b3d137f9d74c8b23ec200055bab4511dc24533d1..2b48e833533da5b8bb4a5f4f134860e89717804a 100644 --- a/include/aidge/utils/Parameter.hpp +++ b/include/aidge/utils/Parameter.hpp @@ -94,6 +94,12 @@ public: (void)p; // avoid unused warning } + Parameterizable(const Parameterizable& params): + mParams(params.mParams) + { + // cpy-ctor (required for Operator cpy-ctor) + } + // Compile-time access with enum template <PARAM_ENUM paramEnum> constexpr typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& get() { diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index bec59eaf2cecdc7f64d1da07580116c4b3334992..dfd2cfedec5aa291f11cf7c2a93d750c3d91145f 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -11,6 +11,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <pybind11/functional.h> #include <stdio.h> #include "aidge/backend/OperatorImpl.hpp" @@ -59,7 +60,11 @@ void init_GenericOperator(py::module& m) { throw py::key_error("Failed to convert parameter type " + key + ", this issue may come from typeid function which gave an unknown key : [" + paramType + "]. Please open an issue asking to add the support for this key."); } return res; - }); + }) + .def_readonly_static("identity", &GenericOperator_Op::Identity) + .def("compute_output_dims", &GenericOperator_Op::computeOutputDims) + .def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function")) + ; m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"), py::arg("name") = ""); diff --git a/setup.py b/setup.py index 0b0f66e9132d66cdb6385d7f8c6c69ae0cc5d0e3..16305afdfdfa5de2e328460d9e96c77eb96a9d98 100644 --- a/setup.py +++ b/setup.py @@ -62,11 +62,11 @@ class CMakeBuild(build_ext): os.chdir(str(build_temp)) - # Impose to use the executable of the python + # Impose to use the executable of the python # used to launch setup.py to setup PythonInterp param_py = "-DPYTHON_EXECUTABLE=" + sys.executable - - install_path = f"{build_temp}/install" if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"] + + install_path = os.path.join(sys.prefix, "lib", "libAidge") if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"] self.spawn(['cmake', str(cwd), param_py, '-DTEST=OFF', f'-DCMAKE_INSTALL_PREFIX:PATH={install_path}']) if not self.dry_run: @@ -83,11 +83,11 @@ class CMakeBuild(build_ext): for file in files: if file.endswith('.so') and (root != str(aidge_package.absolute())): currentFile=os.path.join(root, file) - shutil.copy(currentFile, str(aidge_package.absolute())) + shutil.copy(currentFile, str(aidge_package.absolute())) # Copy version.txt in aidge_package os.chdir(os.path.dirname(__file__)) - shutil.copy("version.txt", str(aidge_package.absolute())) + shutil.copy("version.txt", str(aidge_package.absolute())) if __name__ == '__main__': diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 486a1ffe6cec4f37bb88cbfc5664ce843c4caa2b..03b2a9adb439eb00d0ba59a13fead4f25d617b36 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -519,17 +519,17 @@ void Aidge::GraphView::link(std::string /*name1_inID*/, printf("Not implemented yet.\n"); } -void Aidge::GraphView::insertParent(NodePtr childNode, - NodePtr newParentNode, - IOIndex_t childInputTensorIdx, - IOIndex_t newParentInputTensorIdx, +void Aidge::GraphView::insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx){ NodePtr currentParentNode = childNode->getParent(childInputTensorIdx); const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second; - // Remove child from current parent & current Parent from child + // Remove child from current parent & current Parent from child currentParentNode->removeChild(childNode, currentParentOutputTensorIdx); - // Add child + // Add child currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx); newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx); @@ -679,3 +679,54 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { } } } + +std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { + std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); + + // Map for old node -> new node correspondance + std::map<NodePtr, NodePtr> oldToNewNodes; + + for (const std::shared_ptr<Node> &node_ptr : mNodes) { + oldToNewNodes[node_ptr] = cloneNode(node_ptr); + } + + // For each node, convert old node -> new node connections + for (auto &oldToNewNode : oldToNewNodes) { + if (oldToNewNode.second == nullptr) + continue; // deleted node + + // Add new node to new GraphView + newGraph->add(oldToNewNode.second, false); + + // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr + size_t parentId = 0; + for (auto parent : oldToNewNode.first->inputs()) { + while (oldToNewNodes[parent.first] == nullptr) { + // Find next valid parent in line, going backward in the graph + assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); + const auto& parents = parent.first->inputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + parent = parents[0]; + } + else { + break; + } + } + + if (oldToNewNodes[parent.first]) { + oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); + } + + ++parentId; + } + } + + // Update OutputNodes/inputNodes + newGraph->updateInputNodes(); + newGraph->updateOutputNodes(); + + return newGraph; +} diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index abf572831d8f0b5c2c5eb836ea46e05b8114da55..54fdac808642f3ae603e237737e265ba394fccbd 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -321,6 +321,26 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { } } + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + +Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { + return std::make_shared<Node>(mOperator, mName); +} + +Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { + std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) + ? mOperator + : mOperator->clone(); + + return std::make_shared<Node>(op, mName); +} + +Aidge::NodePtr Aidge::Node::clone() const { + return std::make_shared<Node>(mOperator->clone(), mName); +} + ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..192036651cfbe2df71139dd63ca3d71f07300964 --- /dev/null +++ b/src/operator/GenericOperator.cpp @@ -0,0 +1,17 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <vector> + +#include "aidge/operator/GenericOperator.hpp" + +const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Identity + = [](const std::vector<std::vector<size_t>>& inputsDims) { return inputsDims; }; diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ac2cbf6ca65c7ecbced9596efb71c2052405984 --- /dev/null +++ b/src/recipies/LabelGraph.cpp @@ -0,0 +1,56 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/recipies/LabelGraph.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" + +Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { + // Conv => MaxPooling + if (node->type() == Conv_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvParam::KernelDims>(), op->get<ConvParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // ConvDepthWise => MaxPooling + if (node->type() == ConvDepthWise_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<ConvDepthWise_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvDepthWiseParam::KernelDims>(), op->get<ConvDepthWiseParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // AvgPooling => MaxPooling + if (node->type() == AvgPooling_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<AvgPooling_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<AvgPoolingParam::KernelDims>(), op->get<AvgPoolingParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // MaxPooling => MaxPooling + if (node->type() == MaxPooling_Op<2>::Type) { + return node->clone(); + } + + // By default, remove the node from the graph + return nullptr; +} + +std::shared_ptr<Aidge::GraphView> Aidge::labelGraph(std::shared_ptr<GraphView> graph) { + return graph->cloneCallback(&nodeLabel); +} diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 319370ebad95869efd450eade58a2ecd36075090..4b929286ba494a452c7f9cb71ce944c7d576c03a 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -332,6 +332,234 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { } } +TEST_CASE("[GraphView] clone") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("clone_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->clone(); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("clone_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check new connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) != g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) != g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) != g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) != g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) != g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) != g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + +TEST_CASE("[GraphView] cloneSharedProducers") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedProducers_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedProducers(); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("cloneSharedProducers_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check new connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + +TEST_CASE("[GraphView] cloneSharedOperators") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedOperators_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedOperators(); + g2->forwardDims(); + g2->save("cloneSharedOperators_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + + TEST_CASE("[core/graph] GraphView(insertParent)") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); @@ -352,7 +580,7 @@ TEST_CASE("[core/graph] GraphView(insertParent)") { std::set<NodePtr> expectedConv1Children = {conv3, newConv}; std::set<NodePtr> expectedNewConvChildren = {conv2}; - + REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); @@ -374,4 +602,4 @@ TEST_CASE("[core/graph] GraphView(insertParent)") { REQUIRE((conv1->getChildren()) == expectedConv1Children2); } -} \ No newline at end of file +} diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..873ad68f3198c6b6adf44d8c7ae31e667c63a18d --- /dev/null +++ b/unit_tests/recipies/Test_LabelGraph.cpp @@ -0,0 +1,154 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/recipies/LabelGraph.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[LabelGraph] conv") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("LabelGraph_conv_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_conv_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } +} + +TEST_CASE("[LabelGraph] deleted node") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + Conv(64, 10, {1, 1}, "conv3", {2, 2}) + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleted_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 220, 220})); + REQUIRE(g2->getNode("conv3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 110, 110})); + } +} + +TEST_CASE("[LabelGraph] deleted nodes") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(64, 10, {1, 1}, "conv3") + }); + + g1->save("LabelGraph_deleteds_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleteds_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } +} + +TEST_CASE("[LabelGraph] pooling") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + AvgPooling({2, 2}, "pool1"), + MaxPooling({2, 2}, "pool2"), + MaxPooling({2, 2}, "pool3", {2, 2}) + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("pool1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_pooling"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("pool1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0) == g2->getNode("pool2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0) == g2->getNode("pool3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 223, 223})); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("pool3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 111, 111})); + } +}