/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <catch2/catch_test_macros.hpp>

#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Slice.hpp"

#include "aidge/backend/cpu.hpp"

using namespace Aidge;

TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
    SECTION("1D Tensor") {
        std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array1D<int,10> {
            {0, 1, 2,-3, 4,-5,-6, 7, 8, 9}
        });
        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,4> {
            {0, 1, 2,-3}
        });

        std::shared_ptr<Node> mySlice = Slice({0}, {3}, {0});
        auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
        mySlice->getOperator()->associateInput(0,input0);
        mySlice->getOperator()->setDataType(DataType::Int32);
        mySlice->getOperator()->setBackend("cpu");
        op->computeOutputDims();
        mySlice->forward();

        REQUIRE(*(op->getOutput(0)) == *expectedOutput);
        REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
        REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
    }

    SECTION("2D Tensor") {
        std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array2D<int,2,10> {
            {
                { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
            }
        });
        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<int,2,3> {
            {
                {-5,-6, 7},
                {-5,-6, 7}
            }
        });

        std::shared_ptr<Node> mySlice = Slice({0,5}, {1,7}, {0,1});
        auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
        mySlice->getOperator()->associateInput(0,input0);
        mySlice->getOperator()->setDataType(DataType::Int32);
        mySlice->getOperator()->setBackend("cpu");
        op->computeOutputDims();
        mySlice->forward();
        // mySlice->getOperator()->output(0).print();
        REQUIRE(*(op->getOutput(0)) == *expectedOutput);
        REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
        REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
    }

    SECTION("3D Tensor") {
        std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array3D<int,2,2,10> {
            {
                {
                    { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                    {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                },
                {
                    { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                    {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                }
            }
        });
        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,1,1,3> {
            {
                {
                    { 4,-5,-6}
                }
            }
        });

        std::shared_ptr<Node> mySlice = Slice({0,1,4}, {0,1,6}, {0,1,2});
        auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
        mySlice->getOperator()->associateInput(0,input0);
        mySlice->getOperator()->setDataType(DataType::Int32);
        mySlice->getOperator()->setBackend("cpu");
        op->computeOutputDims();
        mySlice->forward();
        // mySlice->getOperator()->output(0).print();
        REQUIRE(*(op->getOutput(0)) == *expectedOutput);
        REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
        REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
    }

    SECTION("4D Tensor") {
        std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<int,2,2,2,10> {
            {
                {
                    {
                        { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                    },
                    {
                        { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                    }
                },
                {
                    {
                        { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                    },
                    {
                        { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3,11,-5,-6, 7,-1,10}
                    }
                }
            }
        });
        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,2,2,2,10> {
            {
                {
                    {
                        { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                    },
                    {
                        { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                    }
                },
                {
                    {
                        { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
                    },
                    {
                        { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
                        {-5, 4, 2,-3,11,-5,-6, 7,-1,10}
                    }
                }
            }
        });

        std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,9}, {0,1,2,3});
        auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
        mySlice->getOperator()->associateInput(0,input0);
        mySlice->getOperator()->setDataType(DataType::Int32);
        mySlice->getOperator()->setBackend("cpu");
        op->computeOutputDims();
        mySlice->forward();
        // mySlice->getOperator()->output(0).print();
        REQUIRE(*(op->getOutput(0)) == *expectedOutput);
        REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
        REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
    }
}