From 567928aac19490de14ff40c24395615c66bfb50b Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 22 Mar 2024 11:06:08 +0000 Subject: [PATCH] Fix python tests --- aidge_core/unit_tests/test_impl.py | 2 +- .../unit_tests/test_operator_binding.py | 2 +- include/aidge/operator/Pow.hpp | 2 +- .../backend/pybind_OperatorImpl.cpp | 2 +- src/operator/Producer.cpp | 28 ++++++++++--------- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/aidge_core/unit_tests/test_impl.py b/aidge_core/unit_tests/test_impl.py index ad7ee666e..4aacfafd7 100644 --- a/aidge_core/unit_tests/test_impl.py +++ b/aidge_core/unit_tests/test_impl.py @@ -18,7 +18,7 @@ GLOBAL_CPT = 0 class testImpl(aidge_core.OperatorImpl): def __init__(self, op: aidge_core.Operator): - aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error ! + aidge_core.OperatorImpl.__init__(self, op, 'cpu') # Required to avoid type error ! def forward(self): global GLOBAL_CPT diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index c541ae0e0..c94960733 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -108,7 +108,7 @@ class test_operator_binding(unittest.TestCase): """Dummy implementation to test that C++ call python code """ def __init__(self, op: aidge_core.Operator): - aidge_core.OperatorImpl.__init__(self, op) # Recquired to avoid type error ! + aidge_core.OperatorImpl.__init__(self, op, 'test_impl') # Recquired to avoid type error ! self.idx = 0 def forward(self): diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index aadbf92c4..f2becdc60 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -12,8 +12,8 @@ #ifndef AIDGE_CORE_OPERATOR_POW_H_ #define AIDGE_CORE_OPERATOR_POW_H_ -#include <cassert> #include <memory> +#include <string> #include <vector> #include "aidge/utils/Registrar.hpp" diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 4574bd8f8..97cf81717 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -117,7 +117,7 @@ public: void init_OperatorImpl(py::module& m){ py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) - .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>()) + .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) .def("forward", &OperatorImpl::forward) .def("backward", &OperatorImpl::backward) .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp index 4a63b207c..990809b6c 100644 --- a/src/operator/Producer.cpp +++ b/src/operator/Producer.cpp @@ -46,24 +46,26 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op) : OperatorTensor(op), Attributes_(op) { - mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0))); - if (mOutputs[0]->hasImpl()) { - if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ - setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this)); - } - else { - mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend()); - } + if (op.mImpl){ + SET_IMPL_MACRO(Producer_Op, *this, op.backend()); } else { mImpl = nullptr; } + // mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0))); + // if (mOutputs[0]->hasImpl()) { + // if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ + // setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this)); + // } + // else { + // mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend()); + // } + + // } else { + // mImpl = nullptr; + // } } void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - if (Registrar<Producer_Op>::exists(name)) { - setImpl(Registrar<Producer_Op>::create(name)(*this)); - } else { - mImpl = std::make_shared<OperatorImpl>(*this, name); - } + SET_IMPL_MACRO(Producer_Op, *this, name); mOutputs[0]->setBackend(name, device); } \ No newline at end of file -- GitLab