From 264e9b04525c5ba2fb3e64f50670460b7237d20c Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 25 Feb 2025 12:43:05 +0100
Subject: [PATCH] Fix And Op to Equal

---
 src/metrics/Accuracy.cpp             | 14 +++++++-------
 unit_tests/metrics/Test_Accuracy.cpp |  4 ++--
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/src/metrics/Accuracy.cpp b/src/metrics/Accuracy.cpp
index 67b9ab2..318d838 100644
--- a/src/metrics/Accuracy.cpp
+++ b/src/metrics/Accuracy.cpp
@@ -19,7 +19,7 @@
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/ArgMax.hpp"
 #include "aidge/operator/ReduceSum.hpp"
-#include "aidge/operator/And.hpp"
+#include "aidge/operator/Equal.hpp"
 #include "aidge/recipes/GraphViewHelper.hpp"
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/scheduler/SequentialScheduler.hpp"
@@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
     The graph used is the following:
 
     pred->ArgMax
-                  ->And->ReduceSum
+                  ->Equal->ReduceSum
     label->ArgMax
     */
 
@@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
     const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis);
     const std::shared_ptr<Node> argmax_target_node = ArgMax(axis);
 
-    const std::shared_ptr<Node> and_node = And();
+    const std::shared_ptr<Node> equal_node = Equal();
     const std::shared_ptr<Node> rs_node = ReduceSum();
 
     const std::shared_ptr<Node> pred_node = Producer(prediction, "pred");
@@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
     const std::shared_ptr<Node> label_node = Producer(target, "label");
     label_node->addChild(argmax_target_node);
 
-    argmax_perd_node->addChild(and_node,0,0);
-    argmax_target_node->addChild(and_node,0,1);
+    argmax_perd_node->addChild(equal_node,0,0);
+    argmax_target_node->addChild(equal_node,0,1);
 
-    // and_node->addChild(rs_node,0,0);
+    // equal_node->addChild(rs_node,0,0);
 
     // Create the graph
     std::shared_ptr<GraphView> gv_local =
-        Sequential({ and_node, rs_node});
+        Sequential({ equal_node, rs_node});
 
     gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node});
     gv_local->compile(prediction->getImpl()->backend(), prediction->dataType());
diff --git a/unit_tests/metrics/Test_Accuracy.cpp b/unit_tests/metrics/Test_Accuracy.cpp
index f598255..d58076d 100644
--- a/unit_tests/metrics/Test_Accuracy.cpp
+++ b/unit_tests/metrics/Test_Accuracy.cpp
@@ -20,7 +20,7 @@
 #include <vector>
 
 #include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
-#include "aidge/backend/cpu/operator/AndImpl.hpp"
+#include "aidge/backend/cpu/operator/EqualImpl.hpp"
 #include "aidge/backend/cpu/operator/ReduceSumImpl.hpp"
 #include "aidge/data/Tensor.hpp"
 #include "aidge/learning/metrics/Accuracy.hpp"
@@ -28,7 +28,7 @@
 
 #if USE_AIDGE_BACKEND_CUDA
 #include "aidge/backend/cuda/operator/ArgMaxImpl.hpp"
-#include "aidge/backend/cuda/operator/AndImpl.hpp"
+#include "aidge/backend/cuda/operator/EqualImpl.hpp"
 #include "aidge/backend/cuda/operator/ReduceSumImpl.hpp"
 #endif
 
-- 
GitLab