diff --git a/unit_tests/operator/Test_MatMul_Op.cpp b/unit_tests/operator/Test_MatMul_Op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c810e675ad46cc5580bd24e57f7e7dbb84db38f --- /dev/null +++ b/unit_tests/operator/Test_MatMul_Op.cpp @@ -0,0 +1,196 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <cstddef> // std::size_t +#include <memory> +#include <random> // std::random_device, std::mt19937, std::uniform_int_distribution +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace Aidge { +TEST_CASE("[core/operator] MatMul_Op(computeOutputDims)", "[MatMul][computeOutputDims]") { + // Create a random number generator + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<std::size_t> dist(1, 10); + + // Create MatMul Operator + std::shared_ptr<Node> myMatMul = MatMul(); + auto op = std::static_pointer_cast<OperatorTensor>(myMatMul -> getOperator()); + + /** @todo Special case of scalar Tensor objects. + * Not handled yet. + */ + // SECTION("0-D / 0-D") { + // std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>(); + // T0->resize({}); + // op -> associateInput(0,T0); + + // // input_1 - right + // std::shared_ptr<Tensor> T1 = std::make_shared<Tensor>(); + // T1->resize({}); + // op -> associateInput(1,T1); + + // REQUIRE_NOTHROW(op->computeOutputDims()); + // REQUIRE((op->getOutput(0)->dims()).empty()); + + // // input_1 - wrong + // T1->resize({dist(gen)}); + + // REQUIRE_THROWS(op->computeOutputDims()); + // } + + SECTION("1-D / N-D") { + // input_0 + std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>(); + const std::size_t dim0 = dist(gen); + T0->resize({dim0}); + op -> associateInput(0,T0); + + std::shared_ptr<Tensor> T1 = std::make_shared<Tensor>(); + op -> associateInput(1,T1); + + SECTION("1-D / 1-D") { + // input_1 - right + T1->resize({dim0}); + + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE((op->getOutput(0)->dims()).empty()); + + // input_1 - wrong + T1->resize({dim0+1}); + + REQUIRE_THROWS(op -> computeOutputDims()); + } + SECTION("1-D / 2-D") { + // input_1 - right + const std::size_t dim1 = dist(gen); + T1->resize({dim0,dim1}); + + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim1})); + + // input_1 - wrong + T1->resize({dim0+1,dim1}); + + REQUIRE_THROWS(op -> computeOutputDims()); + } + SECTION("1-D / +2-D") { + // input_1 - right + const std::size_t dim1 = dist(gen); + const std::size_t dim2 = dist(gen); + const std::size_t dim3 = dist(gen); + T1->resize({dim1,dim2,dim0,dim3}); + + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim1,dim2,dim3})); + } + } + SECTION("2-D / N-D") { + // input_0 + std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>(); + const std::size_t dim0 = dist(gen); + const std::size_t dim1 = dist(gen); + T0->resize({dim0,dim1}); + op -> associateInput(0,T0); + + // input_1 + std::shared_ptr<Tensor> T1 = std::make_shared<Tensor>(); + op -> associateInput(1,T1); + + SECTION("2-D / 1-D") { + // input_1 - right + T1->resize({dim1}); + + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim0})); + + // input_1 - wrong + T1->resize({dim1+1}); + + REQUIRE_THROWS(op -> computeOutputDims()); + } + SECTION("2-D / 2-D") { + // input_1 - right + const std::size_t dim2 = dist(gen); + T1->resize({dim1, dim2}); + + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim0,dim2})); + + // input_1 - wrong + T1->resize({dim1+1,dim2}); + + REQUIRE_THROWS(op -> computeOutputDims()); + } + SECTION("2-D / +2-D") { + // input_1 - right + const std::size_t dim2 = dist(gen); + const std::size_t dim3 = dist(gen); + const std::size_t dim4 = dist(gen); + T1->resize({dim3,dim4,dim1, dim2}); + + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim3,dim4,dim0,dim2})); + + // input_1 - wrong + T1->resize({dim3,dim4,dim1+1,dim2}); + + REQUIRE_THROWS(op -> computeOutputDims()); + } + } + SECTION("+2-D / +2-D") { + // input_0 + std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>(); + const std::size_t dim0 = dist(gen) + 1; + const std::size_t dim1 = 1; + const std::size_t dim2 = dist(gen); + const std::size_t dim3 = dist(gen); + T0->resize({dim0,dim1,dim2,dim3}); + op -> associateInput(0,T0); + + // input_1 + std::shared_ptr<Tensor> T1 = std::make_shared<Tensor>(); + op -> associateInput(1,T1); + + // input_1 - right + // 1 + const std::size_t dim5 = dist(gen); + T1->resize({dim0,dim1,dim3,dim5}); + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim0,dim1,dim2,dim5})); + + // 2 - input_1 broadcast + T1->resize({1,dim1,dim3,dim5}); + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim0,dim1,dim2,dim5})); + + // 3 - input_0 broadcast + const std::size_t dim1_bigger = dist(gen) + 1; + T1->resize({dim0,dim1_bigger,dim3,dim5}); + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim0,dim1_bigger,dim2,dim5})); + + // 4 - input_0+input_1 broadcast + T1->resize({1,dim1_bigger,dim3,dim5}); + REQUIRE_NOTHROW(op -> computeOutputDims()); + REQUIRE(op->getOutput(0)->dims() == std::vector<std::size_t>({dim0,dim1_bigger,dim2,dim5})); + + // input_1 - wrong + T1->resize({dim0+1,dim1,dim3,dim5}); + REQUIRE_THROWS(op -> computeOutputDims()); + } +} +} // namespace Aidge \ No newline at end of file