Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_TensorImpl.cpp 6.21 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <array>

#include <iostream>

#include <catch2/catch_test_macros.hpp>

#include "aidge/backend/cpu/data/TensorImpl.hpp"
#include "aidge/data/Tensor.hpp"

using namespace Aidge;

template<typename Data_T> bool MakeRainbow(Tensor &i_Tensor)
{
    NbElts_t N = i_Tensor.size();
    Data_T *data = reinterpret_cast<Data_T *>(i_Tensor.getImpl().rawPtr());
    for (std::size_t i = 0; i < N; ++i, ++data)
    {
        *data = i;
    }
    return true;
}

TEST_CASE("Tensor creation")
{
    SECTION("from const array")
    {
        // clang-format off
        Tensor y = Array3D<int, 1, 2, 3>{
            {
            {
                {1,2,3},
                {4,5,6}
            }
        }
        };
        // clang-format on
        Tensor x = Array3D<int, 2, 2, 2>{{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}};

        Tensor xCopy = Array3D<int, 2, 2, 2>{{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}};

        Tensor xFloat
            = Array3D<float, 2, 2, 2>{{{{1., 2.}, {3., 4.}}, {{5., 6.}, {7., 8.}}}};

        SECTION("Tensor features")
        {
            REQUIRE(y.nbDims() == 3);
            REQUIRE(y.dims()[0] == 1);
            REQUIRE(y.dims()[1] == 2);
            REQUIRE(y.dims()[2] == 3);
            auto const &impl = y.getImpl();
            REQUIRE(impl.getMemoryLayout()[0] == impl.getScalarSize() * 3 * 2);
            REQUIRE(impl.getMemoryLayout()[1] == impl.getScalarSize() * 3);
            REQUIRE(impl.getMemoryLayout()[2] == impl.getScalarSize());
            REQUIRE(impl.getScalarSize() == sizeof(int));
            REQUIRE(x.nbDims() == 3);
            REQUIRE(x.dims()[0] == 2);
            REQUIRE(x.dims()[1] == 2);
            REQUIRE(x.dims()[2] == 2);
            REQUIRE(x.size() == 8);
        }

        SECTION("Access to array")
        {
            REQUIRE(reinterpret_cast<int const *>(x.getImpl().getDataAddress())[0] == 1);
            REQUIRE(reinterpret_cast<int const *>(x.getImpl().getDataAddress())[7] == 8);
        }

        SECTION("get function")
        {
            REQUIRE(x.get<int>({0, 0, 0}) == 1);
            REQUIRE(x.get<int>({0, 0, 1}) == 2);
            REQUIRE(x.get<int>({0, 1, 1}) == 4);
            REQUIRE(x.get<int>({1, 1, 0}) == 7);
            x.get<int>({1, 1, 1}) = 36;
            REQUIRE(x.get<int>({1, 1, 1}) == 36);
        }

        SECTION("Pretty printing for debug")
        {
            REQUIRE_NOTHROW(x.print());
        }

        SECTION("Tensor (in)equality")
        {
            REQUIRE(x == xCopy);
            REQUIRE_FALSE(x == xFloat);
        }
    }
    SECTION("rainbow")
    {
        Tensor Rainbow;
        Rainbow.resize({2, 4, 5});
        Rainbow.setDatatype(DataType::UInt16);
        Rainbow.setBackend("cpu");
        REQUIRE(MakeRainbow<std::uint16_t>(Rainbow));
        bool res = true;
        for (NbElts_t i = 0; i != Rainbow.size(); ++i)
        {
            res &= (Rainbow.get<std::uint16_t>(i) == i);
        }
        REQUIRE(res);
    }
}
TEST_CASE("Tensor copy")
{
    SECTION("deep copy")
    {
        Tensor Rainbow;
        Rainbow.resize({2, 4, 5});
        Rainbow.setDatatype(DataType::UInt16);
        Rainbow.setBackend("cpu");
        MakeRainbow<std::uint16_t>(Rainbow);
        Tensor clone(Rainbow);
        REQUIRE(clone.dataType() == Rainbow.dataType());
        REQUIRE(clone.nbDims() == Rainbow.nbDims());
        for (std::size_t i = 0; i < clone.nbDims(); ++i)
        {
            REQUIRE(clone.dims()[i] == Rainbow.dims()[i]);
        }
        REQUIRE(
            clone.getImpl().getDimensions().size()
            == Rainbow.getImpl().getDimensions().size());
        for (std::size_t i = 0; i < clone.nbDims(); ++i)
        {
            REQUIRE(
                clone.getImpl().getDimensions()[i]
                == Rainbow.getImpl().getDimensions()[i]);
        }
        for (Coord_t a = 0; a < clone.dims()[0]; ++a)
        {
            for (Coord_t b = 0; b < clone.dims()[1]; ++b)
            {
                for (Coord_t c = 0; c < clone.dims()[2]; ++c)
                {
                    REQUIRE(
                        clone.get<std::uint16_t>({a, b, c})
                        == Rainbow.get<std::uint16_t>({a, b, c}));
                }
            }
        }
    }
}
TEST_CASE("Tensor access")
{
    SECTION("coordinates manipulations")
    {
        // clang-format off
        Tensor y = Array3D<int, 1, 2, 3>{
            {
            {
                {1,2,3},
                {4,5,6}
            }
        }
        };
        // clang-format on
        NbElts_t flatId = 0;
        for (Coord_t a = 0; a < y.dims()[0]; ++a)
        {
            for (Coord_t b = 0; b < y.dims()[1]; ++b)
            {
                for (Coord_t c = 0; c < y.dims()[2]; ++c)
                {
                    REQUIRE(y.getIdx(std::vector<Coord_t>{a, b, c}) == flatId);
                    std::vector<Coord_t> coords(3);
                    y.getCoord(flatId, coords);
                    REQUIRE(coords[0] == a);
                    REQUIRE(coords[1] == b);
                    REQUIRE(coords[2] == c);
                    ++flatId;
                }
            }
        }
    }
}
TEST_CASE("Tensor extract")
{
    SECTION("shallow extract")
    {
        Tensor Rainbow;
        Rainbow.resize({2, 4, 5});
        Rainbow.setDatatype(DataType::UInt16);
        Rainbow.setBackend("cpu");
        MakeRainbow<std::uint16_t>(Rainbow);
        // Tensor view(Rainbow, {2, 2, 3}, {0, 1, 1});
        // for (Coord_t a = 0; a < view.dims()[0]; ++a)
        // {
        //     for (Coord_t b = 0; b < view.dims()[1]; ++b)
        //     {
        //         for (Coord_t c = 0; c < view.dims()[2]; ++c)
        //         {
        //             REQUIRE(
        //                 view.get<std::uint16_t>({a, b + 1, c + 1})
        //                 == Rainbow.get<std::uint16_t>({a, b + 1, c + 1}));
        //         }
        //     }
        // }
    }
}