Skip to content
Snippets Groups Projects
Commit 10462478 authored by Maxence Naud's avatar Maxence Naud
Browse files

Update batchnBrm cNlls and add fuseBatchNorm test

parent 9676eac9
No related branches found
No related tags found
1 merge request!27Update BatchNorm and add fuseBatchNorm test
Pipeline #35613 passed
...@@ -31,7 +31,7 @@ class test_recipies(unittest.TestCase): ...@@ -31,7 +31,7 @@ class test_recipies(unittest.TestCase):
input_node = aidge_core.Producer(input_tensor, "X") input_node = aidge_core.Producer(input_tensor, "X")
conv = aidge_core.Conv2D(1, 1, [3, 3], name="Conv0") conv = aidge_core.Conv2D(1, 1, [3, 3], name="Conv0")
bn = aidge_core.BatchNorm2D(name="Add0") bn = aidge_core.BatchNorm2D(1, name="Add0")
graph_view = aidge_core.sequential([conv, bn]) graph_view = aidge_core.sequential([conv, bn])
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
using namespace Aidge; using namespace Aidge;
TEST_CASE("[cpu/operator] BatchNorm(forward)", "[BatchNorm][CPU]") { TEST_CASE("[cpu/operator] BatchNorm(forward)", "[BatchNorm][CPU]") {
std::shared_ptr<Node> myBatchNorm = BatchNorm<2>(0.00001F, 0.1F, "mybatchnorm"); std::shared_ptr<Node> myBatchNorm = BatchNorm<2>(3, 0.00001F, 0.1F, "mybatchnorm");
auto op = std::static_pointer_cast<OperatorTensor>(myBatchNorm -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myBatchNorm -> getOperator());
std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array1D<float,3> {{0.9044, 0.3028, 0.0218}}); std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array1D<float,3> {{0.9044, 0.3028, 0.0218}});
std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float,3> {{0.1332, 0.7503, 0.0878}}); std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float,3> {{0.1332, 0.7503, 0.0878}});
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <memory>
#include <cmath>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/data/Tensor.hpp"
namespace Aidge {
TEST_CASE("[core/recipies] FuseBatchNorm", "[recipies][FuseBatchNorm]") {
auto myProd = Producer({2, 3, 3, 3}, "dataProvider");
auto myConv = Conv(3, 3, {1, 1}, "conv1");
auto myBN = BatchNorm<2>(32, 1.0e-5F, 0.1F, "batchnorm1");
auto myProdOp = std::static_pointer_cast<Producer_Op>(myProd->getOperator());
auto myConvOp = std::static_pointer_cast<Conv_Op<2>>(myConv->getOperator());
auto myBNOp = std::static_pointer_cast<BatchNorm_Op<2>>(myBN->getOperator());
myProdOp->setOutput(0, std::make_shared<Tensor>(Array4D<float,2,3,3,3> { //NCHW
{
{
{{8.28257084e-01, 7.99335480e-01, 7.36702740e-01},
{2.36729562e-01, 8.61912668e-01, 9.93067741e-01},
{1.63514376e-01, 8.95773172e-02, 2.96533108e-01}},
{{2.20776618e-01, 5.89067876e-01, 2.03930080e-01},
{1.31294072e-01, 7.10182846e-01, 1.08420849e-04},
{7.21750259e-01, 4.38212037e-01, 5.08823872e-01}},
{{4.30953979e-01, 1.51903450e-01, 3.76343548e-01},
{8.07861805e-01, 7.79679358e-01, 5.01209974e-01},
{9.31280375e-01, 9.94207084e-01, 1.74868107e-03}}
},
{
{{6.22058094e-01, 2.32256651e-02, 6.18222237e-01},
{9.58304763e-01, 2.11395025e-02, 4.95614648e-01},
{2.50825584e-01, 4.50860739e-01, 3.80362332e-01}},
{{9.91703272e-02, 5.06073236e-01, 4.88969564e-01},
{1.12059772e-01, 7.64178872e-01, 7.60362148e-01},
{2.84135342e-02, 4.29610193e-01, 1.27862811e-01}},
{{9.57209170e-01, 8.22797656e-01, 1.91352129e-01},
{9.52722490e-01, 6.35501027e-01, 5.67592978e-02},
{2.00799644e-01, 4.00822222e-01, 9.14380193e-01}}
}
}
}));
myConvOp -> setInput(1, std::make_shared<Tensor>(Array4D<float,3,3,1,1> { //NCHW
{
{
{{8.28257084e-01}},
{{7.99335480e-01}},
{{7.36702740e-01}}
},
{
{{2.36729562e-01}},
{{8.61912668e-01}},
{{9.93067741e-01}}
},
{
{{1.63514376e-01}},
{{8.95773172e-02}},
{{2.96533108e-01}}
}
}
}));
myConvOp -> setInput(2, std::make_shared<Tensor>(Array1D<float,3> {{0.4470, 0.3064, 0.7061}}));
myBNOp -> setInput(1, std::make_shared<Tensor>(Array1D<float,3> {{0.9044, 0.3028, 0.0218}}));
myBNOp -> setInput(2, std::make_shared<Tensor>(Array1D<float,3> {{0.1332, 0.7503, 0.0878}}));
myBNOp -> setInput(3, std::make_shared<Tensor>(Array1D<float,3> {{0.9931, 0.8421, 0.9936}}));
myBNOp -> setInput(4, std::make_shared<Tensor>(Array1D<float,3> {{0.4470, 0.3064, 0.7061}}));
auto g1 = Sequential({
myConv,
myBN
});
g1 -> setName("fuseBNGraph");
myProd -> addChild(myConv); // set graph input
myProdOp -> setDataType(DataType::Float32);
myProdOp -> setBackend("cpu");
g1 -> compile("cpu", DataType::Float32);
auto s = SequentialScheduler(g1);
s.forward();
std::shared_ptr<Tensor> res1 = std::make_shared<Tensor>(*(myBNOp -> getOutput(0)));
fuseBatchNorm(g1);
s.resetScheduling();
s.forward();
std::shared_ptr<Tensor> res2 = std::make_shared<Tensor>(*(myConvOp -> getOutput(0)));
REQUIRE(g1 -> outputNodes().size() == 1);
REQUIRE(g1 -> inputNodes().size() == 1);
bool eq = true;
for (std::size_t i = 0; i < res1->size(); ++i) {
eq &= std::abs(res1->get<float>(i) - res2->get<float>(i)) < 1.0e-06;
}
REQUIRE(eq);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment