Skip to content
Snippets Groups Projects

Add knowledge distillation (KD) loss to aidge_learning

Merged Lucas RAKOTOARIVONY requested to merge lrakotoarivony/aidge_learning:main into dev
Files
3
@@ -34,6 +34,17 @@ Tensor MSE(std::shared_ptr<Tensor>& prediction,
Tensor BCE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target);
/**
* @brief Compute the Knowledge Distillation loss.
* This function returns the loss and set the ``grad()`` of the prediction
* input.
* @param student_prediction Tensor returned by the Aidge Graph of student model,
* it is important that this tensor is not a copy as otherwise the backward
* function will not have a gradient to start.
* @param teacher_prediction Tensor returned by the Aidge Graph of teacher model.
*/
Tensor KD(std::shared_ptr<Tensor>& student_prediction,
const std::shared_ptr<Tensor>& teacher_prediction, float temperature = 2.0f);
Tensor CELoss(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target);
Loading