Skip to content
Snippets Groups Projects
Commit 1fb848bd authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix python tests

parent f8484175
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!88Basic supervised learning
......@@ -18,7 +18,7 @@ GLOBAL_CPT = 0
class testImpl(aidge_core.OperatorImpl):
def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error !
aidge_core.OperatorImpl.__init__(self, op, 'cpu') # Required to avoid type error !
def forward(self):
global GLOBAL_CPT
......
......@@ -108,7 +108,7 @@ class test_operator_binding(unittest.TestCase):
"""Dummy implementation to test that C++ call python code
"""
def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Recquired to avoid type error !
aidge_core.OperatorImpl.__init__(self, op, 'test_impl') # Recquired to avoid type error !
self.idx = 0
def forward(self):
......
......@@ -12,8 +12,8 @@
#ifndef AIDGE_CORE_OPERATOR_POW_H_
#define AIDGE_CORE_OPERATOR_POW_H_
#include <cassert>
#include <memory>
#include <string>
#include <vector>
#include "aidge/utils/Registrar.hpp"
......
......@@ -117,7 +117,7 @@ public:
void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>())
.def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
.def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
......
......@@ -46,24 +46,26 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op)
: OperatorTensor(op),
Attributes_(op)
{
mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0)));
if (mOutputs[0]->hasImpl()) {
if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this));
}
else {
mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend());
}
if (op.mImpl){
SET_IMPL_MACRO(Producer_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
// mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0)));
// if (mOutputs[0]->hasImpl()) {
// if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
// setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this));
// }
// else {
// mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend());
// }
// } else {
// mImpl = nullptr;
// }
}
void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Producer_Op>::exists(name)) {
setImpl(Registrar<Producer_Op>::create(name)(*this));
} else {
mImpl = std::make_shared<OperatorImpl>(*this, name);
}
SET_IMPL_MACRO(Producer_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment