From 95d9888b683d2f7f879c543a7ec0c21e06e197a3 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Mon, 23 Sep 2024 16:30:26 +0200 Subject: [PATCH] force BCE output to be on host --- src/loss/classification/BCE.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/loss/classification/BCE.cpp b/src/loss/classification/BCE.cpp index d515607..296d466 100644 --- a/src/loss/classification/BCE.cpp +++ b/src/loss/classification/BCE.cpp @@ -122,7 +122,7 @@ Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction, // Define node: loss std::vector<int> axes_dims(prediction->nbDims()); std::iota(std::begin(axes_dims), std::end(axes_dims), 0); - auto loss_node = ReduceMean(axes_dims, 1, "loss"); + auto loss_node = ReduceMean(axes_dims, true, false, "loss"); sub3_node->addChild(loss_node, 0, 0); // Define node: gradient @@ -153,5 +153,8 @@ Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction, outputGrad->copyFrom(gradient_op->getOutput(0)->clone()); // Update gradient const std::shared_ptr<OperatorTensor> loss_op = std::dynamic_pointer_cast<OperatorTensor>(loss_node->getOperator()); - return loss_op->getOutput(0)->clone(); // Return loss + // return loss_op->getOutput(0)->clone(); // Return loss + std::shared_ptr<Tensor> fallback; + + return loss_op->getOutput(0)->refFrom(fallback, "cpu"); } -- GitLab