Skip to content
Snippets Groups Projects
Commit 0a0c4904 authored by Maxence Naud's avatar Maxence Naud
Browse files

Standardize code for MatMul.hpp and Test_TensorImpl.cpp

parent 1b40120d
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!76Matmul rework
......@@ -12,18 +12,14 @@
#ifndef AIDGE_CORE_OPERATOR_MATMUL_H_
#define AIDGE_CORE_OPERATOR_MATMUL_H_
#include <array>
#include <cmath>
#include <numeric>
#include <memory>
#include <string>
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
......@@ -35,7 +31,7 @@ class MatMul_Op : public OperatorTensor,
public:
static const std::string Type;
MatMul_Op(): OperatorTensor(Type, 2, 0, 1) {}
MatMul_Op() : OperatorTensor(Type, 2, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......@@ -50,23 +46,33 @@ public:
* @brief Clone the operator using its copy-constructor.
* @see Operator::MatMul_Op
*/
std::shared_ptr<Operator> clone() const override {
std::shared_ptr<Operator> clone() const override final {
return std::make_shared<MatMul_Op>(*this);
}
/**
* @brief Compute dimensions for the output Tensor following the same rules as
* numpy.matmul.
* @note - Both inputs are 2-D Tensors: classic matrix multiplication
* @note - Either input is N-D with N > 2: it is treated as a stack of matrices residing
* in the last two indexes and broadcast accordingly.
* @note - First input is 1-D: it is promoted to a matrix by prepending a 1 to its
* dimensions (D) -> (1,D). The prepended 1 is removed after computation.
* @note - Second input is 1-D: it is promoted to a matrix by appending a 1 to its
* dimensions (D) -> (D,1). The appended 1 is removed after computation.
*/
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final {
mImpl = Registrar<MatMul_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
static const std::vector<std::string> getInputsName() {
return {"data_input1", "data_input2"};
}
static const std::vector<std::string> getOutputsName(){
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
......
......@@ -19,7 +19,7 @@
using namespace Aidge;
TEST_CASE("Tensor creation") {
TEST_CASE("[core/data] Tensor creation") {
SECTION("from const array") {
Tensor x = Array3D<int, 2, 2, 2>{{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}};
......@@ -59,7 +59,7 @@ TEST_CASE("Tensor creation") {
}
}
TEST_CASE("Tensor methods") {
TEST_CASE("[core/data] Tensor methods","[Tensor]") {
Tensor x = Array3D<int, 2, 2, 2>{{
{{1, 2},
{3, 4}},
......@@ -89,7 +89,7 @@ TEST_CASE("Tensor methods") {
REQUIRE(y.getImpl() == x.getImpl());
REQUIRE(approxEq<int>(y, Array1D<int, 2>{{3, 4}}));
REQUIRE(y.isContiguous());
Tensor y2 = x.extract({0, 1, 1}, {2, 1, 1});
REQUIRE(y2.getImpl() == x.getImpl());
REQUIRE(!y2.isContiguous());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment