Skip to content
Snippets Groups Projects
  • Maxence Naud's avatar
    d8d7ca5c
    [Fix] small fixes · d8d7ca5c
    Maxence Naud authored
    - remove warning on void pointer arithmetic in TensorImpl
    - remove unused variable warning in AddImpl
    - update cpu Tensor test
    d8d7ca5c
    History
    [Fix] small fixes
    Maxence Naud authored
    - remove warning on void pointer arithmetic in TensorImpl
    - remove unused variable warning in AddImpl
    - update cpu Tensor test
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_TensorImpl.cpp 1.71 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <array>

#include <catch2/catch_test_macros.hpp>

#include "aidge/data/Tensor.hpp"
#include "aidge/backend/cpu/data/TensorImpl.hpp"

using namespace Aidge;

TEST_CASE("Tensor creation") {
  SECTION("from const array") {
    Tensor x = Array3D<int, 2, 2, 2>{{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}};

    Tensor xCopy = Array3D<int, 2, 2, 2>{{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}};

    Tensor xFloat =
        Array3D<float, 2, 2, 2>{{{{1., 2.}, {3., 4.}}, {{5., 6.}, {7., 8.}}}};

    SECTION("Tensor features") {
      REQUIRE(x.nbDims() == 3);
      REQUIRE(x.dims()[0] == 2);
      REQUIRE(x.dims()[1] == 2);
      REQUIRE(x.dims()[2] == 2);
      REQUIRE(x.size() == 8);
    }

    SECTION("Access to array") {
      REQUIRE(static_cast<int *>(x.getImpl()->rawPtr())[0] == 1);
      REQUIRE(static_cast<int *>(x.getImpl()->rawPtr())[7] == 8);
    }

    SECTION("get function") {
      REQUIRE(x.get<int>({0, 0, 0}) == 1);
      REQUIRE(x.get<int>({0, 0, 1}) == 2);
      REQUIRE(x.get<int>({0, 1, 1}) == 4);
      REQUIRE(x.get<int>({1, 1, 0}) == 7);
      x.get<int>({1, 1, 1}) = 36;
      REQUIRE(x.get<int>({1, 1, 1}) == 36);
    }

    SECTION("Pretty printing for debug") { REQUIRE_NOTHROW(x.print()); }

    SECTION("Tensor (in)equality") {
      REQUIRE(x == xCopy);
      REQUIRE_FALSE(x == xFloat);
    }
  }
}