From 179480936392d229791336bd1fc58506a06aaf29 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Mon, 3 Jun 2024 08:58:17 +0000
Subject: [PATCH] minor changes

---
 src/operator/SoftmaxImpl.cpp              | 3 +--
 unit_tests/operator/Test_MetaOperator.cpp | 6 +++---
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/operator/SoftmaxImpl.cpp b/src/operator/SoftmaxImpl.cpp
index 24026761..ed3d625d 100644
--- a/src/operator/SoftmaxImpl.cpp
+++ b/src/operator/SoftmaxImpl.cpp
@@ -37,10 +37,9 @@ void Aidge::SoftmaxImpl_cpu::forward() {
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
 
     Softmax_Op::Attrs attr = dynamic_cast<const Softmax_Op&>(mOp).getStaticAttributes();
-    const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
 
     // Call kernel
-    kernelFunc(axisIdx,
+    kernelFunc(std::get<0>(attr), // axisIdx
                std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
                std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
                std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index aa9a3909..78058eca 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -194,10 +194,10 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
     SECTION("LSTM(forward)") {
         auto pop = Pop();
         auto myLSTM = LSTM(32, 64, 0, true, "ltsm");
-        auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
+        auto op = std::dynamic_pointer_cast<MetaOperator_Op>(myLSTM->getOperator());
 
-        auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
-        microGraph->save("lstm", false, false);
+        auto microGraph = op->getMicroGraph();
+        microGraph->save("lstm", false, true);
 
         REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
         REQUIRE(myLSTM->nbData() == 1);
-- 
GitLab