From 5948832b02a2adb75a9f0f0bfd037d01823a8a66 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Mon, 18 Dec 2023 09:50:08 +0000 Subject: [PATCH] fix batchnorm and softmax tests --- unit_tests/operator/Test_BatchNormImpl.cpp | 2 +- unit_tests/operator/Test_SoftmaxImpl.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/unit_tests/operator/Test_BatchNormImpl.cpp b/unit_tests/operator/Test_BatchNormImpl.cpp index e6b7c3c6..a1a749d8 100644 --- a/unit_tests/operator/Test_BatchNormImpl.cpp +++ b/unit_tests/operator/Test_BatchNormImpl.cpp @@ -20,7 +20,7 @@ using namespace Aidge; TEST_CASE("[cpu/operator] BatchNorm(forward)", "[BatchNorm][CPU]") { - std::shared_ptr<Node> myBatchNorm = BatchNorm<2>(0.00001F, 0.1F, "mybatchnorm"); + std::shared_ptr<Node> myBatchNorm = BatchNorm<2>(3, 0.00001F, 0.1F, "mybatchnorm"); auto op = std::static_pointer_cast<OperatorTensor>(myBatchNorm -> getOperator()); std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array1D<float,3> {{0.9044, 0.3028, 0.0218}}); std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float,3> {{0.1332, 0.7503, 0.0878}}); diff --git a/unit_tests/operator/Test_SoftmaxImpl.cpp b/unit_tests/operator/Test_SoftmaxImpl.cpp index 3d3c9fe4..360b7440 100644 --- a/unit_tests/operator/Test_SoftmaxImpl.cpp +++ b/unit_tests/operator/Test_SoftmaxImpl.cpp @@ -39,7 +39,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)", "[Softmax][CPU]") { } }); - std::shared_ptr<Node> mySoftmax = Softmax(); + std::shared_ptr<Node> mySoftmax = Softmax(1); auto op = std::static_pointer_cast<OperatorTensor>(mySoftmax -> getOperator()); mySoftmax->getOperator()->associateInput(0,input); mySoftmax->getOperator()->setDataType(DataType::Float32); @@ -108,7 +108,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)", "[Softmax][CPU]") { } }); - std::shared_ptr<Node> mySoftmax = Softmax(); + std::shared_ptr<Node> mySoftmax = Softmax(1); auto op = std::static_pointer_cast<OperatorTensor>(mySoftmax -> getOperator()); mySoftmax->getOperator()->associateInput(0,input); mySoftmax->getOperator()->setDataType(DataType::Float32); -- GitLab