diff --git a/aidge_core/unit_tests/test_impl.py b/aidge_core/unit_tests/test_impl.py index ad7ee666ebb56941cdc426220cd117a0e3f8b8d1..4aacfafd7d51830dc89b7b30ea5ebf521a13fe30 100644 --- a/aidge_core/unit_tests/test_impl.py +++ b/aidge_core/unit_tests/test_impl.py @@ -18,7 +18,7 @@ GLOBAL_CPT = 0 class testImpl(aidge_core.OperatorImpl): def __init__(self, op: aidge_core.Operator): - aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error ! + aidge_core.OperatorImpl.__init__(self, op, 'cpu') # Required to avoid type error ! def forward(self): global GLOBAL_CPT diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index c541ae0e03459a0a7200795bc2d3c6b70c13be3b..c94960733b24444218b1209463adbda11b89f6e8 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -108,7 +108,7 @@ class test_operator_binding(unittest.TestCase): """Dummy implementation to test that C++ call python code """ def __init__(self, op: aidge_core.Operator): - aidge_core.OperatorImpl.__init__(self, op) # Recquired to avoid type error ! + aidge_core.OperatorImpl.__init__(self, op, 'test_impl') # Recquired to avoid type error ! self.idx = 0 def forward(self): diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index aadbf92c4ba02aa69665a9994afb93fa5461a402..f2becdc60ceb44c19e341496f71e09f061cea55f 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -12,8 +12,8 @@ #ifndef AIDGE_CORE_OPERATOR_POW_H_ #define AIDGE_CORE_OPERATOR_POW_H_ -#include <cassert> #include <memory> +#include <string> #include <vector> #include "aidge/utils/Registrar.hpp" diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 4574bd8f86172798e6f72ea768dc04879c7d5e5c..97cf817176c733000eda8da6c6a213ccc22f1dc4 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -117,7 +117,7 @@ public: void init_OperatorImpl(py::module& m){ py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) - .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>()) + .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) .def("forward", &OperatorImpl::forward) .def("backward", &OperatorImpl::backward) .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp index 4a63b207ca417df83f50c0b94ea988ea8739048e..990809b6c7d98e3d518f31fd9630366414e629d0 100644 --- a/src/operator/Producer.cpp +++ b/src/operator/Producer.cpp @@ -46,24 +46,26 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op) : OperatorTensor(op), Attributes_(op) { - mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0))); - if (mOutputs[0]->hasImpl()) { - if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ - setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this)); - } - else { - mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend()); - } + if (op.mImpl){ + SET_IMPL_MACRO(Producer_Op, *this, op.backend()); } else { mImpl = nullptr; } + // mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0))); + // if (mOutputs[0]->hasImpl()) { + // if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ + // setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this)); + // } + // else { + // mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend()); + // } + + // } else { + // mImpl = nullptr; + // } } void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - if (Registrar<Producer_Op>::exists(name)) { - setImpl(Registrar<Producer_Op>::create(name)(*this)); - } else { - mImpl = std::make_shared<OperatorImpl>(*this, name); - } + SET_IMPL_MACRO(Producer_Op, *this, name); mOutputs[0]->setBackend(name, device); } \ No newline at end of file