From 5362d0eedbd4ad92d6ecbcdf8680484d465d7fc3 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 13 Feb 2025 14:22:48 +0000
Subject: [PATCH] Add decorator for dynamic type check of functions parameters
 in Python.

---
 aidge_core/utils.py | 105 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 105 insertions(+)

diff --git a/aidge_core/utils.py b/aidge_core/utils.py
index b6890bc24..749836e9c 100644
--- a/aidge_core/utils.py
+++ b/aidge_core/utils.py
@@ -13,6 +13,9 @@ import threading
 import subprocess
 import pathlib
 from typing import List
+from inspect import signature
+from functools import wraps
+from typing import Union, _SpecialForm, List, Mapping, Dict, Tuple, get_origin, get_args
 
 
 def template_docstring(template_keyword, text_to_replace):
@@ -37,6 +40,108 @@ def template_docstring(template_keyword, text_to_replace):
     return dec
 
 
+def is_instance_of(obj, typ) -> bool:
+    """Check if an object is an instance of a type.
+    With a special handling for subscripted types.
+    """
+    origin = get_origin(typ)
+    args = get_args(typ)
+
+    # If it's not a generic type, fallback to normal isinstance check
+    if origin is None:
+        return isinstance(obj, typ)
+
+    # Check if the object is of the expected container type
+    if not isinstance(obj, origin):
+        return False
+
+    # Handle specific cases for List, Dict, Tuple
+    if origin in (list, set):
+        return all(is_instance_of(item, args[0]) for item in obj)
+    if origin is dict:
+        return all(is_instance_of(k, args[0]) and is_instance_of(v, args[1]) for k, v in obj.items())
+    if origin is tuple:
+        if len(args) == 2 and args[1] is ...:  # Handles Tuple[X, ...]
+            return all(is_instance_of(item, args[0]) for item in obj)
+        return len(obj) == len(args) and all(is_instance_of(item, t) for item, t in zip(obj, args))
+
+    raise NotImplementedError(f"Type {origin} is not supported")
+
+def type_to_str(typ) -> str:
+    """Return a string describing the type given as an argument.
+    With a special handling for subscripted types.
+    This gives a more detail than the __name__ attribute of the type.
+
+    Example: dict[str, list[list[int]]] instead of dict.
+    """
+    origin = get_origin(typ)
+    args = get_args(typ)
+
+    if origin is None:
+        return typ.__name__
+    if origin in (list, set):
+        return f"{origin.__name__}[{type_to_str(args[0])}]"
+    if origin is dict:
+        return f"{origin.__name__}[{type_to_str(args[0])}, {type_to_str(args[1])}]"
+    if origin is tuple:
+        if len(args) == 2 and args[1] is ...:
+            return f"{origin.__name__}[{type_to_str(args[0])}, ...]"
+        return f"{origin.__name__}[{', '.join(type_to_str(t) for t in args)}]"
+    raise NotImplementedError(f"Type {origin} is not supported")
+
+def var_to_type_str(var) -> str:
+    """Return a string describing the type of a variable.
+    With a special handling for subscripted types.
+    """
+    typ = type(var)
+    if typ is list and var:
+        return f"list[{var_to_type_str(var[0])}]"
+    if typ is set and var:
+        return f"set[{var_to_type_str(next(iter(var)))}]"
+    if typ is dict and var:
+        key_type = var_to_type_str(next(iter(var.keys())))
+        value_type = var_to_type_str(next(iter(var.values())))
+        return f"dict[{key_type}, {value_type}]"
+    if typ is tuple and var:
+        return f"tuple[{', '.join(var_to_type_str(v) for v in var)}]"
+    return typ.__name__
+
+def check_types(f):
+    """Decorator used to automatically check type of functions/methods.
+    To do so, we use type annotation available since Python 3.5 https://docs.python.org/3/library/typing.html.
+    Typing check is done with an handling of subscripted types (List, Dict, Tuple).
+    """
+    sig = signature(f)
+
+    # Dictionary key : param name, value : annotation
+    args_types = {p.name: p.annotation \
+            for p in sig.parameters.values()}
+
+    @wraps(f)
+    def decorated(*args, **kwargs):
+        bind = sig.bind(*args, **kwargs)
+        obj_name = ""
+
+        # Check if we are in a method !
+        if "self" in sig.parameters:
+            obj_name = f"{bind.args[0].__class__.__name__}."
+
+        for value, typ in zip(bind.args, args_types.items()):
+            annotation_type = typ[1]
+            if annotation_type == sig.empty:
+                pass
+            if type(annotation_type) is _SpecialForm and annotation_type._name == "Any": # check if Any
+                continue
+            if value is None: # None value is always accepted
+                continue
+            if hasattr(annotation_type, "__origin__") and annotation_type.__origin__ is Union: # check if Union
+                    # Types are contained in the __args__ attribute which is a list
+                    # isinstance only support type or tuple, so we convert to tuple
+                    annotation_type = tuple(annotation_type.__args__)
+            if annotation_type != sig.empty and not is_instance_of(value, annotation_type):
+                raise TypeError(f'In {obj_name}{f.__name__}: \"{typ[0]}\" parameter must be of type {type_to_str(annotation_type)} but is of type {var_to_type_str(value)} instead.')
+        return f(*args, **kwargs)
+    return decorated
 
 
 def run_command(command: List[str], cwd: pathlib.Path = None):
-- 
GitLab