From 2ac231d8a2e633b4439607a6065ebe919c264d29 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 19 Oct 2023 08:35:44 +0000 Subject: [PATCH] [Upd] Add Operator to support an undefined number of inputs. Remove templates --- include/aidge/aidge.hpp | 2 +- include/aidge/operator/Add.hpp | 66 ++++++++++++------------- python_binding/operator/pybind_Add.cpp | 12 ++--- unit_tests/recipies/Test_FuseMulAdd.cpp | 4 +- 4 files changed, 41 insertions(+), 43 deletions(-) diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 16fa9967c..8a1b50a0e 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -33,7 +33,7 @@ #include "aidge/operator/Add.hpp" #include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/BatchNorm.hpp" -#include "aidge/operator/Concat.hpp" +// #include "aidge/operator/Concat.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/FC.hpp" diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 65c7e8ce0..ceb058dbd 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -16,7 +16,7 @@ #include <vector> #include <cmath> #include <memory> -#include <array> +#include <vector> #include "aidge/utils/Registrar.hpp" #include "aidge/operator/Operator.hpp" @@ -26,24 +26,23 @@ namespace Aidge { -template <std::size_t NUM> class Add_Op : public Operator, - public Registrable<Add_Op<NUM>, std::string, std::unique_ptr<OperatorImpl>(const Add_Op<NUM>&)> { -public: + public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> { +private: // FIXME: change accessibility - std::array<std::shared_ptr<Tensor>, NUM> mInputs; + std::vector<std::shared_ptr<Tensor>> mInputs; const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + const IOIndex_t mNbInputs; public: static constexpr const char* Type = "Add"; - constexpr Add_Op() - : Operator(Type) + Add_Op(const IOIndex_t nbIn) + : Operator(Type), + mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())), + mNbInputs(nbIn) { - assert(NUM > 0 && "Add should have at least one input"); - for (std::size_t i = 0; i<NUM; ++i) { - mInputs[i] = std::make_shared<Tensor>(); - } + assert(nbIn > 0 && "Add should have at least one input"); setDatatype(DataType::Float32); } @@ -51,17 +50,16 @@ public: * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Add_Op(const Add_Op<NUM>& op) + Add_Op(const Add_Op& op) : Operator(Type), + mInputs(op.mInputs), + mNbInputs(op.mNbInputs), mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor - assert(NUM > 0 && "Add should have at least one input"); - for (std::size_t i = 0; i<NUM; ++i) { - mInputs[i] = std::make_shared<Tensor>(); - } + assert(mNbInputs > 0 && "Add should have at least one input"); setDatatype(op.mOutput->dataType()); - mImpl = op.mImpl ? Registrar<Add_Op<NUM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + mImpl = op.mImpl ? Registrar<Add_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; } /** @@ -82,7 +80,7 @@ public: // } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -92,10 +90,10 @@ public: if (!mInputs[0]->empty()) { const auto expectedDims = mInputs[0]->dims(); std::size_t nonEmptyInputTensor = 1; - for (; nonEmptyInputTensor<NUM && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { + for (; nonEmptyInputTensor < mNbInputs && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { assert(expectedDims == mInputs[nonEmptyInputTensor]->dims()); } - if (nonEmptyInputTensor == NUM) { + if (nonEmptyInputTensor == mNbInputs) { mOutput->resize(expectedDims); } } @@ -103,8 +101,8 @@ public: bool outputDimsForwarded() const override final { std::size_t forwarded = 0; - for (; forwarded < NUM && (!mInputs[forwarded]->empty()); ++forwarded) {} - return ((forwarded==NUM) && !(mOutput->empty())); + for (; forwarded < mNbInputs && (!mInputs[forwarded]->empty()); ++forwarded) {} + return ((forwarded==mNbInputs) && !(mOutput->empty())); } // void checkDims() const override final { @@ -114,13 +112,13 @@ public: // } // } inline Tensor& input(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); return *(mInputs[inputIdx].get()); } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); return mInputs[inputIdx]; } inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { @@ -130,7 +128,7 @@ public: } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); return std::static_pointer_cast<Data>(mInputs[inputIdx]); } std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { @@ -141,11 +139,11 @@ public: void setBackend(const std::string& name) override { - mImpl = Registrar<Add_Op<NUM>>::create(name)(*this); + mImpl = Registrar<Add_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround - for (std::size_t i = 0; i < NUM; ++i) { + for (std::size_t i = 0; i < mNbInputs; ++i) { mInputs[i]->setBackend(name); } } @@ -154,15 +152,16 @@ public: mOutput->setDatatype(datatype); // FIXME: temporary workaround - for (std::size_t i = 0; i < NUM; ++i) { + for (std::size_t i = 0; i < mNbInputs; ++i) { mInputs[i]->setDatatype(datatype); } } - inline IOIndex_t nbInputs() const noexcept override final { return NUM; } - inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; } + inline IOIndex_t nbInputs() const noexcept override final { return mNbInputs; } + inline IOIndex_t nbDataInputs() const noexcept override final { return mNbInputs; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } - static const std::vector<std::string> getInputsName(){ + + static const std::vector<std::string> getInputsName(){ return {"data_input_0", "data_input_n"}; } static const std::vector<std::string> getOutputsName(){ @@ -170,9 +169,8 @@ public: } }; -template <std::size_t NUM> -inline std::shared_ptr<Node> Add(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Add_Op<NUM>>(), name); +inline std::shared_ptr<Node> Add(const IOIndex_t nbIn, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Add_Op>(nbIn), name); } } diff --git a/python_binding/operator/pybind_Add.cpp b/python_binding/operator/pybind_Add.cpp index 0b2323c5c..bff795a73 100644 --- a/python_binding/operator/pybind_Add.cpp +++ b/python_binding/operator/pybind_Add.cpp @@ -19,15 +19,15 @@ namespace py = pybind11; namespace Aidge { -template <std::size_t NUM> void declare_Add(py::module &m) { - py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "AddOp", py::multiple_inheritance()) - .def("get_inputs_name", &Add_Op<NUM>::getInputsName) - .def("get_outputs_name", &Add_Op<NUM>::getOutputsName); +void declare_Add(py::module &m) { + py::class_<Add_Op, std::shared_ptr<Add_Op>, Operator>(m, "AddOp", py::multiple_inheritance()) + .def("get_inputs_name", &Add_Op::getInputsName) + .def("get_outputs_name", &Add_Op::getOutputsName); - m.def("Add", &Add<NUM>, py::arg("name") = ""); + m.def("Add", &Add, py::arg("nbIn"), py::arg("name") = ""); } void init_Add(py::module &m) { - declare_Add<2>(m); + declare_Add(m); } } // namespace Aidge diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 92b2b7c13..7e64f3ff5 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -25,9 +25,9 @@ namespace Aidge { TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView auto matmul0 = MatMul(5, "matmul0"); - auto add0 = Add<2>("add0"); + auto add0 = Add(2, "add0"); auto matmul1 = MatMul(5, "matmul1"); - auto add1 = Add<2>("add1"); + auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); auto w0 = Producer({5, 5}, "W0"); -- GitLab