Skip to content
Snippets Groups Projects
Commit db487d89 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

remove matmul attrs

parent 4fe5e82e
No related branches found
No related tags found
No related merge requests found
......@@ -27,34 +27,21 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
enum class MatMulAttr { OutChannels };
class MatMul_Op : public OperatorTensor,
public Registrable<MatMul_Op,
std::string,
std::unique_ptr<OperatorImpl>(const MatMul_Op &)>,
public StaticAttributes<MatMulAttr, DimSize_t> {
std::unique_ptr<OperatorImpl>(const MatMul_Op &)> {
public:
static const std::string Type;
MatMul_Op() = delete;
using Attributes_ = StaticAttributes<MatMulAttr, DimSize_t>;
template <MatMulAttr e> using attr = typename Attributes_::template attr<e>;
MatMul_Op(DimSize_t out_channels)
: OperatorTensor(Type, 1, 1, 1),
Attributes_(
attr<MatMulAttr::OutChannels>(out_channels))
{}
MatMul_Op(): OperatorTensor(Type, 2, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
MatMul_Op(const MatMul_Op& op)
: OperatorTensor(op),
Attributes_(op)
MatMul_Op(const MatMul_Op& op) : OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<MatMul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
}
......@@ -69,16 +56,20 @@ public:
void computeOutputDims() override final {
bool associated = true;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
if (!getInput(0)->empty() && !getInput(1)->empty())
{
std::vector<std::size_t> outDims;
for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++)
{
outDims.push_back(getInput(0)->dims()[i]);
}
associated &= !(getInput(i)->empty());
}
if (associated) {
// <batch, OutChannels>
mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<MatMulAttr::OutChannels>()});
size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0;
for (std::size_t i = 0; i < getInput(1)->nbDims(); i++)
{
if(i != secondToLastIdx)
outDims.push_back(getInput(1)->dims()[i]);
}
mOutputs[0]->resize(outDims);
}
}
......@@ -89,24 +80,16 @@ public:
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight"};
return {"data_input1", "data_input2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> MatMul(DimSize_t inChannels, DimSize_t outChannels, const std::string& name = "") {
// FIXME: properly handle default w initialization in every cases
auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(outChannels), name);
addProducer(matmul, 1, {outChannels, inChannels}, "w");
return matmul;
inline std::shared_ptr<Node> MatMul(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<MatMul_Op>(), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::MatMulAttr>::data[] = {"OutChannels"};
}
#endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */
......@@ -19,16 +19,12 @@
namespace py = pybind11;
namespace Aidge {
void declare_MatMul(py::module &m) {
py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, Attributes, OperatorTensor>(m, "MatMulOp", py::multiple_inheritance())
void init_MatMul(py::module &m) {
py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor>(m, "MatMulOp", py::multiple_inheritance())
.def("get_inputs_name", &MatMul_Op::getInputsName)
.def("get_outputs_name", &MatMul_Op::getOutputsName)
.def("attributes_name", &MatMul_Op::staticGetAttrsName);
m.def("MatMul", &MatMul, py::arg("in_channels"), py::arg("out_channels"), py::arg("name") = "");
}
void init_MatMul(py::module &m) {
declare_MatMul(m);
m.def("MatMul", &MatMul, py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment