Skip to content

Draft: Add InstanceNorm backend kernel for 3DUnet support

matthieu marchal requested to merge feat/3d-unet-instance-norm into dev

This MR is linked to aidge#308 and aims to add the missing operator InstanceNormalization (InstanceNorm) for 3D-Unet support.

InstanceNorm - CPU Backend Implementation

Backend Implementation Files

  • InstanceNormImpl.hpp (49 lines) - Interface definition and backend registration
  • InstanceNormImpl.cpp (71 lines) - Forward/backward dispatch with tensor management
  • InstanceNormImpl_kernels.hpp (197 lines) - Optimized kernels with Welford's algorithm
  • Test_InstanceNormImpl.cpp (269 lines) - Forward/backward validation tests

Mathematical Implementation

Forward Pass

For input tensor x ∈ ℝ^(N×C×H×W), the forward computation is:

Mean and Variance (Welford's Algorithm):

\begin{align}
\mu_{n,c} &= \frac{1}{HW} \sum_{h,w} x_{n,c,h,w} \\
\sigma^2_{n,c} &= \frac{1}{HW} \sum_{h,w} (x_{n,c,h,w} - \mu_{n,c})^2
\end{align}

Output:

y_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,c}}{\sqrt{\sigma^2_{n,c} + \epsilon}} \gamma_c + \beta_c

Backward Pass

Intermediate Variables:

\begin{align}
\hat{x}_{n,c,h,w} &= \frac{x_{n,c,h,w} - \mu_{n,c}}{\sqrt{\sigma^2_{n,c} + \epsilon}} \\
\overline{g}_{n,c} &= \frac{1}{HW} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \\
\overline{g \hat{x}}_{n,c} &= \frac{1}{HW} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \hat{x}_{n,c,h,w}
\end{align}

Gradients:

\begin{align}
\frac{\partial \mathcal{L}}{\partial \gamma_c} &= \sum_{n} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \hat{x}_{n,c,h,w} \quad \text{(accumulated over full batch)} \\
\frac{\partial \mathcal{L}}{\partial \beta_c} &= \sum_{n} \sum_{h,w} \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} \quad \text{(accumulated over full batch)} \\
\frac{\partial \mathcal{L}}{\partial x_{n,c,h,w}} &= \frac{\gamma_c}{\sqrt{\sigma^2_{n,c} + \epsilon}} \left( \frac{\partial \mathcal{L}}{\partial y_{n,c,h,w}} - \overline{g}_{n,c} - \hat{x}_{n,c,h,w} \overline{g \hat{x}}_{n,c} \right)
\end{align}

Backend Architecture

Core Components

  • Interface: InstanceNormImpl_cpu type alias with forward/backward function signatures
  • Dispatch: Tensor extraction, registrar-based kernel selection, and raw pointer management
  • Kernels: Template-based implementations with Welford's algorithm and multi-type support
  • Testing: Comprehensive validation of forward/backward passes with reference implementations

Key Features

  • Welford's Algorithm: Numerically stable online mean/variance computation
  • Template System: Multi-type support <I, G, B, O> for forward, <I, GI, G, GG, B, GB, GO> for backward
  • Data Types: Float32/Float64 automatic registration and type mapping
  • Memory Optimization: Stride-based access patterns and cache-friendly implementations

External References

Edited by matthieu marchal

Merge request reports

Loading