From 264e9b04525c5ba2fb3e64f50670460b7237d20c Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 25 Feb 2025 12:43:05 +0100 Subject: [PATCH] Fix And Op to Equal --- src/metrics/Accuracy.cpp | 14 +++++++------- unit_tests/metrics/Test_Accuracy.cpp | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/metrics/Accuracy.cpp b/src/metrics/Accuracy.cpp index 67b9ab2..318d838 100644 --- a/src/metrics/Accuracy.cpp +++ b/src/metrics/Accuracy.cpp @@ -19,7 +19,7 @@ #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/ArgMax.hpp" #include "aidge/operator/ReduceSum.hpp" -#include "aidge/operator/And.hpp" +#include "aidge/operator/Equal.hpp" #include "aidge/recipes/GraphViewHelper.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" @@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, The graph used is the following: pred->ArgMax - ->And->ReduceSum + ->Equal->ReduceSum label->ArgMax */ @@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis); const std::shared_ptr<Node> argmax_target_node = ArgMax(axis); - const std::shared_ptr<Node> and_node = And(); + const std::shared_ptr<Node> equal_node = Equal(); const std::shared_ptr<Node> rs_node = ReduceSum(); const std::shared_ptr<Node> pred_node = Producer(prediction, "pred"); @@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, const std::shared_ptr<Node> label_node = Producer(target, "label"); label_node->addChild(argmax_target_node); - argmax_perd_node->addChild(and_node,0,0); - argmax_target_node->addChild(and_node,0,1); + argmax_perd_node->addChild(equal_node,0,0); + argmax_target_node->addChild(equal_node,0,1); - // and_node->addChild(rs_node,0,0); + // equal_node->addChild(rs_node,0,0); // Create the graph std::shared_ptr<GraphView> gv_local = - Sequential({ and_node, rs_node}); + Sequential({ equal_node, rs_node}); gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node}); gv_local->compile(prediction->getImpl()->backend(), prediction->dataType()); diff --git a/unit_tests/metrics/Test_Accuracy.cpp b/unit_tests/metrics/Test_Accuracy.cpp index f598255..d58076d 100644 --- a/unit_tests/metrics/Test_Accuracy.cpp +++ b/unit_tests/metrics/Test_Accuracy.cpp @@ -20,7 +20,7 @@ #include <vector> #include "aidge/backend/cpu/operator/ArgMaxImpl.hpp" -#include "aidge/backend/cpu/operator/AndImpl.hpp" +#include "aidge/backend/cpu/operator/EqualImpl.hpp" #include "aidge/backend/cpu/operator/ReduceSumImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/learning/metrics/Accuracy.hpp" @@ -28,7 +28,7 @@ #if USE_AIDGE_BACKEND_CUDA #include "aidge/backend/cuda/operator/ArgMaxImpl.hpp" -#include "aidge/backend/cuda/operator/AndImpl.hpp" +#include "aidge/backend/cuda/operator/EqualImpl.hpp" #include "aidge/backend/cuda/operator/ReduceSumImpl.hpp" #endif -- GitLab