Skip to content
Snippets Groups Projects
Commit 1fb848bd authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix python tests

parent f8484175
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!88Basic supervised learning
...@@ -18,7 +18,7 @@ GLOBAL_CPT = 0 ...@@ -18,7 +18,7 @@ GLOBAL_CPT = 0
class testImpl(aidge_core.OperatorImpl): class testImpl(aidge_core.OperatorImpl):
def __init__(self, op: aidge_core.Operator): def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error ! aidge_core.OperatorImpl.__init__(self, op, 'cpu') # Required to avoid type error !
def forward(self): def forward(self):
global GLOBAL_CPT global GLOBAL_CPT
......
...@@ -108,7 +108,7 @@ class test_operator_binding(unittest.TestCase): ...@@ -108,7 +108,7 @@ class test_operator_binding(unittest.TestCase):
"""Dummy implementation to test that C++ call python code """Dummy implementation to test that C++ call python code
""" """
def __init__(self, op: aidge_core.Operator): def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Recquired to avoid type error ! aidge_core.OperatorImpl.__init__(self, op, 'test_impl') # Recquired to avoid type error !
self.idx = 0 self.idx = 0
def forward(self): def forward(self):
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#ifndef AIDGE_CORE_OPERATOR_POW_H_ #ifndef AIDGE_CORE_OPERATOR_POW_H_
#define AIDGE_CORE_OPERATOR_POW_H_ #define AIDGE_CORE_OPERATOR_POW_H_
#include <cassert>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
......
...@@ -117,7 +117,7 @@ public: ...@@ -117,7 +117,7 @@ public:
void init_OperatorImpl(py::module& m){ void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>()) .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
.def("forward", &OperatorImpl::forward) .def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward) .def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData) .def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
......
...@@ -46,24 +46,26 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op) ...@@ -46,24 +46,26 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op)
: OperatorTensor(op), : OperatorTensor(op),
Attributes_(op) Attributes_(op)
{ {
mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0))); if (op.mImpl){
if (mOutputs[0]->hasImpl()) { SET_IMPL_MACRO(Producer_Op, *this, op.backend());
if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this));
}
else {
mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend());
}
} else { } else {
mImpl = nullptr; mImpl = nullptr;
} }
// mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0)));
// if (mOutputs[0]->hasImpl()) {
// if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
// setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this));
// }
// else {
// mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend());
// }
// } else {
// mImpl = nullptr;
// }
} }
void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Producer_Op>::exists(name)) { SET_IMPL_MACRO(Producer_Op, *this, name);
setImpl(Registrar<Producer_Op>::create(name)(*this));
} else {
mImpl = std::make_shared<OperatorImpl>(*this, name);
}
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment