Skip to content
Snippets Groups Projects

Matmul rework

Merged Houssem ROUIS requested to merge hrouis/aidge_core:matmul_rework into dev
1 unresolved thread
2 files
+ 20
14
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -12,18 +12,14 @@
#ifndef AIDGE_CORE_OPERATOR_MATMUL_H_
#define AIDGE_CORE_OPERATOR_MATMUL_H_
#include <array>
#include <cmath>
#include <numeric>
#include <memory>
#include <string>
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
@@ -35,7 +31,7 @@ class MatMul_Op : public OperatorTensor,
public:
static const std::string Type;
MatMul_Op(): OperatorTensor(Type, 2, 0, 1) {}
MatMul_Op() : OperatorTensor(Type, 2, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
@@ -50,23 +46,33 @@ public:
* @brief Clone the operator using its copy-constructor.
* @see Operator::MatMul_Op
*/
std::shared_ptr<Operator> clone() const override {
std::shared_ptr<Operator> clone() const override final {
return std::make_shared<MatMul_Op>(*this);
}
/**
* @brief Compute dimensions for the output Tensor following the same rules as
* numpy.matmul.
* @note - Both inputs are 2-D Tensors: classic matrix multiplication
* @note - Either input is N-D with N > 2: it is treated as a stack of matrices residing
* in the last two indexes and broadcast accordingly.
* @note - First input is 1-D: it is promoted to a matrix by prepending a 1 to its
* dimensions (D) -> (1,D). The prepended 1 is removed after computation.
* @note - Second input is 1-D: it is promoted to a matrix by appending a 1 to its
* dimensions (D) -> (D,1). The appended 1 is removed after computation.
*/
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final {
mImpl = Registrar<MatMul_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
static const std::vector<std::string> getInputsName() {
return {"data_input1", "data_input2"};
}
static const std::vector<std::string> getOutputsName(){
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
Loading