Draft: Add InstanceNorm operator for 3DUnet support
This MR is linked to aidge#308 and aims to add the missing operator InstanceNormalization (InstanceNorm) for 3D-Unet support.
InstanceNorm - Core Framework Implementation
Core Framework Files
-
InstanceNorm.hpp
(202 lines) - Main operator header with epsilon attribute and comprehensive mathematical documentation -
InstanceNorm.cpp
(81 lines) - Operator implementation with constructor, dimension forwarding, and factory function -
pybind_InstanceNorm.cpp
(128 lines) - Python bindings with LaTeX documentation and parameter defaults -
Test_InstanceNorm_Op.cpp
(169 lines) - Unit tests for attributes, dimension handling, and functionality
Core Operator Architecture
InstanceNorm.hpp
)
Header Implementation (-
Attribute System:
StaticAttributes<InstanceNormAttr>
with single Epsilon parameter -
Class Definition:
InstanceNorm_Op
inherits fromOperatorTensorWithImpl<InstanceNorm_Op>
- Documentation: Extensive LaTeX formulations for forward/backward passes
- API Design: Epsilon getter and factory function with learnable parameters
InstanceNorm.cpp
)
Source Implementation (- Constructor: Epsilon initialization with three input categories (Data, Param, Param)
- Dimension Forwarding: Multi-dimensional tensor shape inference and parameter validation
-
Factory Function: Creates nodes with attached scale/bias parameters via
addProducer()
pybind_InstanceNorm.cpp
)
Python Bindings (- Class Binding: Python class with attribute access and comprehensive docstrings
-
Function Binding: Factory with defaults (
nb_features
,epsilon=1e-5
,name=""
) - Registration: Automatic operator registration for Python interface
Test_InstanceNorm_Op.cpp
)
Unit Testing (- Coverage: 3D/4D tensors, constructors, attributes, factory function, and I/O names
- Validation: Epsilon handling, dimension forwarding, and parameter attachment
Key Implementation Details
- Single Attribute: Epsilon (float, default 1e-5) for numerical stability
- Input Structure: data_input, scale (γ), bias (β) - latter two are learnable per-channel parameters
- Normalization: Per-instance, per-channel across spatial dimensions
- Mathematical Support: Complete forward/backward formulations with gradient computation
External References
- InstanceNorm paper: The Missing Ingredient for Fast Stylization https://arxiv.org/pdf/1607.08022
- Pytorch documentation: https://docs.pytorch.org/docs/stable/generated/torch.nn.InstanceNorm3d.html
- ONNX documentation: https://onnx.ai/onnx/operators/onnx__InstanceNormalization.html
Edited by matthieu marchal