Skip to content
Snippets Groups Projects
Commit 759906bf authored by Maxence Naud's avatar Maxence Naud
Browse files

Update python binding and tests

parent 585f986f
No related branches found
No related tags found
2 merge requests!46Remove Operator reference to Tensor,!20Draft: Introduction of Tiling
...@@ -37,15 +37,15 @@ class ExportNode(ABC): ...@@ -37,15 +37,15 @@ class ExportNode(ABC):
for idx, parent_node in enumerate(self.node.get_parents()): for idx, parent_node in enumerate(self.node.get_parents()):
self.inputs.append(parent_node) self.inputs.append(parent_node)
if parent_node is not None: if parent_node is not None:
self.inputs_dims.append(self.operator.input(idx).dims()) self.inputs_dims.append(self.operator.get_input(idx).dims())
else: else:
self.inputs_dims.append(None) self.inputs_dims.append(None)
for idx, child_node in enumerate(self.node.get_children()): for idx, child_node in enumerate(self.node.get_children()):
self.outputs.append(child_node) self.outputs.append(child_node)
# Dirty hot fix, change it quickly # Dirty hot fix, change it quickly
self.outputs_dims.append(self.operator.output(0).dims()) self.outputs_dims.append(self.operator.get_output(0).dims())
@abstractmethod @abstractmethod
def export(self, export_folder:str, list_configs:list): def export(self, export_folder:str, list_configs:list):
......
...@@ -16,14 +16,14 @@ class test_operator_binding(unittest.TestCase): ...@@ -16,14 +16,14 @@ class test_operator_binding(unittest.TestCase):
Can be remove in later stage of the developpement. Can be remove in later stage of the developpement.
""" """
def setUp(self): def setUp(self):
self.generic_operator = aidge_core.GenericOperator("FakeConv", 1, 1, 1).get_operator() self.generic_operator = aidge_core.GenericOperator("FakeConv", 1, 0, 1).get_operator()
def tearDown(self): def tearDown(self):
pass pass
def test_default_name(self): def test_default_name(self):
op_type = "Conv" op_type = "Conv"
gop = aidge_core.GenericOperator(op_type, 1, 1, 1, "FictiveName") gop = aidge_core.GenericOperator(op_type, 1, 0, 1, "FictiveName")
# check node name is not operator type # check node name is not operator type
self.assertNotEqual(gop.name(), "Conv") self.assertNotEqual(gop.name(), "Conv")
# check node name is not default # check node name is not default
...@@ -95,12 +95,12 @@ class test_operator_binding(unittest.TestCase): ...@@ -95,12 +95,12 @@ class test_operator_binding(unittest.TestCase):
def test_compute_output_dims(self): def test_compute_output_dims(self):
in_dims=[25, 25] in_dims=[25, 25]
input = aidge_core.Producer(in_dims, name="In") input = aidge_core.Producer(in_dims, name="In")
genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp") genOp = aidge_core.GenericOperator("genOp", 1, 0, 1, name="genOp")
_ = aidge_core.sequential([input, genOp]) _ = aidge_core.sequential([input, genOp])
self.assertListEqual(genOp.get_operator().output(0).dims(), []) self.assertListEqual(genOp.get_operator().get_output(0).dims(), [])
genOp.get_operator().set_compute_output_dims(lambda x:x) genOp.get_operator().set_compute_output_dims(lambda x:x)
genOp.get_operator().compute_output_dims() genOp.get_operator().compute_output_dims()
self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) self.assertListEqual(genOp.get_operator().get_output(0).dims(), in_dims)
def test_set_impl(self): def test_set_impl(self):
...@@ -116,7 +116,7 @@ class test_operator_binding(unittest.TestCase): ...@@ -116,7 +116,7 @@ class test_operator_binding(unittest.TestCase):
""" """
self.idx += 1 self.idx += 1
generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu") generic_node = aidge_core.GenericOperator("Relu", 1, 0, 1, name="myReLu")
generic_op = generic_node.get_operator() generic_op = generic_node.get_operator()
customImpl = PythonCustomImpl(generic_op) customImpl = PythonCustomImpl(generic_op)
......
...@@ -32,15 +32,17 @@ class test_attributes(unittest.TestCase): ...@@ -32,15 +32,17 @@ class test_attributes(unittest.TestCase):
self.assertEqual(conv_op.get_attr("KernelDims"), k_dims) self.assertEqual(conv_op.get_attr("KernelDims"), k_dims)
def test_fc(self): def test_fc(self):
in_channels = 4
out_channels = 8 out_channels = 8
nb_bias = True nb_bias = True
fc_op = aidge_core.FC(out_channels, nb_bias).get_operator() fc_op = aidge_core.FC(in_channels, out_channels, nb_bias).get_operator()
self.assertEqual(fc_op.get_attr("OutChannels"), out_channels) self.assertEqual(fc_op.get_attr("OutChannels"), out_channels)
self.assertEqual(fc_op.get_attr("NoBias"), nb_bias) self.assertEqual(fc_op.get_attr("NoBias"), nb_bias)
def test_matmul(self): def test_matmul(self):
in_channels = 4
out_channels = 8 out_channels = 8
matmul_op = aidge_core.MatMul(out_channels).get_operator() matmul_op = aidge_core.MatMul(in_channels, out_channels).get_operator()
self.assertEqual(matmul_op.get_attr("OutChannels"), out_channels) self.assertEqual(matmul_op.get_attr("OutChannels"), out_channels)
def test_producer_1D(self): def test_producer_1D(self):
......
...@@ -22,8 +22,8 @@ class test_recipies(unittest.TestCase): ...@@ -22,8 +22,8 @@ class test_recipies(unittest.TestCase):
def test_remove_flatten(self): def test_remove_flatten(self):
graph_view = aidge_core.sequential([ graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), aidge_core.GenericOperator("Flatten", 1, 0, 1, name="Flatten0"),
aidge_core.FC(50, name='0') aidge_core.FC(10, 50, name='0')
]) ])
old_nodes = graph_view.get_nodes() old_nodes = graph_view.get_nodes()
aidge_core.remove_flatten(graph_view) aidge_core.remove_flatten(graph_view)
...@@ -33,9 +33,9 @@ class test_recipies(unittest.TestCase): ...@@ -33,9 +33,9 @@ class test_recipies(unittest.TestCase):
self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
def test_fuse_matmul_add(self): def test_fuse_matmul_add(self):
matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0") matmul0 = aidge_core.MatMul(1, 1, name="MatMul0")
add0 = aidge_core.Add(2, name="Add0") add0 = aidge_core.Add(2, name="Add0")
matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1") matmul1 = aidge_core.MatMul(1, 1, name="MatMul1")
add1 = aidge_core.Add(2, name="Add1") add1 = aidge_core.Add(2, name="Add1")
graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1]) graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1])
......
...@@ -20,10 +20,8 @@ namespace Aidge { ...@@ -20,10 +20,8 @@ namespace Aidge {
void init_Operator(py::module& m){ void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
// .def("set_output", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
.def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx")) .def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
// .def("set_input", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx"))
.def("nb_inputs", &Operator::nbInputs) .def("nb_inputs", &Operator::nbInputs)
.def("nb_data", &Operator::nbData) .def("nb_data", &Operator::nbData)
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Operator.hpp"
#include <pybind11/stl.h>
namespace py = pybind11;
namespace Aidge {
void init_OperatorTensor(py::module& m){
py::class_<OperatorTensor, std::shared_ptr<OperatorTensor>, Operator>(m, "OperatorTensor")
.def("get_output", &OperatorTensor::getOutput, py::arg("outputIdx"))
.def("get_input", &OperatorTensor::getInput, py::arg("inputIdx"))
;
}
}
...@@ -19,6 +19,7 @@ void init_Tensor(py::module&); ...@@ -19,6 +19,7 @@ void init_Tensor(py::module&);
void init_OperatorImpl(py::module&); void init_OperatorImpl(py::module&);
void init_Attributes(py::module&); void init_Attributes(py::module&);
void init_Operator(py::module&); void init_Operator(py::module&);
void init_OperatorTensor(py::module&);
void init_Add(py::module&); void init_Add(py::module&);
void init_AvgPooling(py::module&); void init_AvgPooling(py::module&);
...@@ -65,6 +66,7 @@ void init_Aidge(py::module& m){ ...@@ -65,6 +66,7 @@ void init_Aidge(py::module& m){
init_OperatorImpl(m); init_OperatorImpl(m);
init_Attributes(m); init_Attributes(m);
init_Operator(m); init_Operator(m);
init_OperatorTensor(m);
init_Add(m); init_Add(m);
init_AvgPooling(m); init_AvgPooling(m);
init_BatchNorm(m); init_BatchNorm(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment