Skip to content
Snippets Groups Projects
Commit d704c111 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

remove matmul attrs

parent 79cd2ca3
No related branches found
No related tags found
No related merge requests found
...@@ -27,34 +27,21 @@ ...@@ -27,34 +27,21 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
enum class MatMulAttr { OutChannels };
class MatMul_Op : public OperatorTensor, class MatMul_Op : public OperatorTensor,
public Registrable<MatMul_Op, public Registrable<MatMul_Op,
std::string, std::string,
std::unique_ptr<OperatorImpl>(const MatMul_Op &)>, std::unique_ptr<OperatorImpl>(const MatMul_Op &)> {
public StaticAttributes<MatMulAttr, DimSize_t> {
public: public:
static constexpr const char* Type = "MatMul"; static constexpr const char* Type = "MatMul";
MatMul_Op() = delete; MatMul_Op(): OperatorTensor(Type, 2, 0, 1) {}
using Attributes_ = StaticAttributes<MatMulAttr, DimSize_t>;
template <MatMulAttr e> using attr = typename Attributes_::template attr<e>;
MatMul_Op(DimSize_t out_channels)
: OperatorTensor(Type, 1, 1, 1),
Attributes_(
attr<MatMulAttr::OutChannels>(out_channels))
{}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy. * @param op Operator to copy.
*/ */
MatMul_Op(const MatMul_Op& op) MatMul_Op(const MatMul_Op& op) : OperatorTensor(op)
: OperatorTensor(op),
Attributes_(op)
{ {
mImpl = op.mImpl ? Registrar<MatMul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<MatMul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
} }
...@@ -69,16 +56,20 @@ public: ...@@ -69,16 +56,20 @@ public:
void computeOutputDims() override final { void computeOutputDims() override final {
bool associated = true; if (!getInput(0)->empty() && !getInput(1)->empty())
for (IOIndex_t i = 0; i < nbInputs(); ++i) { {
if (!getInput(i)) { std::vector<std::size_t> outDims;
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++)
{
outDims.push_back(getInput(0)->dims()[i]);
} }
associated &= !(getInput(i)->empty()); size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0;
} for (std::size_t i = 0; i < getInput(1)->nbDims(); i++)
if (associated) { {
// <batch, OutChannels> if(i != secondToLastIdx)
mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<MatMulAttr::OutChannels>()}); outDims.push_back(getInput(1)->dims()[i]);
}
mOutputs[0]->resize(outDims);
} }
} }
...@@ -93,24 +84,16 @@ public: ...@@ -93,24 +84,16 @@ public:
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "weight"}; return {"data_input1", "data_input2"};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
} }
}; };
inline std::shared_ptr<Node> MatMul(DimSize_t inChannels, DimSize_t outChannels, const std::string& name = "") { inline std::shared_ptr<Node> MatMul(const std::string& name = "") {
// FIXME: properly handle default w initialization in every cases return std::make_shared<Node>(std::make_shared<MatMul_Op>(), name);
auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(outChannels), name);
addProducer(matmul, 1, std::array<DimSize_t, 2>({outChannels, inChannels}), "w");
return matmul;
} }
} // namespace Aidge } // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::MatMulAttr>::data[] = {"OutChannels"};
}
#endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */ #endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */
...@@ -19,15 +19,11 @@ ...@@ -19,15 +19,11 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void declare_MatMul(py::module &m) { void init_MatMul(py::module &m) {
py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor, Attributes>(m, "MatMulOp", py::multiple_inheritance()) py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor>(m, "MatMulOp", py::multiple_inheritance())
.def("get_inputs_name", &MatMul_Op::getInputsName) .def("get_inputs_name", &MatMul_Op::getInputsName)
.def("get_outputs_name", &MatMul_Op::getOutputsName); .def("get_outputs_name", &MatMul_Op::getOutputsName);
m.def("MatMul", &MatMul, py::arg("in_channels"), py::arg("out_channels"), py::arg("name") = ""); m.def("MatMul", &MatMul, py::arg("name") = "");
}
void init_MatMul(py::module &m) {
declare_MatMul(m);
} }
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment