/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <catch2/catch_test_macros.hpp> #include <memory> #include <cmath> #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/data/Tensor.hpp" namespace Aidge { TEST_CASE("[core/recipes] FuseBatchNorm", "[recipes][FuseBatchNorm]") { auto myProd = Producer({2, 3, 3, 3}, "dataProvider"); auto myConv = Conv(3, 3, {1, 1}, "conv1"); auto myBN = BatchNorm<2>(32, 1.0e-5F, 0.1F, "batchnorm1"); auto myProdOp = std::static_pointer_cast<Producer_Op>(myProd->getOperator()); auto myConvOp = std::static_pointer_cast<Conv_Op<2>>(myConv->getOperator()); auto myBNOp = std::static_pointer_cast<BatchNorm_Op<2>>(myBN->getOperator()); myProdOp->setOutput(0, std::make_shared<Tensor>(Array4D<float,2,3,3,3> { //NCHW { { {{8.28257084e-01, 7.99335480e-01, 7.36702740e-01}, {2.36729562e-01, 8.61912668e-01, 9.93067741e-01}, {1.63514376e-01, 8.95773172e-02, 2.96533108e-01}}, {{2.20776618e-01, 5.89067876e-01, 2.03930080e-01}, {1.31294072e-01, 7.10182846e-01, 1.08420849e-04}, {7.21750259e-01, 4.38212037e-01, 5.08823872e-01}}, {{4.30953979e-01, 1.51903450e-01, 3.76343548e-01}, {8.07861805e-01, 7.79679358e-01, 5.01209974e-01}, {9.31280375e-01, 9.94207084e-01, 1.74868107e-03}} }, { {{6.22058094e-01, 2.32256651e-02, 6.18222237e-01}, {9.58304763e-01, 2.11395025e-02, 4.95614648e-01}, {2.50825584e-01, 4.50860739e-01, 3.80362332e-01}}, {{9.91703272e-02, 5.06073236e-01, 4.88969564e-01}, {1.12059772e-01, 7.64178872e-01, 7.60362148e-01}, {2.84135342e-02, 4.29610193e-01, 1.27862811e-01}}, {{9.57209170e-01, 8.22797656e-01, 1.91352129e-01}, {9.52722490e-01, 6.35501027e-01, 5.67592978e-02}, {2.00799644e-01, 4.00822222e-01, 9.14380193e-01}} } } })); myConvOp -> setInput(1, std::make_shared<Tensor>(Array4D<float,3,3,1,1> { //NCHW { { {{8.28257084e-01}}, {{7.99335480e-01}}, {{7.36702740e-01}} }, { {{2.36729562e-01}}, {{8.61912668e-01}}, {{9.93067741e-01}} }, { {{1.63514376e-01}}, {{8.95773172e-02}}, {{2.96533108e-01}} } } })); myConvOp -> setInput(2, std::make_shared<Tensor>(Array1D<float,3> {{0.4470, 0.3064, 0.7061}})); myBNOp -> setInput(1, std::make_shared<Tensor>(Array1D<float,3> {{0.9044, 0.3028, 0.0218}})); myBNOp -> setInput(2, std::make_shared<Tensor>(Array1D<float,3> {{0.1332, 0.7503, 0.0878}})); myBNOp -> setInput(3, std::make_shared<Tensor>(Array1D<float,3> {{0.9931, 0.8421, 0.9936}})); myBNOp -> setInput(4, std::make_shared<Tensor>(Array1D<float,3> {{0.4470, 0.3064, 0.7061}})); auto g1 = Sequential({ myProd, myConv, myBN }); g1 -> setName("fuseBNGraph"); g1 -> compile("cpu", DataType::Float32); auto s = SequentialScheduler(g1); s.forward(); std::shared_ptr<Tensor> res1 = std::make_shared<Tensor>(*(myBNOp -> getOutput(0))); fuseBatchNorm(g1); s.resetScheduling(); s.forward(); std::shared_ptr<Tensor> res2 = std::make_shared<Tensor>(*(myConvOp -> getOutput(0))); REQUIRE(g1 -> outputNodes().size() == 1); REQUIRE(g1 -> inputNodes().size() == 0); bool eq = true; for (std::size_t i = 0; i < res1->size(); ++i) { eq &= std::abs(res1->get<float>(i) - res2->get<float>(i)) < 1.0e-06; } REQUIRE(eq); } } // namespace Aidge