diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..651acd6fe1a58edc0b6f2c446e48e4bc4e4a8750 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,18 @@ +############################################################################### +# Aidge Continious Integration and Continious Deployment # +# # +############################################################################### + +stages: + # Analyse code + - static_analysis + # Build Aidge + - build + # Unit test stage + - test + +include: + - local: '/.gitlab/ci/_global.gitlab-ci.yml' + - local: '/.gitlab/ci/static_analysis.gitlab-ci.yml' + - local: '/.gitlab/ci/build.gitlab-ci.yml' + - local: '/.gitlab/ci/test.gitlab-ci.yml' diff --git a/.gitlab/ci/_global.gitlab-ci.yml b/.gitlab/ci/_global.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..6f34fe701df035e68ce49825fde0ff88449a9637 --- /dev/null +++ b/.gitlab/ci/_global.gitlab-ci.yml @@ -0,0 +1,13 @@ +################################################################################ +# Centralized definitions of common job parameter values. # +# Parameters with many optional configurations may be in separate files. # +# # +################################################################################ +variables: + GIT_SUBMODULE_STRATEGY: recursive + OMP_NUM_THREADS: 4 + GIT_SSL_NO_VERIFY: 1 + DEBIAN_FRONTEND: noninteractive + + +image: n2d2-ci/ubuntu20.04/cpu:latest \ No newline at end of file diff --git a/.gitlab/ci/build.gitlab-ci.yml b/.gitlab/ci/build.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..0c004a9f2dc9f4425a8e962d0c9d0e9fec83146e --- /dev/null +++ b/.gitlab/ci/build.gitlab-ci.yml @@ -0,0 +1,33 @@ +build:ubuntu_cpp: + stage: build + tags: + - docker + image: n2d2-ci/ubuntu20.04/cpu:latest + + script: + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON .. + - make -j4 all install + + artifacts: + paths: + - build_cpp/ + - install_cpp/ + +build:ubuntu_python: + stage: build + tags: + - docker + image: n2d2-ci/ubuntu20.04/cpu:latest + + script: + - python3 -m pip install virtualenv + - virtualenv venv + - source venv/bin/activate + - export AIDGE_INSTALL=`pwd`/install + - python3 -m pip install . + artifacts: + paths: + - venv/ \ No newline at end of file diff --git a/.gitlab/ci/static_analysis.gitlab-ci.yml b/.gitlab/ci/static_analysis.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..f7c09a33a65801fb25b1f20f76eac5a7a7952917 --- /dev/null +++ b/.gitlab/ci/static_analysis.gitlab-ci.yml @@ -0,0 +1,37 @@ +static_analysis:cpp: + stage: static_analysis + tags: + - static_analysis + allow_failure: true + script: + - mkdir -p $CI_COMMIT_REF_NAME + - cppcheck -j 4 --enable=all --inconclusive --force --xml --xml-version=2 . 2> cppcheck-result.xml + - python -m pip install Pygments + - cppcheck-htmlreport --file=cppcheck-result.xml --report-dir=$CI_COMMIT_REF_NAME --source-dir=. + - python3 -m pip install -U cppcheck_codequality + - cppcheck-codequality --input-file=cppcheck-result.xml --output-file=cppcheck.json + - mkdir -p public/cpp + - mv $CI_COMMIT_REF_NAME public/cpp/ + artifacts: + paths: + - public + reports: + codequality: cppcheck.json + +static_analysis:python: + stage: static_analysis + tags: + - static_analysis + allow_failure: true + script: + - pip install pylint + - pip install pylint-gitlab + - pylint --rcfile=.pylintrc --exit-zero --output-format=pylint_gitlab.GitlabCodeClimateReporter aidge_core/ > codeclimate.json + - pylint --rcfile=.pylintrc --exit-zero --output-format=pylint_gitlab.GitlabPagesHtmlReporter aidge_core/ > pylint.html + - mkdir -p public/python/$CI_COMMIT_REF_NAME + - mv pylint.html public/python/$CI_COMMIT_REF_NAME/ + artifacts: + paths: + - public + reports: + codequality: codeclimate.json \ No newline at end of file diff --git a/.gitlab/ci/test.gitlab-ci.yml b/.gitlab/ci/test.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..c85683f889fbc3b5d6af366642d0932d5e2ce9a1 --- /dev/null +++ b/.gitlab/ci/test.gitlab-ci.yml @@ -0,0 +1,23 @@ +test:ubuntu_cpp: + stage: test + needs: ["build:ubuntu_cpp"] + tags: + - docker + image: n2d2-ci/ubuntu20.04/cpu:latest + script: + - cd build_cpp + - ctest --output-on-failure + +test:ubuntu_python: + stage: test + needs: ["build:ubuntu_python"] + tags: + - docker + image: n2d2-ci/ubuntu20.04/cpu:latest + script: + - source venv/bin/activate + - cd aidge_core + - python3 -m pip list + # Run on discovery all tests located in core/unit_tests/python and discard the stdout + # only to show the errors/warnings and the results of the tests + - python3 -m unittest discover -s unit_tests/ -v -b 1> /dev/null diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..03c0cf31f3e63bcae09a45e9a8e6694a78d2f4b1 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,644 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= aidge_core, torch, tensorflow + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=0.0 + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the ignore-list. The +# regex matches against paths. +ignore-paths= + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + c-extension-no-member, + too-many-locals, + missing-class-docstring, + missing-function-docstring, + too-many-ancestor, + too-many-arguments, + protected-access, + too-many-branches, + too-many-ancestors, + wrong-import-order, + wrong-import-position, + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _, + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )?<?https?://\S+>?$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=200 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=no + +# Signatures are removed from the similarity computation +ignore-signatures=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear and the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values, + thread._local, + _thread._local, + aidge.global_variables, + aidge.cells.abstract_cell.Trainable, + torch, + tensorflow, + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= aidge_core + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[DESIGN] + +# List of qualified class names to ignore when countint class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception \ No newline at end of file diff --git a/aidge_core/__init__.py b/aidge_core/__init__.py index 8e0bef3a02be05199c0092461b920fe7e5d839dd..ad18a8ef1b23625dcb52951f52c43adc4222c997 100644 --- a/aidge_core/__init__.py +++ b/aidge_core/__init__.py @@ -1 +1,10 @@ -from aidge_core.aidge_core import * # import so generated by PyBind \ No newline at end of file +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" +from aidge_core.aidge_core import * # import so generated by PyBind diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py new file mode 100644 index 0000000000000000000000000000000000000000..b326e0748c2c77612dd79122fe891a6207d945dc --- /dev/null +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -0,0 +1,65 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core + +class test_operator_binding(unittest.TestCase): + """Very basic test to make sure the python APi is not broken. + Can be remove in later stage of the developpement. + """ + def setUp(self): + self.generic_operator = aidge_core.GenericOperator("FakeConv", 1, 1, 1).get_operator() + + def tearDown(self): + pass + + def test_default_name(self): + op_type = "Conv" + gop = aidge_core.GenericOperator(op_type, 1, 1, 1, "FictiveName") + # check node name is not operator type + self.assertNotEqual(gop.name(), "Conv") + # check node name is not default + self.assertNotEqual(gop.name(), "") + + def test_param_bool(self): + self.generic_operator.add_parameter("bool", True) + self.assertEqual(self.generic_operator.get_parameter("bool"), True) + + def test_param_int(self): + self.generic_operator.add_parameter("int", 1) + self.assertEqual(self.generic_operator.get_parameter("int"), 1) + + def test_param_float(self): + self.generic_operator.add_parameter("float", 2.0) + self.assertEqual(self.generic_operator.get_parameter("float"), 2.0) + + def test_param_str(self): + self.generic_operator.add_parameter("str", "value") + self.assertEqual(self.generic_operator.get_parameter("str"), "value") + + def test_param_l_int(self): + self.generic_operator.add_parameter("l_int", [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) + self.assertEqual(self.generic_operator.get_parameter("l_int"), [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) + + def test_param_l_bool(self): + self.generic_operator.add_parameter("l_bool", [True, False, False, True]) + self.assertEqual(self.generic_operator.get_parameter("l_bool"), [True, False, False, True]) + + def test_param_l_float(self): + self.generic_operator.add_parameter("l_float", [2.0, 1.0]) + self.assertEqual(self.generic_operator.get_parameter("l_float"), [2.0, 1.0]) + + def test_param_l_str(self): + self.generic_operator.add_parameter("l_str", ["ok"]) + self.assertEqual(self.generic_operator.get_parameter("l_str"), ["ok"]) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/aidge_core/unit_tests/test_parameters.py b/aidge_core/unit_tests/test_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..02c7598820d2429bc49ff9a2f02c8ee841783173 --- /dev/null +++ b/aidge_core/unit_tests/test_parameters.py @@ -0,0 +1,77 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core + +class test_parameters(unittest.TestCase): + """Very basic test to make sure the python APi is not broken. + Can be remove in later stage of the developpement. + """ + def setUp(self): + pass + + def tearDown(self): + pass + + def test_conv(self): + # TODO : test StrideDims & DilationDims when supported in ctor + in_channels = 4 + out_channels = 8 + k_dims = [2, 2] + conv_op = aidge_core.Conv2D(in_channels , out_channels, k_dims).get_operator() + self.assertEqual(conv_op.get("InChannels"), in_channels) + self.assertEqual(conv_op.get("OutChannels"), out_channels) + self.assertEqual(conv_op.get("KernelDims"), k_dims) + + def test_fc(self): + out_channels = 8 + nb_bias = True + fc_op = aidge_core.FC(out_channels, nb_bias).get_operator() + self.assertEqual(fc_op.get("OutChannels"), out_channels) + self.assertEqual(fc_op.get("NoBias"), nb_bias) + + def test_matmul(self): + out_channels = 8 + matmul_op = aidge_core.Matmul(out_channels).get_operator() + self.assertEqual(matmul_op.get("OutChannels"), out_channels) + + def test_producer_1D(self): + dims = [5] + producer_op = aidge_core.Producer(dims).get_operator() + self.assertEqual(producer_op.dims(), dims) + + def test_producer_2D(self): + dims = [10,5] + producer_op = aidge_core.Producer(dims).get_operator() + self.assertEqual(producer_op.dims(), dims) + + def test_producer_3D(self): + dims = [1,10,5] + producer_op = aidge_core.Producer(dims).get_operator() + self.assertEqual(producer_op.dims(), dims) + + def test_producer_4D(self): + dims = [12,1,10,5] + producer_op = aidge_core.Producer(dims).get_operator() + self.assertEqual(producer_op.dims(), dims) + + def test_producer_5D(self): + dims = [2,12,1,10,5] + producer_op = aidge_core.Producer(dims).get_operator() + self.assertEqual(producer_op.dims(), dims) + + def test_leaky_relu(self): + negative_slope = 0.25 + leakyrelu_op = aidge_core.LeakyReLU(negative_slope).get_operator() + self.assertEqual(leakyrelu_op.get("NegativeSlope"), negative_slope) + +if __name__ == '__main__': + unittest.main() diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index fabbe584530d37930e18bf070383b8d75f732d24..0780ce9a24da0ceb0c42b32944021f5df2fa9726 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -34,13 +34,23 @@ class GraphView; */ class Node : public std::enable_shared_from_this<Node> { private: + struct weakCompare { + bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const { + // Compare the content of the weak_ptrs + auto sharedA = a.lock(); + auto sharedB = b.lock(); + if (!sharedB) return false; // nothing after expired pointer + if (!sharedA) return true; + return sharedA < sharedB; // shared_ptr has a valid comparison operator + } + }; std::string mName; /** Name of the Node. Should be unique. */ - std::set<std::shared_ptr<GraphView>> mViews = std::set<std::shared_ptr<GraphView>>(); /** Set of pointers to GraphView instances including this Node instance. */ + std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */ const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator std::vector<NodePtr> mParents; /** List of parent node for each input (Parent --> Node --> Child) */ - std::vector<std::vector<NodePtr>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */ + std::vector<std::vector<std::weak_ptr<Node>>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */ std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */ std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */ @@ -70,7 +80,7 @@ public: * @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index. * @return Connector */ - Connector operator()(const std::vector<Connector> ctors); + Connector operator()(const std::vector<Connector> &ctors); public: /////////////////////////////////////////////////////// @@ -131,14 +141,14 @@ public: /** * @brief List of pair <Parent, ID of the data intput>. When an input is not * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. - * @return std::vector<std::pair<NodePtr, IOIndex_t>> + * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; /** * @brief List of pair <Parent, ID of the parent output>. When an input is not linked * to any Parent, the pair is <nullptr, gk_IODefaultIndex>. - * @return std::vector<std::pair<NodePtr, IOIndex_t>> + * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; @@ -146,7 +156,7 @@ public: * @brief Parent and its output Tensor ID linked to the inID-th input Tensor. * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. * @param inID - * @return std::pair<NodePtr, IOIndex_t> + * @return std::pair<std::shared_ptr<Node>, IOIndex_t> */ inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const { assert((inID != gk_IODefaultIndex) && (inID < nbInputs()) && "Input index out of bound."); @@ -178,19 +188,19 @@ public: /** * @brief List input ids of children liked to outputs of the node - * @return std::vector<std::vector<std::pair<NodePtr, + * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; /** - * @brief Children and their input Tensor ID linked to the outID-th output + * @brief Children and their input Tensor ID linked to the outId-th output * Tensor. - * @param outID - * @return std::vector<std::pair<NodePtr, IOIndex_t>> + * @param outId + * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> - output(IOIndex_t outID) const; + output(IOIndex_t outId) const; /** * @brief Number of inputs, including both data and learnable parameters. @@ -231,7 +241,11 @@ public: * @return std::vector<GraphView> */ inline std::set<std::shared_ptr<GraphView>> views() const noexcept { - return mViews; + std::set<std::shared_ptr<GraphView>> res; + for (const auto &v : mViews) { + res.insert(v.lock()); + } + return res; } /** @@ -239,14 +253,14 @@ public: * the current Node. This feature allows transparent GraphViews. * @param graphPtr Pointer to GraphView to add to the list. */ - inline void addView(const std::shared_ptr<GraphView> graphPtr) { - mViews.insert(graphPtr); + inline void addView(const std::shared_ptr<GraphView> &graphPtr) { + mViews.insert(std::weak_ptr<GraphView>(graphPtr)); } - inline void removeView(const std::shared_ptr<GraphView> graphPtr) { - if (mViews.find(graphPtr) != mViews.end()) { - mViews.erase(graphPtr); - } + inline void removeView(const std::shared_ptr<GraphView> &graphPtr) { + std::set<std::weak_ptr<GraphView>, weakCompare>::const_iterator viewIt = mViews.cbegin(); + for (; (viewIt != mViews.cend()) && ((*viewIt).lock() != graphPtr) ; ++viewIt) {} + mViews.erase(*viewIt); } /** @@ -280,14 +294,14 @@ public: /** * @brief Get the list of parent Nodes. As an input is linked to a unique Node, * if none is linked then the parent is a nullptr. - * @return std::vector<NodePtr> + * @return std::vector<std::shared_ptr<Node>> */ std::vector<NodePtr> getParents() const; /** * @brief Get the pointer to parent of the specified input index. This pointer is nullptr if no parent is linked. * @param inId Input index. - * @return NodePtr& + * @return std::shared_ptr<Node>& */ inline NodePtr &getParents(const IOIndex_t inId) { assert(inId != gk_IODefaultIndex); @@ -298,7 +312,7 @@ public: * @brief Unlink the parent Node at the specified input index and return its pointer. * Return a nullptr is no parent was linked. * @param inId Input index. - * @return NodePtr + * @return std::shared_ptr<Node> */ NodePtr popParent(const IOIndex_t inId); @@ -308,7 +322,7 @@ public: * @brief Get the set of pointers to children Nodes linked to the current Node.object. * @details The returned set does not include any nullptr as an output maybe linked to * an undifined number of Nodes. It does not change the computation of its associated Operator. - * @return std::set<NodePtr>> + * @return std::set<std::shared_ptr<Node>>> */ std::set<NodePtr> getChildren() const; @@ -317,14 +331,14 @@ public: /** * @brief Get the list of children Nodes linked to the output at specified index. * @param outId Output index. - * @return std::vector<NodePtr> + * @return std::vector<std::shared_ptr<Node>> */ - std::vector<NodePtr> getChildren(const IOIndex_t outID) const; + std::vector<NodePtr> getChildren(const IOIndex_t outId) const; /** * @brief Remove registered child from children list of specified output if possible. * If so, also remove current Node from child Node from parent. - * @param nodePtr Node to remove. + * @param std::shared_ptr<Node> Node to remove. * @param outId Output index. Default 0. * @return true Child found and removed for given output index. * @return false Child not found at given index. Nothing removed. @@ -388,4 +402,4 @@ private: }; } // namespace Aidge -#endif /* __AIDGE_CORE_GRAPH_NODE_H__ */ \ No newline at end of file +#endif /* __AIDGE_CORE_GRAPH_NODE_H__ */ diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 254d62c6bdff89dd28079245adf0b2559cca66f8..86b96bfaa8bf0eb5ab52fa542f169708ff8d09ca 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -79,6 +79,7 @@ class GenericOperator_Op mParams.Add<T>(key, value); } + std::string getParameterType(std::string const &key) { return mParams.getParamType(key); } std::vector<std::string> getParametersName() { return mParams.getParametersName(); } @@ -88,7 +89,7 @@ class GenericOperator_Op printf("Info: using associateInput() on a GenericOperator.\n"); } - void computeOutputDims() override final { + void computeOutputDims() override final { assert(false && "Cannot compute output dim of a GenericOperator"); } @@ -115,7 +116,7 @@ class GenericOperator_Op printf("Info: using getInput() on a GenericOperator.\n"); return mInputs[inputIdx]; } - inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { assert((outputIdx < mNbOut) && "output index out of range for this instance of GenericOperator"); printf("Info: using getOutput() on a GenericOperator.\n"); return mOutputs[outputIdx]; diff --git a/include/aidge/utils/CParameter.hpp b/include/aidge/utils/CParameter.hpp index c7d0ea23d6899b51f76ccdcfcfd8db9de6607165..64943ff58eae9a06fe50afb1b81deea1b66e90ea 100644 --- a/include/aidge/utils/CParameter.hpp +++ b/include/aidge/utils/CParameter.hpp @@ -15,7 +15,6 @@ #include <assert.h> #include <map> #include <vector> -#include <numeric> namespace Aidge { @@ -23,6 +22,13 @@ namespace Aidge { ///\todo managing complex types or excluding non-trivial, non-aggregate types class CParameter { +private: + template <typename T> + struct is_vector : std::false_type {}; + + template <typename T, typename Alloc> + struct is_vector<std::vector<T, Alloc>> : std::true_type {}; + public: // not copyable, not movable CParameter(CParameter const &) = delete; @@ -30,6 +36,7 @@ public: CParameter &operator=(CParameter const &) = delete; CParameter &operator=(CParameter &&) = delete; CParameter() : m_Params({}){}; + ~CParameter() = default; /** * \brief Returning a parameter identified by its name @@ -41,7 +48,7 @@ public: * param buffer that will get invalid after the CParam death. * \note at() throws if the parameter does not exist, using find to test for parameter existance */ - template<class T> T Get(std::string const &i_ParamName) const + template<class T> T Get(std::string const i_ParamName) const { assert(m_Params.find(i_ParamName) != m_Params.end()); assert(m_Types.find(i_ParamName) != m_Types.end()); @@ -65,9 +72,11 @@ public: = i_Value; // Black-magic used to add anytype into the vector m_Params[i_ParamName] = m_OffSet; // Copy pointer offset m_OffSet += sizeof(T); // Increment offset + m_Types[i_ParamName] = typeid(i_Value).name(); } - + + std::string getParamType(std::string const &i_ParamName){ return m_Types[i_ParamName]; } @@ -79,18 +88,14 @@ public: return parametersName; } - - ~CParameter() = default; - private: - // Note for Cyril: of course storing offset and not address! Good idea std::map<std::string, std::size_t> m_Params; // { Param name : offset } ///\brief Map to check type error - /* Note : i tried this : `std::map<std::string, std::type_info const *> m_Types;` + /* Note : i tried this : `std::map<std::string, std::type_info const *> mTypes;` but looks like the type_ingo object was destroyed. I am not a hugde fan of storing a string and making string comparison. - Maybe we can use a custom enum type (or is there a standard solution ?) + Maybe we can use a custom enum type (or is there a standard solution ?) */ std::map<std::string, std::string> m_Types; @@ -102,9 +107,9 @@ private: ///\brief Offset, in number of uint8_t, of the next parameter to write std::size_t m_OffSet = 0; + }; } - #endif /* __AIDGE_CPARAMETER_H__ */ diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index 550f040210c219b2956313e0c3ea1cdce63429aa..82eaa0062b7db0e57da3d78d56e503e3a4beb19f 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -21,7 +21,7 @@ namespace py = pybind11; namespace Aidge { void declare_FC(py::module &m) { - py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator>(m, "FC_Op", py::multiple_inheritance()); + py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator, PyAbstractParametrizable>(m, "FC_Op", py::multiple_inheritance()); m.def("FC", &FC, py::arg("out_channels"), py::arg("nobias") = false, py::arg("name") = nullptr); } diff --git a/python_binding/operator/pybind_Matmul.cpp b/python_binding/operator/pybind_Matmul.cpp index a26caeba326aba28490c15aa7264f4895640c3b3..c81845ca5e5ba3674356d16db660f4e3550e9004 100644 --- a/python_binding/operator/pybind_Matmul.cpp +++ b/python_binding/operator/pybind_Matmul.cpp @@ -21,7 +21,7 @@ namespace py = pybind11; namespace Aidge { void declare_Matmul(py::module &m) { - py::class_<Matmul_Op, std::shared_ptr<Matmul_Op>, Operator>(m, "Matmul_Op", py::multiple_inheritance()); + py::class_<Matmul_Op, std::shared_ptr<Matmul_Op>, Operator, PyAbstractParametrizable>(m, "Matmul_Op", py::multiple_inheritance()); m.def("Matmul", &Matmul, py::arg("out_channels"), py::arg("name") = nullptr); } diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5568e4b599195f50450bcb715c6e03e034c1ceb2..286ed7136a369e63f567b35135f89afcc266e0e1 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -21,8 +21,8 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name) : mName((name == nullptr) ? std::string() : std::string(name)), mOperator(op), mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), - mChildren(std::vector<std::vector<std::shared_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), - std::vector<std::shared_ptr<Node>>())), + mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), + std::vector<std::weak_ptr<Node>>())), mIdInChildren( std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { @@ -33,7 +33,7 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name) // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// -Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> ctors) { +Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { assert((ctors.size() == nbDataInputs()) && "Wrong number of arguments.\n"); for (__attribute__((unused)) std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); @@ -134,12 +134,12 @@ Aidge::Node::outputs() const { } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::output(Aidge::IOIndex_t outID) const { +Aidge::Node::output(Aidge::IOIndex_t outId) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outID].size()); - for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); + for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { listOutputs[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outID][i], mIdInChildren[outID][i]); + std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); } return listOutputs; } @@ -161,7 +161,7 @@ Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const { return counter; } -void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID) { +void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) { assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index"); if (mIdOutParents[inId] != gk_IODefaultIndex) { std::printf("Warning: filling a Tensor already attributed\n"); @@ -171,7 +171,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID) // find first occurence of child in the output's children originalParent.first->removeChild(shared_from_this(), originalParent.second); } - mIdOutParents[inId] = newNodeOutID; + mIdOutParents[inId] = newNodeoutId; } /////////////////////////////////////////////////////// @@ -179,9 +179,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID) /////////////////////////////////////////////////////// void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { - assert((otherInId != gk_IODefaultIndex) && (otherInId < otherNode->nbInputs()) && - "Input index out of bound."); - assert((outId != gk_IODefaultIndex) && (outId < nbOutputs()) && "Output index out of bound."); + assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); + assert((outId < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second != gk_IODefaultIndex) { std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId); } @@ -189,24 +188,22 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou otherNode->setInputId(otherInId, outId); otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId)); // manage nodes - mChildren[outId].push_back(otherNode); + mChildren[outId].push_back(std::weak_ptr<Node>(otherNode)); mIdInChildren[outId].push_back(otherInId); otherNode->addParent(shared_from_this(), otherInId); } -void Aidge::Node::addChildView(std::shared_ptr<GraphView> other_graph, const IOIndex_t outID, +void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { - assert((otherInId.second != gk_IODefaultIndex) && - (otherInId.second < otherInId.first->nbInputs()) && - "Other graph input index out of bound."); - assert((outID != gk_IODefaultIndex) && (outID < nbOutputs()) && "Output index out of bound."); - std::set<std::shared_ptr<Node>> inNodes = other_graph->inputNodes(); + assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); + assert((outId < nbOutputs()) && "Output index out of bound."); + std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node - addChildOp(otherInId.first, outID, otherInId.second); + addChildOp(otherInId.first, outId, otherInId.second); } } @@ -256,24 +253,36 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) { std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const std::vector<std::shared_ptr<Node>> &childrenOfOneOutput : mChildren) { - children.insert(childrenOfOneOutput.begin(), childrenOfOneOutput.end()); + for (const auto &childrenOfOneOutput : mChildren) { + for (const auto &oneChild : childrenOfOneOutput) { + children.insert(oneChild.lock()); + } } return children; } -std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { return mChildren; } +std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { + std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); + for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { + children[outId] = getChildren(outId); + } + return children; +} -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outID) const { - assert((outID != gk_IODefaultIndex) && (outID < nbOutputs()) && "Output index out of bound."); - return mChildren[outID]; +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { + assert((outId < nbOutputs()) && "Output index out of bound."); + std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); + for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { + children.push_back(mChildren[outId][i].lock()); + } + return children; } bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { - assert((outId != gk_IODefaultIndex) && (outId < nbOutputs()) && "Child index out of bound."); + assert((outId < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { - if (mChildren[outId][j] == nodePtr) { + if (mChildren[outId][j].lock() == nodePtr) { mChildren[outId].erase(mChildren[outId].begin() + j); mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j); removed = true; @@ -301,7 +310,7 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) { child.first->removeParent(child.second); } - mChildren[i] = std::vector<std::shared_ptr<Node>>(); + mChildren[i] = std::vector<std::weak_ptr<Node>>(); mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to diff --git a/src/graph/OpArgs.cpp b/src/graph/OpArgs.cpp index 52018e1cddca6ab5b8d709f4bc7b64bf067ecad2..3994a111d0881268d8768b2cb5843df65f7b4d17 100644 --- a/src/graph/OpArgs.cpp +++ b/src/graph/OpArgs.cpp @@ -22,10 +22,10 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); /* * /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. - * Prefer a functional descrition for detailed inputs + * Prefer a functional description for detailed inputs */ for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) { - node_ptr -> addChild(elt.node()); // already check that node_ptr->nbOutput == 1 + node_ptr -> addChild(elt.node()); // already checks that node_ptr->nbOutput() == 1 } gv->add(elt.node()); } diff --git a/unit_tests/operator/Test_GenericOperator.cpp b/unit_tests/operator/Test_GenericOperator.cpp index ff41ed468e5b84bf3455c25b327a91730967d3c6..886326214a4a285fb32e5909da5114d74782ee46 100644 --- a/unit_tests/operator/Test_GenericOperator.cpp +++ b/unit_tests/operator/Test_GenericOperator.cpp @@ -22,33 +22,44 @@ TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") { GenericOperator_Op Testop("TestOp", 1, 1, 1); int value = 5; const char* key = "intParam"; - Testop.addParameter<int>(key, value); + Testop.addParameter(key, value); REQUIRE(Testop.getParameter<int>(key) == value); } SECTION("LONG") { GenericOperator_Op Testop("TestOp", 1, 1, 1); long value = 3; const char* key = "longParam"; - Testop.addParameter<long>(key, value); + Testop.addParameter(key, value); REQUIRE(Testop.getParameter<long>(key) == value); } SECTION("FLOAT") { GenericOperator_Op Testop("TestOp", 1, 1, 1); float value = 2.0; const char* key = "floatParam"; - Testop.addParameter<float>(key, value); + Testop.addParameter(key, value); REQUIRE(Testop.getParameter<float>(key) == value); + } + SECTION("VECTOR<BOOL>") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + std::vector<bool> value = {true, false, false, true, true}; + const char* key = "vect"; + Testop.addParameter(key, value); + + REQUIRE(Testop.getParameter<std::vector<bool>>(key).size() == value.size()); + for (std::size_t i=0; i < value.size(); ++i){ + REQUIRE(Testop.getParameter<std::vector<bool>>(key)[i] == value[i]); + } } SECTION("VECTOR<INT>") { GenericOperator_Op Testop("TestOp", 1, 1, 1); - std::vector<int> value = {1, 2}; + std::vector<int> value = {1, 2, 3, 4, 5, 6, 7, 8, 9}; const char* key = "vect"; - Testop.addParameter<std::vector<int>>(key, value); + Testop.addParameter(key, value); REQUIRE(Testop.getParameter<std::vector<int>>(key).size() == value.size()); for (std::size_t i=0; i < value.size(); ++i){ REQUIRE(Testop.getParameter<std::vector<int>>(key)[i] == value[i]); - } + } } SECTION("MULTIPLE PARAMS") { /*