/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <catch2/catch_test_macros.hpp>

#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Add.hpp"

#include "aidge/backend/cpu.hpp"

using namespace Aidge;

TEST_CASE("[cpu/operator] Add(forward)", "[Add][CPU]") {
    std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
        {                                       //
            {                                   //
                {{20, 47},{21, 48},{22, 49}},   //
                {{23, 50},{24, 51},{25, 52}},   //
                {{26, 53},{27, 54},{28, 55}}    //
            },                                  //
            {                                   //
                {{29, 56},{30, 57},{31, 58}},   //
                {{32, 59},{33, 60},{34, 61}},   //
                {{35, 62},{36, 63},{37, 64}}    //
            },                                  //
            {                                   //
                {{38, 65},{39, 66},{40, 67}},   //
                {{41, 68},{42, 69},{43, 70}},   //
                {{44, 71},{45, 72},{46, 73}}    //
            }                                   //
        }                                       //
    });                                         //

    SECTION("Two inputs") {
        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
            {
                {
                    {{40,  94},{42,  96},{44,  98}},
                    {{46, 100},{48, 102},{50, 104}},
                    {{52, 106},{54, 108},{56, 110}}
                },
                {
                    {{58, 112},{60, 114},{62, 116}},
                    {{64, 118},{66, 120},{68, 122}},
                    {{70, 124},{72, 126},{74, 128}}
                },
                {
                    {{76, 130},{78, 132},{80, 134}},
                    {{82, 136},{84, 138},{86, 140}},
                    {{88, 142},{90, 144},{92, 146}}
                }
            }
        });

        std::shared_ptr<Node> myAdd = Add();
        auto op = std::static_pointer_cast<OperatorTensor>(myAdd -> getOperator());
        op->associateInput(0, input1);
        op->associateInput(1, input1);
        op->setBackend("cpu");
        op->setDataType(DataType::Int32);
        myAdd->forward();

        REQUIRE(*(op->getOutput(0)) == *expectedOutput);
    }

    SECTION("Broadcasting") {
        std::shared_ptr<Tensor> input_0 = std::make_shared<Tensor>(Array4D<int,3,1,3,2> {
        {                                       //
            {                                   //
                {{0, 1},{2, 3},{4, 5}}          //
            },                                  //
            {                                   //
                {{6, 7},{8, 9},{10, 11}}        //
            },                                  //
            {                                   //
                {{12, 13},{14, 15},{16, 17}}    //
            }                                   //
        }                                       //
        });                                     //
        std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
        {                                       //
            {                                   //
                {{20, 21},{22, 23},{24, 25}},   //
                {{26, 27},{28, 29},{30, 31}},   //
                {{32, 33},{34, 35},{36, 37}}    //
            }                                   //
        }                                       //
        });                                     //

        std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<int,2> {{100,200}});
        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
            {                                               //
                {                                           //
                    {{ 120, 222},{ 124, 226},{ 128, 230}},  //
                    {{ 126, 228},{ 130, 232},{ 134, 236}},  //
                    {{ 132, 234},{ 136, 238},{ 140, 242}}   //
                },                                          //
                {                                           //
                    {{ 126, 228},{ 130, 232},{ 134, 236}},  //
                    {{ 132, 234},{ 136, 238},{ 140, 242}},  //
                    {{ 138, 240},{ 142, 244},{ 146, 248}}   //
                },                                          //
                {                                           //
                    {{ 132, 234},{ 136, 238},{140, 242}},   //
                    {{ 138, 240},{ 142, 244},{146, 248}},   //
                    {{ 144, 246},{ 148, 250},{152, 254}}    //
                }                                           //
            }                                               //
        });                                                 //

        std::shared_ptr<Node> myAdd_0 = Add();
        std::shared_ptr<Node> myAdd_1 = Add();
        auto op_0 = std::static_pointer_cast<OperatorTensor>(myAdd_0 -> getOperator());
        auto op_1 = std::static_pointer_cast<OperatorTensor>(myAdd_1 -> getOperator());
        op_0->associateInput(0, input_0);
        op_0->associateInput(1, input_1);

        op_1->associateInput(0, input_2);
        op_1->associateInput(1, op_0->getOutput(0));
        op_0->setDataType(DataType::Int32);
        op_1->setDataType(DataType::Int32);
        op_0->setBackend("cpu");
        op_1->setBackend("cpu");
        myAdd_0->forward();
        myAdd_1->forward();
        op_1->getOutput(0)->print();
        expectedOutput->print();
        REQUIRE(*op_1->getOutput(0) == *expectedOutput);
    }
}