Skip to content
Snippets Groups Projects
Commit 193f46ef authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

add the support of the CELoss

parent 6e7cf75a
No related branches found
No related tags found
2 merge requests!44Update 0.2.3 -> 0.3.0,!32Add the Cross Entropy Loss
Pipeline #62642 failed
...@@ -31,9 +31,13 @@ namespace loss { ...@@ -31,9 +31,13 @@ namespace loss {
*/ */
Tensor MSE(std::shared_ptr<Tensor>& prediction, Tensor MSE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target); const std::shared_ptr<Tensor>& target);
Tensor BCE(std::shared_ptr<Tensor>& prediction, Tensor BCE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target); const std::shared_ptr<Tensor>& target);
Tensor CELoss(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target);
} // namespace loss } // namespace loss
} // namespace Aidge } // namespace Aidge
......
...@@ -19,10 +19,12 @@ namespace py = pybind11; ...@@ -19,10 +19,12 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Loss(py::module &m) { void init_Loss(py::module &m)
auto m_loss = {
m.def_submodule("loss", "Submodule dedicated to loss functions"); auto m_loss = m.def_submodule("loss", "Submodule dedicated to loss functions");
m_loss.def("MSE", &loss::MSE, py::arg("graph"), py::arg("target")); m_loss.def("MSE", &loss::MSE, py::arg("graph"), py::arg("target"));
m_loss.def("BCE", &loss::BCE, py::arg("graph"), py::arg("target")); m_loss.def("BCE", &loss::BCE, py::arg("graph"), py::arg("target"));
m_loss.def("CELoss", &loss::CELoss, py::arg("graph"), py::arg("target"));
} }
} // namespace Aidge } // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <numeric> // std::iota
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/loss/LossList.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Ln.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/Softmax.hpp"
Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target)
{
AIDGE_ASSERT(prediction->nbDims() == 2,
"Label must have two dims: [BatchSize, NbChannel]");
AIDGE_ASSERT(prediction->backend() == target->backend(),
"'prediction' and 'target' Tensors must be on the "
"same backend. Found {} and {}.\n",
prediction->backend(), target->backend());
AIDGE_ASSERT(prediction->dims() == target->dims(),
"'prediction' (shape {}) and 'target' (shape {}) Tensors must "
"have the same dimensions.\n",
prediction->dims(), target->dims());
AIDGE_ASSERT(prediction->dataType() == target->dataType(),
"'prediction' (data type {}) and 'target' (data type {}) "
"Tensors must have the same data type.\n",
prediction->dataType(), target->dataType())
auto backend = prediction->backend();
auto dataType = prediction->dataType();
// Compute the predicition SoftMax
auto softmaxOp = Softmax_Op(1);
softmaxOp.setDataType(dataType);
softmaxOp.setBackend(backend);
softmaxOp.associateInput(0, prediction);
softmaxOp.forward();
auto softmax = softmaxOp.getOutput(0);
// Compute the loss value using a GraphView
auto targetNode = Producer(target);
auto softmaxNode = Producer(softmax);
auto logNode = Ln();
auto mulNode = Mul();
auto sumNode = ReduceSum({1});
auto meanNode = ReduceMean({0});
softmaxNode->addChild(logNode, 0, 0);
logNode->addChild(mulNode, 0, 0);
targetNode->addChild(mulNode, 0, 1);
mulNode->addChild(sumNode);
sumNode->addChild(meanNode);
std::shared_ptr<GraphView> lossGraphView = std::make_shared<GraphView>("CELoss");
lossGraphView->add({targetNode, softmaxNode, logNode, mulNode, sumNode, meanNode});
lossGraphView->compile(backend, dataType);
SequentialScheduler scheduler(lossGraphView);
scheduler.forward(true);
auto meanOp = std::static_pointer_cast<OperatorTensor>(meanNode->getOperator());
auto lossTensor = meanOp->getOutput(0);
(*lossTensor) *= (Aidge::Array1D<float, 1> {-1});
// Compute and set the error signal
auto subOp = Sub_Op();
subOp.setDataType(dataType);
subOp.setBackend(backend);
subOp.associateInput(0, softmax);
subOp.associateInput(1, target);
subOp.forward();
auto err = subOp.getOutput(0);
std::size_t batchSize = (target->dims())[0];
(*err) /= (Aidge::Array1D<float, 1> {static_cast<float> (batchSize)});
prediction->setGrad(err);
// Return the loss value
return (*lossTensor);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment