Skip to content
Snippets Groups Projects
Commit d704c111 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

remove matmul attrs

parent 79cd2ca3
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!76Matmul rework
......@@ -27,34 +27,21 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
enum class MatMulAttr { OutChannels };
class MatMul_Op : public OperatorTensor,
public Registrable<MatMul_Op,
std::string,
std::unique_ptr<OperatorImpl>(const MatMul_Op &)>,
public StaticAttributes<MatMulAttr, DimSize_t> {
std::unique_ptr<OperatorImpl>(const MatMul_Op &)> {
public:
static constexpr const char* Type = "MatMul";
MatMul_Op() = delete;
using Attributes_ = StaticAttributes<MatMulAttr, DimSize_t>;
template <MatMulAttr e> using attr = typename Attributes_::template attr<e>;
MatMul_Op(DimSize_t out_channels)
: OperatorTensor(Type, 1, 1, 1),
Attributes_(
attr<MatMulAttr::OutChannels>(out_channels))
{}
MatMul_Op(): OperatorTensor(Type, 2, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
MatMul_Op(const MatMul_Op& op)
: OperatorTensor(op),
Attributes_(op)
MatMul_Op(const MatMul_Op& op) : OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<MatMul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
}
......@@ -69,16 +56,20 @@ public:
void computeOutputDims() override final {
bool associated = true;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
if (!getInput(0)->empty() && !getInput(1)->empty())
{
std::vector<std::size_t> outDims;
for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++)
{
outDims.push_back(getInput(0)->dims()[i]);
}
associated &= !(getInput(i)->empty());
}
if (associated) {
// <batch, OutChannels>
mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<MatMulAttr::OutChannels>()});
size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0;
for (std::size_t i = 0; i < getInput(1)->nbDims(); i++)
{
if(i != secondToLastIdx)
outDims.push_back(getInput(1)->dims()[i]);
}
mOutputs[0]->resize(outDims);
}
}
......@@ -93,24 +84,16 @@ public:
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight"};
return {"data_input1", "data_input2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> MatMul(DimSize_t inChannels, DimSize_t outChannels, const std::string& name = "") {
// FIXME: properly handle default w initialization in every cases
auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(outChannels), name);
addProducer(matmul, 1, std::array<DimSize_t, 2>({outChannels, inChannels}), "w");
return matmul;
inline std::shared_ptr<Node> MatMul(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<MatMul_Op>(), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::MatMulAttr>::data[] = {"OutChannels"};
}
#endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */
......@@ -19,15 +19,11 @@
namespace py = pybind11;
namespace Aidge {
void declare_MatMul(py::module &m) {
py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor, Attributes>(m, "MatMulOp", py::multiple_inheritance())
void init_MatMul(py::module &m) {
py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, OperatorTensor>(m, "MatMulOp", py::multiple_inheritance())
.def("get_inputs_name", &MatMul_Op::getInputsName)
.def("get_outputs_name", &MatMul_Op::getOutputsName);
m.def("MatMul", &MatMul, py::arg("in_channels"), py::arg("out_channels"), py::arg("name") = "");
}
void init_MatMul(py::module &m) {
declare_MatMul(m);
m.def("MatMul", &MatMul, py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment