/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <catch2/catch_test_macros.hpp>

#include "aidge/data/Tensor.hpp"

#include "aidge/backend/opencv/data/TensorImpl.hpp"


using namespace Aidge;

TEST_CASE("Tensor creation opencv", "[Tensor][OpenCV]") {
    SECTION("from const array") {
        Tensor x;
        x.setDataType(Aidge::DataType::Int32);
        x.setBackend("opencv");
        x = Array3D<int,2,2,2>{
        {
            {
                {1, 2},
                {3, 4}
            },
            {
                {5, 6},
                {7, 8}
            }
        }};

        Tensor xCopy;
        xCopy.setDataType(Aidge::DataType::Int32);
        xCopy.setBackend("opencv");
        xCopy = Array3D<int,2,2,2>{
        {
            {
                {1, 2},
                {3, 4}
            },
            {
                {5, 6},
                {7, 8}
            }
        }};

        Tensor xFloat;
        xFloat.setBackend("opencv");
        xFloat = Array3D<float,2,2,2>{
        {
            {
                {1., 2.},
                {3., 4.}
            },
            {
                {5., 6.},
                {7., 8.}
            }
        }};

        SECTION("Tensor features") {
            REQUIRE(x.nbDims() == 3);
            REQUIRE(x.dims()[0] == 2);
            REQUIRE(x.dims()[1] == 2);
            REQUIRE(x.dims()[2] == 2);
            REQUIRE(x.size() == 8);
        }

        SECTION("OpenCV tensor features") {
            REQUIRE(static_cast<TensorImpl_opencv<int>*>(x.getImpl().get())->data().rows == 2);
            REQUIRE(static_cast<TensorImpl_opencv<int>*>(x.getImpl().get())->data().cols == 2);
            REQUIRE(static_cast<TensorImpl_opencv<int>*>(x.getImpl().get())->data().dims == 2);
            REQUIRE(static_cast<TensorImpl_opencv<int>*>(x.getImpl().get())->data().total() == 4);
            REQUIRE(static_cast<TensorImpl_opencv<int>*>(x.getImpl().get())->data().channels() == 2);
        }

        SECTION("Access to array") {
            REQUIRE(static_cast<int*>(x.getImpl()->rawPtr())[0] == 1);
            REQUIRE(static_cast<int*>(x.getImpl()->rawPtr())[7] == 8);
        }

        SECTION("get function") {
            REQUIRE(x.get<int>({0,0,0}) == 1);
            REQUIRE(x.get<int>({0,0,1}) == 2);
            REQUIRE(x.get<int>({0,1,1}) == 4);
            REQUIRE(x.get<int>({1,1,0}) == 7);
            x.set<int>({1, 1, 1}, 36);
            REQUIRE(x.get<int>({1,1,1}) == 36);
        }

        SECTION("Pretty printing for debug") {
            REQUIRE_NOTHROW(x.print());
        }

        SECTION("Tensor (in)equality") {
            REQUIRE(x == xCopy);
            REQUIRE_FALSE(x == xFloat);
        }
    }

    SECTION("from const array before backend") {
        Tensor x = Array3D<int,2,2,2>{
        {
            {
                {1, 2},
                {3, 4}
            },
            {
                {5, 6},
                {7, 8}
            }
        }};
        x.setBackend("opencv");

        REQUIRE(x.nbDims() == 3);
        REQUIRE(x.dims()[0] == 2);
        REQUIRE(x.dims()[1] == 2);
        REQUIRE(x.dims()[2] == 2);
        REQUIRE(x.size() == 8);

        REQUIRE(x.get<int>({0,0,0}) == 1);
        REQUIRE(x.get<int>({0,0,1}) == 2);
        REQUIRE(x.get<int>({0,1,1}) == 4);
        REQUIRE(x.get<int>({1,1,1}) == 8);
    }

}