From db487d8952444e5deff18e8b77738568181dc662 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 6 Dec 2023 16:56:37 +0100 Subject: [PATCH] remove matmul attrs --- include/aidge/operator/MatMul.hpp | 55 ++++++++--------------- python_binding/operator/pybind_Matmul.cpp | 10 ++--- 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 3d80193be..ffb01bc6a 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -27,34 +27,21 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -enum class MatMulAttr { OutChannels }; class MatMul_Op : public OperatorTensor, public Registrable<MatMul_Op, std::string, - std::unique_ptr<OperatorImpl>(const MatMul_Op &)>, - public StaticAttributes<MatMulAttr, DimSize_t> { + std::unique_ptr<OperatorImpl>(const MatMul_Op &)> { public: static const std::string Type; - MatMul_Op() = delete; - - using Attributes_ = StaticAttributes<MatMulAttr, DimSize_t>; - template <MatMulAttr e> using attr = typename Attributes_::template attr<e>; - - MatMul_Op(DimSize_t out_channels) - : OperatorTensor(Type, 1, 1, 1), - Attributes_( - attr<MatMulAttr::OutChannels>(out_channels)) - {} + MatMul_Op(): OperatorTensor(Type, 2, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - MatMul_Op(const MatMul_Op& op) - : OperatorTensor(op), - Attributes_(op) + MatMul_Op(const MatMul_Op& op) : OperatorTensor(op) { mImpl = op.mImpl ? Registrar<MatMul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -69,16 +56,20 @@ public: void computeOutputDims() override final { - bool associated = true; - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); + if (!getInput(0)->empty() && !getInput(1)->empty()) + { + std::vector<std::size_t> outDims; + for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++) + { + outDims.push_back(getInput(0)->dims()[i]); } - associated &= !(getInput(i)->empty()); - } - if (associated) { - // <batch, OutChannels> - mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<MatMulAttr::OutChannels>()}); + size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0; + for (std::size_t i = 0; i < getInput(1)->nbDims(); i++) + { + if(i != secondToLastIdx) + outDims.push_back(getInput(1)->dims()[i]); + } + mOutputs[0]->resize(outDims); } } @@ -89,24 +80,16 @@ public: } static const std::vector<std::string> getInputsName(){ - return {"data_input", "weight"}; + return {"data_input1", "data_input2"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } }; -inline std::shared_ptr<Node> MatMul(DimSize_t inChannels, DimSize_t outChannels, const std::string& name = "") { - // FIXME: properly handle default w initialization in every cases - auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(outChannels), name); - addProducer(matmul, 1, {outChannels, inChannels}, "w"); - return matmul; +inline std::shared_ptr<Node> MatMul(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<MatMul_Op>(), name); } } // namespace Aidge -namespace { -template <> -const char *const EnumStrings<Aidge::MatMulAttr>::data[] = {"OutChannels"}; -} - #endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */ diff --git a/python_binding/operator/pybind_Matmul.cpp b/python_binding/operator/pybind_Matmul.cpp index 72bc0f817..92e4bc801 100644 --- a/python_binding/operator/pybind_Matmul.cpp +++ b/python_binding/operator/pybind_Matmul.cpp @@ -19,16 +19,12 @@ namespace py = pybind11; namespace Aidge { -void declare_MatMul(py::module &m) { - py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, Attributes, OperatorTensor>(m, "MatMulOp", py::multiple_inheritance()) +void init_MatMul(py::module &m) { + py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor>(m, "MatMulOp", py::multiple_inheritance()) .def("get_inputs_name", &MatMul_Op::getInputsName) .def("get_outputs_name", &MatMul_Op::getOutputsName) .def("attributes_name", &MatMul_Op::staticGetAttrsName); - m.def("MatMul", &MatMul, py::arg("in_channels"), py::arg("out_channels"), py::arg("name") = ""); -} - -void init_MatMul(py::module &m) { - declare_MatMul(m); + m.def("MatMul", &MatMul, py::arg("name") = ""); } } // namespace Aidge -- GitLab