/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <memory>

#include <catch2/catch_test_macros.hpp>

#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
#include "aidge/data/DataType.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/TensorUtils.hpp"

using namespace Aidge;

TEST_CASE("[cpu/operator] Sqrt(forward)", "[Sqrt][CPU]") {
    SECTION("2D Tensor") {
        std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,2> {
            {
                {16.00000000,  0.62226844},
                { 0.00000000,  1.84539008}
            }
        });
        Tensor expectedOutput = Array2D<float,2,2> {
            {
                {4.00000000, 0.78883994},
                {0.00000000, 1.35845140}
            }
        };

        std::shared_ptr<Sqrt_Op> op = std::make_shared<Sqrt_Op>();
        op->associateInput(0,input);
        op->setDataType(DataType::Float32);
        op->setBackend("cpu");
        op->forward();

        REQUIRE(approxEq<float>(*(op->getOutput(0)), expectedOutput, 1e-5f, 1e-8f));
    }

    SECTION("4D Tensor") {
        std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<float,2,3,3,3> {
            {
                {
                    {{0.06218481, 0.46850157, 0.60914326},
                     {0.57470602, 0.09943211, 0.59992820},
                     {0.99623793, 0.54931718, 0.89343822}},
                    {{0.75176072, 0.38237786, 0.84824580},
                     {0.10619396, 0.11959118, 0.93499404},
                     {0.65563291, 0.02913034, 0.17093092}},
                    {{0.36303985, 0.92073035, 0.79146117},
                     {0.88962847, 0.94561219, 0.92033130},
                     {0.52903181, 0.13397896, 0.76086712}}
                },
                {
                    {{0.31242222, 0.80526417, 0.48411584},
                     {0.84375203, 0.65408552, 0.55028963},
                     {0.77546734, 0.06203610, 0.83163154}},
                    {{0.46342927, 0.53631741, 0.39145601},
                     {0.14204198, 0.84214240, 0.94185621},
                     {0.05068624, 0.99889028, 0.38464361}},
                    {{0.37591159, 0.51769549, 0.30288595},
                     {0.96883464, 0.35154045, 0.55648762},
                     {0.13022375, 0.73467660, 0.02705121}}
                }
            }
        });

        Tensor expectedOutput = Array4D<float,2,3,3,3> {
            {
                {
                    {{0.24936883, 0.6844717,  0.7804763},
                     {0.75809366, 0.31532857, 0.7745503},
                     {0.9981172,  0.7411593,  0.9452186}},
                    {{0.86704135, 0.6183671,  0.9210026},
                     {0.32587415, 0.34581956, 0.9669509},
                     {0.80971164, 0.17067613, 0.41343793}},
                    {{0.60252786, 0.9595469,  0.88964105},
                     {0.9432012,  0.97242594, 0.95933896},
                     {0.7273457,  0.36603138, 0.87227696}}
                },
                {
                    {{0.55894744, 0.89736515, 0.69578433},
                     {0.91855973, 0.8087555,  0.7418151},
                     {0.88060623, 0.24907047, 0.91193837}},
                    {{0.6807564,  0.73233694, 0.6256645},
                     {0.37688458, 0.9176832,  0.9704928},
                     {0.22513604, 0.99944496, 0.62019646}},
                    {{0.6131163,  0.7195106,  0.5503507},
                     {0.984294,   0.59290844, 0.745981},
                     {0.3608653,  0.8571328,  0.16447252}}
                }
            }
        };

        std::shared_ptr<Sqrt_Op> op = std::make_shared<Sqrt_Op>();
        op->associateInput(0,input);
        op->setDataType(DataType::Float32);
        op->setBackend("cpu");
        op->forward();

        REQUIRE(approxEq<float>(*(op->getOutput(0)), expectedOutput, 1e-5f, 1e-8f));
    }
}