Skip to content
Snippets Groups Projects

Fix Accuracy metric

Merged Houssem ROUIS requested to merge fix_accuracy into dev
2 files
+ 9
9
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 7
7
@@ -19,7 +19,7 @@
@@ -19,7 +19,7 @@
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/And.hpp"
#include "aidge/operator/Equal.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
@@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
@@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
The graph used is the following:
The graph used is the following:
pred->ArgMax
pred->ArgMax
->And->ReduceSum
->Equal->ReduceSum
label->ArgMax
label->ArgMax
*/
*/
@@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
@@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis);
const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis);
const std::shared_ptr<Node> argmax_target_node = ArgMax(axis);
const std::shared_ptr<Node> argmax_target_node = ArgMax(axis);
const std::shared_ptr<Node> and_node = And();
const std::shared_ptr<Node> equal_node = Equal();
const std::shared_ptr<Node> rs_node = ReduceSum();
const std::shared_ptr<Node> rs_node = ReduceSum();
const std::shared_ptr<Node> pred_node = Producer(prediction, "pred");
const std::shared_ptr<Node> pred_node = Producer(prediction, "pred");
@@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
@@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Node> label_node = Producer(target, "label");
const std::shared_ptr<Node> label_node = Producer(target, "label");
label_node->addChild(argmax_target_node);
label_node->addChild(argmax_target_node);
argmax_perd_node->addChild(and_node,0,0);
argmax_perd_node->addChild(equal_node,0,0);
argmax_target_node->addChild(and_node,0,1);
argmax_target_node->addChild(equal_node,0,1);
// and_node->addChild(rs_node,0,0);
// equal_node->addChild(rs_node,0,0);
// Create the graph
// Create the graph
std::shared_ptr<GraphView> gv_local =
std::shared_ptr<GraphView> gv_local =
Sequential({ and_node, rs_node});
Sequential({ equal_node, rs_node});
gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node});
gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node});
gv_local->compile(prediction->getImpl()->backend(), prediction->dataType());
gv_local->compile(prediction->getImpl()->backend(), prediction->dataType());
Loading