Draft: Add InstanceNorm backend kernel for 3DUnet support
This MR is linked to aidge#308 and aims to add the missing operator InstanceNormalization (InstanceNorm) for 3D-Unet support.
InstanceNorm - CPU Backend Implementation
Backend Implementation Files
-
InstanceNormImpl.hpp
(49 lines) - Interface definition and backend registration -
InstanceNormImpl.cpp
(71 lines) - Forward/backward dispatch with tensor management -
InstanceNormImpl_kernels.hpp
(197 lines) - Optimized kernels with Welford's algorithm -
Test_InstanceNormImpl.cpp
(269 lines) - Forward/backward validation tests
Mathematical Implementation
Forward Pass
For input tensor x ∈ ℝ^(N×C×H×W), the forward computation is:
Mean and Variance (Welford's Algorithm):
\begin{align}
\mu_{n,c} &= \frac{1}{HW} \sum_{h,w} x_{n,c,h,w} \\
\sigma^2_{n,c} &= \frac{1}{HW} \sum_{h,w} (x_{n,c,h,w} - \mu_{n,c})^2
\end{align}
Output:
y_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,c}}{\sqrt{\sigma^2_{n,c} + \epsilon}} \gamma_c + \beta_c
Backward Pass
Intermediate Variables:
\begin{align}
\hat{x}_{n,c,h,w} &= \frac{x_{n,c,h,w} - \mu_{n,c}}{\sqrt{\sigma^2_{n,c} + \epsilon}} \\
\overline{g}_{n,c} &= \frac{1}{HW} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \\
\overline{g \hat{x}}_{n,c} &= \frac{1}{HW} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \hat{x}_{n,c,h,w}
\end{align}
Gradients:
\begin{align}
\frac{\partial \mathcal{L}}{\partial \gamma_c} &= \sum_{n} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \hat{x}_{n,c,h,w} \quad \text{(accumulated over full batch)} \\
\frac{\partial \mathcal{L}}{\partial \beta_c} &= \sum_{n} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \quad \text{(accumulated over full batch)} \\
\frac{\partial \mathcal{L}}{\partial x_{n,c,h,w}} &= \frac{\gamma_c}{\sqrt{\sigma^2_{n,c} + \epsilon}} \left( \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} - \overline{g}_{n,c} - \hat{x}_{n,c,h,w} \overline{g \hat{x}}_{n,c} \right)
\end{align}
Backend Architecture
Core Components
-
Interface:
InstanceNormImpl_cpu
type alias with forward/backward function signatures - Dispatch: Tensor extraction, registrar-based kernel selection, and raw pointer management
- Kernels: Template-based implementations with Welford's algorithm and multi-type support
- Testing: Comprehensive validation of forward/backward passes with reference implementations
Key Features
- Welford's Algorithm: Numerically stable online mean/variance computation
-
Template System: Multi-type support
<I, G, B, O>
for forward,<I, GI, G, GG, B, GB, GO>
for backward - Data Types: Float32/Float64 automatic registration and type mapping
- Memory Optimization: Stride-based access patterns and cache-friendly implementations
External References
Edited by matthieu marchal