From 95d9888b683d2f7f879c543a7ec0c21e06e197a3 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Mon, 23 Sep 2024 16:30:26 +0200
Subject: [PATCH] force BCE output to be on host

---
 src/loss/classification/BCE.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/src/loss/classification/BCE.cpp b/src/loss/classification/BCE.cpp
index d515607..296d466 100644
--- a/src/loss/classification/BCE.cpp
+++ b/src/loss/classification/BCE.cpp
@@ -122,7 +122,7 @@ Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction,
     // Define node: loss
     std::vector<int> axes_dims(prediction->nbDims());
     std::iota(std::begin(axes_dims), std::end(axes_dims), 0);
-    auto loss_node = ReduceMean(axes_dims, 1, "loss");
+    auto loss_node = ReduceMean(axes_dims, true, false, "loss");
     sub3_node->addChild(loss_node, 0, 0);
 
     // Define node: gradient
@@ -153,5 +153,8 @@ Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction,
     outputGrad->copyFrom(gradient_op->getOutput(0)->clone()); // Update gradient
 
     const std::shared_ptr<OperatorTensor> loss_op = std::dynamic_pointer_cast<OperatorTensor>(loss_node->getOperator());
-    return loss_op->getOutput(0)->clone(); // Return loss
+    // return loss_op->getOutput(0)->clone(); // Return loss
+    std::shared_ptr<Tensor> fallback;
+
+    return loss_op->getOutput(0)->refFrom(fallback, "cpu");
 }
-- 
GitLab