Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
aidge_learning
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Eclipse Projects
aidge
aidge_learning
Commits
3309f38d
Commit
3309f38d
authored
9 months ago
by
Lucas RAKOTOARIVONY
Browse files
Options
Downloads
Patches
Plain Diff
Add knowledge distillation (KD) loss
parent
035f43e0
No related branches found
No related tags found
No related merge requests found
Pipeline
#55583
failed
9 months ago
Stage: build
Stage: test
Stage: coverage
Changes
3
Pipelines
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
include/aidge/loss/LossList.hpp
+11
-0
11 additions, 0 deletions
include/aidge/loss/LossList.hpp
python_binding/learning/loss/pybind_Loss.cpp
+1
-0
1 addition, 0 deletions
python_binding/learning/loss/pybind_Loss.cpp
src/loss/distillation/KD.cpp
+156
-0
156 additions, 0 deletions
src/loss/distillation/KD.cpp
with
168 additions
and
0 deletions
include/aidge/loss/LossList.hpp
+
11
−
0
View file @
3309f38d
...
@@ -33,6 +33,17 @@ Tensor MSE(std::shared_ptr<Tensor>& prediction,
...
@@ -33,6 +33,17 @@ Tensor MSE(std::shared_ptr<Tensor>& prediction,
const
std
::
shared_ptr
<
Tensor
>&
target
);
const
std
::
shared_ptr
<
Tensor
>&
target
);
Tensor
BCE
(
std
::
shared_ptr
<
Tensor
>&
prediction
,
Tensor
BCE
(
std
::
shared_ptr
<
Tensor
>&
prediction
,
const
std
::
shared_ptr
<
Tensor
>&
target
);
const
std
::
shared_ptr
<
Tensor
>&
target
);
/**
* @brief Compute the Knowledge Distillation loss.
* This function returns the loss and set the ``grad()`` of the prediction
* input.
* @param student_prediction Tensor returned by the Aidge Graph of student model,
* it is important that this tensor is not a copy as otherwise the backward
* function will not have a gradient to start.
* @param teacher_prediction Tensor returned by the Aidge Graph of teacher model.
*/
Tensor
KD
(
std
::
shared_ptr
<
Tensor
>&
student_prediction
,
const
std
::
shared_ptr
<
Tensor
>&
teacher_prediction
,
float
temperature
=
2.0
f
);
}
// namespace loss
}
// namespace loss
}
// namespace Aidge
}
// namespace Aidge
...
...
This diff is collapsed.
Click to expand it.
python_binding/learning/loss/pybind_Loss.cpp
+
1
−
0
View file @
3309f38d
...
@@ -24,5 +24,6 @@ void init_Loss(py::module &m) {
...
@@ -24,5 +24,6 @@ void init_Loss(py::module &m) {
m
.
def_submodule
(
"loss"
,
"Submodule dedicated to loss functions"
);
m
.
def_submodule
(
"loss"
,
"Submodule dedicated to loss functions"
);
m_loss
.
def
(
"MSE"
,
&
loss
::
MSE
,
py
::
arg
(
"graph"
),
py
::
arg
(
"target"
));
m_loss
.
def
(
"MSE"
,
&
loss
::
MSE
,
py
::
arg
(
"graph"
),
py
::
arg
(
"target"
));
m_loss
.
def
(
"BCE"
,
&
loss
::
BCE
,
py
::
arg
(
"graph"
),
py
::
arg
(
"target"
));
m_loss
.
def
(
"BCE"
,
&
loss
::
BCE
,
py
::
arg
(
"graph"
),
py
::
arg
(
"target"
));
m_loss
.
def
(
"KD"
,
&
loss
::
KD
,
py
::
arg
(
"student_prediction"
),
py
::
arg
(
"teacher_prediction"
),
py
::
arg
(
"temperature"
)
=
2.0
f
);
}
}
}
// namespace Aidge
}
// namespace Aidge
This diff is collapsed.
Click to expand it.
src/loss/distillation/KD.cpp
0 → 100644
+
156
−
0
View file @
3309f38d
/********************************************************************************
* Copyright (c) 2024 Thales
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
* Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
* Date: 12.09.2024
*
********************************************************************************/
#include
<memory>
#include
<numeric>
// std::iota
#include
"aidge/data/Tensor.hpp"
#include
"aidge/graph/GraphView.hpp"
#include
"aidge/graph/OpArgs.hpp"
#include
"aidge/loss/LossList.hpp"
#include
"aidge/recipes/GraphViewHelper.hpp"
#include
"aidge/scheduler/Scheduler.hpp"
#include
"aidge/scheduler/SequentialScheduler.hpp"
#include
"aidge/operator/OperatorTensor.hpp"
#include
"aidge/operator/Pow.hpp"
#include
"aidge/operator/ReduceMean.hpp"
#include
"aidge/operator/Softmax.hpp"
#include
"aidge/operator/Ln.hpp"
#include
"aidge/operator/Sub.hpp"
#include
"aidge/operator/Mul.hpp"
#include
"aidge/backend/cpu/operator/PowImpl.hpp"
#include
"aidge/backend/cpu/operator/ReduceMeanImpl.hpp"
#include
"aidge/backend/cpu/operator/SoftmaxImpl.hpp"
#include
"aidge/backend/cpu/operator/LnImpl.hpp"
#include
"aidge/backend/cpu/operator/SubImpl.hpp"
#include
"aidge/backend/cpu/operator/MulImpl.hpp"
Aidge
::
Tensor
Aidge
::
loss
::
KD
(
std
::
shared_ptr
<
Tensor
>&
student_prediction
,
const
std
::
shared_ptr
<
Tensor
>&
teacher_prediction
,
float
temperature
)
{
/*
Implementation note:
Knowledge distillation (KD) loss function
KD is computed using a graph in order to not be backend dependant.
The graph used is the following:
student_predictions->Mul_student
(1/temperature)->Mul_student
teacher_predictions->Mul_teacher
(1/temperature)->Mul_teacher
Mul_student->Softmax1->Ln->Mul
Mul_teacher->Softmax2->Mul
Mul->Mul2
(-1)->Mul2
Mul2->ReduceMean->Loss
Softmax1->Sub
Softmax2->Sub
Sub->Gradient
*/
AIDGE_ASSERT
(
student_prediction
->
backend
()
==
teacher_prediction
->
backend
(),
"'prediction' and 'target' Tensors must be on the "
"same backend. Found {} and {}.
\n
"
,
student_prediction
->
backend
(),
teacher_prediction
->
backend
());
AIDGE_ASSERT
(
student_prediction
->
dims
()
==
teacher_prediction
->
dims
(),
"'prediction' (shape {}) and 'target' (shape {}) Tensors must "
"have the same dimensions.
\n
"
,
student_prediction
->
dims
(),
teacher_prediction
->
dims
());
AIDGE_ASSERT
(
student_prediction
->
dataType
()
==
teacher_prediction
->
dataType
(),
"'prediction' (data type {}) and 'target' (data type {}) "
"Tensors must have the same data type.
\n
"
,
student_prediction
->
dataType
(),
teacher_prediction
->
dataType
());
// Define nodes: inputs
const
std
::
shared_ptr
<
Node
>
student_node
=
Producer
(
student_prediction
,
"stud_pred"
);
const
std
::
shared_ptr
<
Node
>
teacher_node
=
Producer
(
teacher_prediction
,
"tchr_pred"
);
// Define node: mul_student = student_predictons * (1/temperature)
const
std
::
shared_ptr
<
Node
>
mul_student_node
=
Mul
(
"temperature_student"
);
// Note: this assume target is [nbBatch, nbChan]
Producer
(
std
::
make_shared
<
Tensor
>
(
Array1D
<
float
,
1
>
{{
1
/
temperature
}}))
->
addChild
(
mul_student_node
,
0
,
1
);
student_node
->
addChild
(
mul_student_node
,
0
,
0
);
// Define node: mul_teacher = teacher_predictons * (1/temperature)
const
std
::
shared_ptr
<
Node
>
mul_teacher_node
=
Mul
(
"temperature_teacher"
);
// Note: this assume target is [nbBatch, nbChan]
Producer
(
std
::
make_shared
<
Tensor
>
(
Array1D
<
float
,
1
>
{{
1
/
temperature
}}))
->
addChild
(
mul_teacher_node
,
0
,
1
);
teacher_node
->
addChild
(
mul_teacher_node
,
0
,
0
);
// Define node: soft_student = softmax(mul_student)
const
std
::
shared_ptr
<
Node
>
soft_student_node
=
Softmax
(
1
,
"softmax_student"
);
mul_student_node
->
addChild
(
soft_student_node
,
0
,
0
);
// Define node: ln_soft_student = ln(soft_student)
const
std
::
shared_ptr
<
Node
>
ln_soft_student_node
=
Ln
(
"ln_softmax_student"
);
soft_student_node
->
addChild
(
ln_soft_student_node
,
0
,
0
);
// Define node: soft_teacher = softmax(mul_teacher)
const
std
::
shared_ptr
<
Node
>
soft_teacher_node
=
Softmax
(
1
,
"softmax_teacher"
);
mul_teacher_node
->
addChild
(
soft_teacher_node
,
0
,
0
);
// Define node: mul = soft_student * soft_teacher
const
std
::
shared_ptr
<
Node
>
mul_node
=
Mul
(
"softmax_multiplication"
);
ln_soft_student_node
->
addChild
(
mul_node
,
0
,
0
);
// log_soft_stud_node
soft_teacher_node
->
addChild
(
mul_node
,
0
,
1
);
const
std
::
vector
<
DimSize_t
>
mDims
=
teacher_prediction
->
dims
();
float
value
=
-
1.0
*
mDims
[
1
];
// Define node: mul2 = mul * (-n)
const
std
::
shared_ptr
<
Node
>
mul2_node
=
Mul
(
"softmax_negative"
);
Producer
(
std
::
make_shared
<
Tensor
>
(
Array1D
<
float
,
1
>
{{
value
}}))
->
addChild
(
mul2_node
,
0
,
1
);
mul_node
->
addChild
(
mul2_node
,
0
,
0
);
// Define node: loss
std
::
vector
<
int
>
axes_dims
(
student_prediction
->
nbDims
());
std
::
iota
(
std
::
begin
(
axes_dims
),
std
::
end
(
axes_dims
),
0
);
auto
rm_node
=
ReduceMean
(
axes_dims
,
1
,
"loss"
);
mul2_node
->
addChild
(
rm_node
,
0
,
0
);
// Define node: gradient
const
std
::
shared_ptr
<
Node
>
sub_node
=
Sub
(
"gradient"
);
soft_student_node
->
addChild
(
sub_node
,
0
,
0
);
// log_soft_stud_node
soft_teacher_node
->
addChild
(
sub_node
,
0
,
1
);
// Create GraphView
std
::
shared_ptr
<
GraphView
>
gv_loss
=
std
::
make_shared
<
GraphView
>
(
"KD"
);
gv_loss
->
add
({
student_node
,
teacher_node
,
mul_student_node
->
getParent
(
1
),
mul_student_node
,
mul_teacher_node
->
getParent
(
1
),
mul_teacher_node
,
soft_student_node
,
ln_soft_student_node
,
soft_teacher_node
,
mul_node
,
mul2_node
->
getParent
(
1
),
mul2_node
,
rm_node
,
sub_node
});
gv_loss
->
compile
(
student_prediction
->
getImpl
()
->
backend
(),
student_prediction
->
dataType
());
SequentialScheduler
ss_loss
{
gv_loss
};
ss_loss
.
forward
(
false
);
std
::
shared_ptr
<
Tensor
>
outputGrad
=
student_prediction
->
grad
();
const
std
::
shared_ptr
<
OperatorTensor
>
gradient_op
=
std
::
dynamic_pointer_cast
<
OperatorTensor
>
(
sub_node
->
getOperator
());
outputGrad
->
copyFrom
(
gradient_op
->
getOutput
(
0
)
->
clone
());
// Update gradient
const
std
::
shared_ptr
<
OperatorTensor
>
loss_op
=
std
::
dynamic_pointer_cast
<
OperatorTensor
>
(
rm_node
->
getOperator
());
return
loss_op
->
getOutput
(
0
)
->
clone
();
// Return loss
}
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment