From 5362d0eedbd4ad92d6ecbcdf8680484d465d7fc3 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 13 Feb 2025 14:22:48 +0000 Subject: [PATCH] Add decorator for dynamic type check of functions parameters in Python. --- aidge_core/utils.py | 105 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/aidge_core/utils.py b/aidge_core/utils.py index b6890bc24..749836e9c 100644 --- a/aidge_core/utils.py +++ b/aidge_core/utils.py @@ -13,6 +13,9 @@ import threading import subprocess import pathlib from typing import List +from inspect import signature +from functools import wraps +from typing import Union, _SpecialForm, List, Mapping, Dict, Tuple, get_origin, get_args def template_docstring(template_keyword, text_to_replace): @@ -37,6 +40,108 @@ def template_docstring(template_keyword, text_to_replace): return dec +def is_instance_of(obj, typ) -> bool: + """Check if an object is an instance of a type. + With a special handling for subscripted types. + """ + origin = get_origin(typ) + args = get_args(typ) + + # If it's not a generic type, fallback to normal isinstance check + if origin is None: + return isinstance(obj, typ) + + # Check if the object is of the expected container type + if not isinstance(obj, origin): + return False + + # Handle specific cases for List, Dict, Tuple + if origin in (list, set): + return all(is_instance_of(item, args[0]) for item in obj) + if origin is dict: + return all(is_instance_of(k, args[0]) and is_instance_of(v, args[1]) for k, v in obj.items()) + if origin is tuple: + if len(args) == 2 and args[1] is ...: # Handles Tuple[X, ...] + return all(is_instance_of(item, args[0]) for item in obj) + return len(obj) == len(args) and all(is_instance_of(item, t) for item, t in zip(obj, args)) + + raise NotImplementedError(f"Type {origin} is not supported") + +def type_to_str(typ) -> str: + """Return a string describing the type given as an argument. + With a special handling for subscripted types. + This gives a more detail than the __name__ attribute of the type. + + Example: dict[str, list[list[int]]] instead of dict. + """ + origin = get_origin(typ) + args = get_args(typ) + + if origin is None: + return typ.__name__ + if origin in (list, set): + return f"{origin.__name__}[{type_to_str(args[0])}]" + if origin is dict: + return f"{origin.__name__}[{type_to_str(args[0])}, {type_to_str(args[1])}]" + if origin is tuple: + if len(args) == 2 and args[1] is ...: + return f"{origin.__name__}[{type_to_str(args[0])}, ...]" + return f"{origin.__name__}[{', '.join(type_to_str(t) for t in args)}]" + raise NotImplementedError(f"Type {origin} is not supported") + +def var_to_type_str(var) -> str: + """Return a string describing the type of a variable. + With a special handling for subscripted types. + """ + typ = type(var) + if typ is list and var: + return f"list[{var_to_type_str(var[0])}]" + if typ is set and var: + return f"set[{var_to_type_str(next(iter(var)))}]" + if typ is dict and var: + key_type = var_to_type_str(next(iter(var.keys()))) + value_type = var_to_type_str(next(iter(var.values()))) + return f"dict[{key_type}, {value_type}]" + if typ is tuple and var: + return f"tuple[{', '.join(var_to_type_str(v) for v in var)}]" + return typ.__name__ + +def check_types(f): + """Decorator used to automatically check type of functions/methods. + To do so, we use type annotation available since Python 3.5 https://docs.python.org/3/library/typing.html. + Typing check is done with an handling of subscripted types (List, Dict, Tuple). + """ + sig = signature(f) + + # Dictionary key : param name, value : annotation + args_types = {p.name: p.annotation \ + for p in sig.parameters.values()} + + @wraps(f) + def decorated(*args, **kwargs): + bind = sig.bind(*args, **kwargs) + obj_name = "" + + # Check if we are in a method ! + if "self" in sig.parameters: + obj_name = f"{bind.args[0].__class__.__name__}." + + for value, typ in zip(bind.args, args_types.items()): + annotation_type = typ[1] + if annotation_type == sig.empty: + pass + if type(annotation_type) is _SpecialForm and annotation_type._name == "Any": # check if Any + continue + if value is None: # None value is always accepted + continue + if hasattr(annotation_type, "__origin__") and annotation_type.__origin__ is Union: # check if Union + # Types are contained in the __args__ attribute which is a list + # isinstance only support type or tuple, so we convert to tuple + annotation_type = tuple(annotation_type.__args__) + if annotation_type != sig.empty and not is_instance_of(value, annotation_type): + raise TypeError(f'In {obj_name}{f.__name__}: \"{typ[0]}\" parameter must be of type {type_to_str(annotation_type)} but is of type {var_to_type_str(value)} instead.') + return f(*args, **kwargs) + return decorated def run_command(command: List[str], cwd: pathlib.Path = None): -- GitLab