Skip to content
Snippets Groups Projects
Commit 759906bf authored by Maxence Naud's avatar Maxence Naud
Browse files

Update python binding and tests

parent 585f986f
No related branches found
No related tags found
No related merge requests found
......@@ -37,15 +37,15 @@ class ExportNode(ABC):
for idx, parent_node in enumerate(self.node.get_parents()):
self.inputs.append(parent_node)
if parent_node is not None:
self.inputs_dims.append(self.operator.input(idx).dims())
self.inputs_dims.append(self.operator.get_input(idx).dims())
else:
self.inputs_dims.append(None)
for idx, child_node in enumerate(self.node.get_children()):
self.outputs.append(child_node)
# Dirty hot fix, change it quickly
self.outputs_dims.append(self.operator.output(0).dims())
self.outputs_dims.append(self.operator.get_output(0).dims())
@abstractmethod
def export(self, export_folder:str, list_configs:list):
......
......@@ -16,14 +16,14 @@ class test_operator_binding(unittest.TestCase):
Can be remove in later stage of the developpement.
"""
def setUp(self):
self.generic_operator = aidge_core.GenericOperator("FakeConv", 1, 1, 1).get_operator()
self.generic_operator = aidge_core.GenericOperator("FakeConv", 1, 0, 1).get_operator()
def tearDown(self):
pass
def test_default_name(self):
op_type = "Conv"
gop = aidge_core.GenericOperator(op_type, 1, 1, 1, "FictiveName")
gop = aidge_core.GenericOperator(op_type, 1, 0, 1, "FictiveName")
# check node name is not operator type
self.assertNotEqual(gop.name(), "Conv")
# check node name is not default
......@@ -95,12 +95,12 @@ class test_operator_binding(unittest.TestCase):
def test_compute_output_dims(self):
in_dims=[25, 25]
input = aidge_core.Producer(in_dims, name="In")
genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp")
genOp = aidge_core.GenericOperator("genOp", 1, 0, 1, name="genOp")
_ = aidge_core.sequential([input, genOp])
self.assertListEqual(genOp.get_operator().output(0).dims(), [])
self.assertListEqual(genOp.get_operator().get_output(0).dims(), [])
genOp.get_operator().set_compute_output_dims(lambda x:x)
genOp.get_operator().compute_output_dims()
self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims)
self.assertListEqual(genOp.get_operator().get_output(0).dims(), in_dims)
def test_set_impl(self):
......@@ -116,7 +116,7 @@ class test_operator_binding(unittest.TestCase):
"""
self.idx += 1
generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu")
generic_node = aidge_core.GenericOperator("Relu", 1, 0, 1, name="myReLu")
generic_op = generic_node.get_operator()
customImpl = PythonCustomImpl(generic_op)
......
......@@ -32,15 +32,17 @@ class test_attributes(unittest.TestCase):
self.assertEqual(conv_op.get_attr("KernelDims"), k_dims)
def test_fc(self):
in_channels = 4
out_channels = 8
nb_bias = True
fc_op = aidge_core.FC(out_channels, nb_bias).get_operator()
fc_op = aidge_core.FC(in_channels, out_channels, nb_bias).get_operator()
self.assertEqual(fc_op.get_attr("OutChannels"), out_channels)
self.assertEqual(fc_op.get_attr("NoBias"), nb_bias)
def test_matmul(self):
in_channels = 4
out_channels = 8
matmul_op = aidge_core.MatMul(out_channels).get_operator()
matmul_op = aidge_core.MatMul(in_channels, out_channels).get_operator()
self.assertEqual(matmul_op.get_attr("OutChannels"), out_channels)
def test_producer_1D(self):
......
......@@ -22,8 +22,8 @@ class test_recipies(unittest.TestCase):
def test_remove_flatten(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
aidge_core.FC(50, name='0')
aidge_core.GenericOperator("Flatten", 1, 0, 1, name="Flatten0"),
aidge_core.FC(10, 50, name='0')
])
old_nodes = graph_view.get_nodes()
aidge_core.remove_flatten(graph_view)
......@@ -33,9 +33,9 @@ class test_recipies(unittest.TestCase):
self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
def test_fuse_matmul_add(self):
matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0")
matmul0 = aidge_core.MatMul(1, 1, name="MatMul0")
add0 = aidge_core.Add(2, name="Add0")
matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1")
matmul1 = aidge_core.MatMul(1, 1, name="MatMul1")
add1 = aidge_core.Add(2, name="Add1")
graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1])
......
......@@ -20,10 +20,8 @@ namespace Aidge {
void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
// .def("set_output", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
.def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
// .def("set_input", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx"))
.def("nb_inputs", &Operator::nbInputs)
.def("nb_data", &Operator::nbData)
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Operator.hpp"
#include <pybind11/stl.h>
namespace py = pybind11;
namespace Aidge {
void init_OperatorTensor(py::module& m){
py::class_<OperatorTensor, std::shared_ptr<OperatorTensor>, Operator>(m, "OperatorTensor")
.def("get_output", &OperatorTensor::getOutput, py::arg("outputIdx"))
.def("get_input", &OperatorTensor::getInput, py::arg("inputIdx"))
;
}
}
......@@ -19,6 +19,7 @@ void init_Tensor(py::module&);
void init_OperatorImpl(py::module&);
void init_Attributes(py::module&);
void init_Operator(py::module&);
void init_OperatorTensor(py::module&);
void init_Add(py::module&);
void init_AvgPooling(py::module&);
......@@ -65,6 +66,7 @@ void init_Aidge(py::module& m){
init_OperatorImpl(m);
init_Attributes(m);
init_Operator(m);
init_OperatorTensor(m);
init_Add(m);
init_AvgPooling(m);
init_BatchNorm(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment