Skip to content
Snippets Groups Projects
Commit 3309f38d authored by Lucas RAKOTOARIVONY's avatar Lucas RAKOTOARIVONY
Browse files

Add knowledge distillation (KD) loss

parent 035f43e0
No related branches found
No related tags found
No related merge requests found
Pipeline #55583 failed
...@@ -33,6 +33,17 @@ Tensor MSE(std::shared_ptr<Tensor>& prediction, ...@@ -33,6 +33,17 @@ Tensor MSE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target); const std::shared_ptr<Tensor>& target);
Tensor BCE(std::shared_ptr<Tensor>& prediction, Tensor BCE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target); const std::shared_ptr<Tensor>& target);
/**
* @brief Compute the Knowledge Distillation loss.
* This function returns the loss and set the ``grad()`` of the prediction
* input.
* @param student_prediction Tensor returned by the Aidge Graph of student model,
* it is important that this tensor is not a copy as otherwise the backward
* function will not have a gradient to start.
* @param teacher_prediction Tensor returned by the Aidge Graph of teacher model.
*/
Tensor KD(std::shared_ptr<Tensor>& student_prediction,
const std::shared_ptr<Tensor>& teacher_prediction, float temperature = 2.0f);
} // namespace loss } // namespace loss
} // namespace Aidge } // namespace Aidge
......
...@@ -24,5 +24,6 @@ void init_Loss(py::module &m) { ...@@ -24,5 +24,6 @@ void init_Loss(py::module &m) {
m.def_submodule("loss", "Submodule dedicated to loss functions"); m.def_submodule("loss", "Submodule dedicated to loss functions");
m_loss.def("MSE", &loss::MSE, py::arg("graph"), py::arg("target")); m_loss.def("MSE", &loss::MSE, py::arg("graph"), py::arg("target"));
m_loss.def("BCE", &loss::BCE, py::arg("graph"), py::arg("target")); m_loss.def("BCE", &loss::BCE, py::arg("graph"), py::arg("target"));
m_loss.def("KD", &loss::KD, py::arg("student_prediction"), py::arg("teacher_prediction"), py::arg("temperature") = 2.0f);
} }
} // namespace Aidge } // namespace Aidge
/********************************************************************************
* Copyright (c) 2024 Thales
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
* Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
* Date: 12.09.2024
*
********************************************************************************/
#include <memory>
#include <numeric> // std::iota
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/loss/LossList.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Pow.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/operator/Ln.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/backend/cpu/operator/PowImpl.hpp"
#include "aidge/backend/cpu/operator/ReduceMeanImpl.hpp"
#include "aidge/backend/cpu/operator/SoftmaxImpl.hpp"
#include "aidge/backend/cpu/operator/LnImpl.hpp"
#include "aidge/backend/cpu/operator/SubImpl.hpp"
#include "aidge/backend/cpu/operator/MulImpl.hpp"
Aidge::Tensor Aidge::loss::KD(std::shared_ptr<Tensor>& student_prediction,
const std::shared_ptr<Tensor>& teacher_prediction,
float temperature) {
/*
Implementation note:
Knowledge distillation (KD) loss function
KD is computed using a graph in order to not be backend dependant.
The graph used is the following:
student_predictions->Mul_student
(1/temperature)->Mul_student
teacher_predictions->Mul_teacher
(1/temperature)->Mul_teacher
Mul_student->Softmax1->Ln->Mul
Mul_teacher->Softmax2->Mul
Mul->Mul2
(-1)->Mul2
Mul2->ReduceMean->Loss
Softmax1->Sub
Softmax2->Sub
Sub->Gradient
*/
AIDGE_ASSERT(student_prediction->backend() == teacher_prediction->backend(),
"'prediction' and 'target' Tensors must be on the "
"same backend. Found {} and {}.\n",
student_prediction->backend(), teacher_prediction->backend());
AIDGE_ASSERT(student_prediction->dims() == teacher_prediction->dims(),
"'prediction' (shape {}) and 'target' (shape {}) Tensors must "
"have the same dimensions.\n",
student_prediction->dims(), teacher_prediction->dims());
AIDGE_ASSERT(student_prediction->dataType() == teacher_prediction->dataType(),
"'prediction' (data type {}) and 'target' (data type {}) "
"Tensors must have the same data type.\n",
student_prediction->dataType(), teacher_prediction->dataType());
// Define nodes: inputs
const std::shared_ptr<Node> student_node = Producer(student_prediction, "stud_pred");
const std::shared_ptr<Node> teacher_node = Producer(teacher_prediction, "tchr_pred");
// Define node: mul_student = student_predictons * (1/temperature)
const std::shared_ptr<Node> mul_student_node = Mul("temperature_student");
// Note: this assume target is [nbBatch, nbChan]
Producer(std::make_shared<Tensor>(
Array1D<float, 1>{{1 / temperature}}))
->addChild(mul_student_node, 0, 1);
student_node->addChild(mul_student_node, 0, 0);
// Define node: mul_teacher = teacher_predictons * (1/temperature)
const std::shared_ptr<Node> mul_teacher_node = Mul("temperature_teacher");
// Note: this assume target is [nbBatch, nbChan]
Producer(std::make_shared<Tensor>(
Array1D<float, 1>{{1 / temperature}}))
->addChild(mul_teacher_node, 0, 1);
teacher_node->addChild(mul_teacher_node, 0, 0);
// Define node: soft_student = softmax(mul_student)
const std::shared_ptr<Node> soft_student_node = Softmax(1, "softmax_student");
mul_student_node->addChild(soft_student_node, 0, 0);
// Define node: ln_soft_student = ln(soft_student)
const std::shared_ptr<Node> ln_soft_student_node = Ln("ln_softmax_student");
soft_student_node->addChild(ln_soft_student_node, 0, 0);
// Define node: soft_teacher = softmax(mul_teacher)
const std::shared_ptr<Node> soft_teacher_node = Softmax(1, "softmax_teacher");
mul_teacher_node->addChild(soft_teacher_node, 0, 0);
// Define node: mul = soft_student * soft_teacher
const std::shared_ptr<Node> mul_node = Mul("softmax_multiplication");
ln_soft_student_node->addChild(mul_node, 0, 0); // log_soft_stud_node
soft_teacher_node->addChild(mul_node, 0, 1);
const std::vector<DimSize_t> mDims = teacher_prediction->dims();
float value = -1.0*mDims[1];
// Define node: mul2 = mul * (-n)
const std::shared_ptr<Node> mul2_node = Mul("softmax_negative");
Producer(std::make_shared<Tensor>(
Array1D<float, 1>{{value}}))
->addChild(mul2_node, 0, 1);
mul_node->addChild(mul2_node, 0, 0);
// Define node: loss
std::vector<int> axes_dims(student_prediction->nbDims());
std::iota(std::begin(axes_dims), std::end(axes_dims), 0);
auto rm_node = ReduceMean(axes_dims, 1, "loss");
mul2_node->addChild(rm_node, 0, 0);
// Define node: gradient
const std::shared_ptr<Node> sub_node = Sub("gradient");
soft_student_node->addChild(sub_node, 0, 0); // log_soft_stud_node
soft_teacher_node->addChild(sub_node, 0, 1);
// Create GraphView
std::shared_ptr<GraphView> gv_loss = std::make_shared<GraphView>("KD");
gv_loss->add({student_node, teacher_node,
mul_student_node->getParent(1), mul_student_node,
mul_teacher_node->getParent(1), mul_teacher_node,
soft_student_node, ln_soft_student_node,
soft_teacher_node, mul_node,
mul2_node->getParent(1), mul2_node,
rm_node, sub_node});
gv_loss->compile(student_prediction->getImpl()->backend(), student_prediction->dataType());
SequentialScheduler ss_loss{gv_loss};
ss_loss.forward(false);
std::shared_ptr<Tensor> outputGrad = student_prediction->grad();
const std::shared_ptr<OperatorTensor> gradient_op = std::dynamic_pointer_cast<OperatorTensor>(sub_node->getOperator());
outputGrad->copyFrom(gradient_op->getOutput(0)->clone()); // Update gradient
const std::shared_ptr<OperatorTensor> loss_op = std::dynamic_pointer_cast<OperatorTensor>(rm_node->getOperator());
return loss_op->getOutput(0)->clone(); // Return loss
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment