Skip to content
Snippets Groups Projects
Commit 7c755ad7 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'fix_accuracy' into 'dev'

Fix Accuracy metric

See merge request !36
parents aae638bb 264e9b04
No related branches found
No related tags found
2 merge requests!44Update 0.2.3 -> 0.3.0,!36Fix Accuracy metric
Pipeline #67294 passed
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/ArgMax.hpp" #include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/ReduceSum.hpp" #include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/And.hpp" #include "aidge/operator/Equal.hpp"
#include "aidge/recipes/GraphViewHelper.hpp" #include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
...@@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, ...@@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
The graph used is the following: The graph used is the following:
pred->ArgMax pred->ArgMax
->And->ReduceSum ->Equal->ReduceSum
label->ArgMax label->ArgMax
*/ */
...@@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, ...@@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis); const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis);
const std::shared_ptr<Node> argmax_target_node = ArgMax(axis); const std::shared_ptr<Node> argmax_target_node = ArgMax(axis);
const std::shared_ptr<Node> and_node = And(); const std::shared_ptr<Node> equal_node = Equal();
const std::shared_ptr<Node> rs_node = ReduceSum(); const std::shared_ptr<Node> rs_node = ReduceSum();
const std::shared_ptr<Node> pred_node = Producer(prediction, "pred"); const std::shared_ptr<Node> pred_node = Producer(prediction, "pred");
...@@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, ...@@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Node> label_node = Producer(target, "label"); const std::shared_ptr<Node> label_node = Producer(target, "label");
label_node->addChild(argmax_target_node); label_node->addChild(argmax_target_node);
argmax_perd_node->addChild(and_node,0,0); argmax_perd_node->addChild(equal_node,0,0);
argmax_target_node->addChild(and_node,0,1); argmax_target_node->addChild(equal_node,0,1);
// and_node->addChild(rs_node,0,0); // equal_node->addChild(rs_node,0,0);
// Create the graph // Create the graph
std::shared_ptr<GraphView> gv_local = std::shared_ptr<GraphView> gv_local =
Sequential({ and_node, rs_node}); Sequential({ equal_node, rs_node});
gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node}); gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node});
gv_local->compile(prediction->getImpl()->backend(), prediction->dataType()); gv_local->compile(prediction->getImpl()->backend(), prediction->dataType());
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <vector> #include <vector>
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp" #include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include "aidge/backend/cpu/operator/AndImpl.hpp" #include "aidge/backend/cpu/operator/EqualImpl.hpp"
#include "aidge/backend/cpu/operator/ReduceSumImpl.hpp" #include "aidge/backend/cpu/operator/ReduceSumImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/learning/metrics/Accuracy.hpp" #include "aidge/learning/metrics/Accuracy.hpp"
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#if USE_AIDGE_BACKEND_CUDA #if USE_AIDGE_BACKEND_CUDA
#include "aidge/backend/cuda/operator/ArgMaxImpl.hpp" #include "aidge/backend/cuda/operator/ArgMaxImpl.hpp"
#include "aidge/backend/cuda/operator/AndImpl.hpp" #include "aidge/backend/cuda/operator/EqualImpl.hpp"
#include "aidge/backend/cuda/operator/ReduceSumImpl.hpp" #include "aidge/backend/cuda/operator/ReduceSumImpl.hpp"
#endif #endif
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment