Skip to content
Snippets Groups Projects
Commit d704c111 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

remove matmul attrs

parent 79cd2ca3
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!76Matmul rework
...@@ -27,34 +27,21 @@ ...@@ -27,34 +27,21 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
enum class MatMulAttr { OutChannels };
class MatMul_Op : public OperatorTensor, class MatMul_Op : public OperatorTensor,
public Registrable<MatMul_Op, public Registrable<MatMul_Op,
std::string, std::string,
std::unique_ptr<OperatorImpl>(const MatMul_Op &)>, std::unique_ptr<OperatorImpl>(const MatMul_Op &)> {
public StaticAttributes<MatMulAttr, DimSize_t> {
public: public:
static constexpr const char* Type = "MatMul"; static constexpr const char* Type = "MatMul";
MatMul_Op() = delete; MatMul_Op(): OperatorTensor(Type, 2, 0, 1) {}
using Attributes_ = StaticAttributes<MatMulAttr, DimSize_t>;
template <MatMulAttr e> using attr = typename Attributes_::template attr<e>;
MatMul_Op(DimSize_t out_channels)
: OperatorTensor(Type, 1, 1, 1),
Attributes_(
attr<MatMulAttr::OutChannels>(out_channels))
{}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy. * @param op Operator to copy.
*/ */
MatMul_Op(const MatMul_Op& op) MatMul_Op(const MatMul_Op& op) : OperatorTensor(op)
: OperatorTensor(op),
Attributes_(op)
{ {
mImpl = op.mImpl ? Registrar<MatMul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<MatMul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
} }
...@@ -69,16 +56,20 @@ public: ...@@ -69,16 +56,20 @@ public:
void computeOutputDims() override final { void computeOutputDims() override final {
bool associated = true; if (!getInput(0)->empty() && !getInput(1)->empty())
for (IOIndex_t i = 0; i < nbInputs(); ++i) { {
if (!getInput(i)) { std::vector<std::size_t> outDims;
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++)
{
outDims.push_back(getInput(0)->dims()[i]);
} }
associated &= !(getInput(i)->empty()); size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0;
} for (std::size_t i = 0; i < getInput(1)->nbDims(); i++)
if (associated) { {
// <batch, OutChannels> if(i != secondToLastIdx)
mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<MatMulAttr::OutChannels>()}); outDims.push_back(getInput(1)->dims()[i]);
}
mOutputs[0]->resize(outDims);
} }
} }
...@@ -93,24 +84,16 @@ public: ...@@ -93,24 +84,16 @@ public:
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "weight"}; return {"data_input1", "data_input2"};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
} }
}; };
inline std::shared_ptr<Node> MatMul(DimSize_t inChannels, DimSize_t outChannels, const std::string& name = "") { inline std::shared_ptr<Node> MatMul(const std::string& name = "") {
// FIXME: properly handle default w initialization in every cases return std::make_shared<Node>(std::make_shared<MatMul_Op>(), name);
auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(outChannels), name);
addProducer(matmul, 1, std::array<DimSize_t, 2>({outChannels, inChannels}), "w");
return matmul;
} }
} // namespace Aidge } // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::MatMulAttr>::data[] = {"OutChannels"};
}
#endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */ #endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */
...@@ -19,15 +19,11 @@ ...@@ -19,15 +19,11 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void declare_MatMul(py::module &m) { void init_MatMul(py::module &m) {
py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor, Attributes>(m, "MatMulOp", py::multiple_inheritance()) py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor>(m, "MatMulOp", py::multiple_inheritance())
.def("get_inputs_name", &MatMul_Op::getInputsName) .def("get_inputs_name", &MatMul_Op::getInputsName)
.def("get_outputs_name", &MatMul_Op::getOutputsName); .def("get_outputs_name", &MatMul_Op::getOutputsName);
m.def("MatMul", &MatMul, py::arg("in_channels"), py::arg("out_channels"), py::arg("name") = ""); m.def("MatMul", &MatMul, py::arg("name") = "");
}
void init_MatMul(py::module &m) {
declare_MatMul(m);
} }
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment