Skip to content
Snippets Groups Projects
Commit ef861e58 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed binding

parent c7119497
No related branches found
No related tags found
1 merge request!16Unified interface for attributes
This commit is part of merge request !16. Comments created here will be created in the context of that merge request.
......@@ -27,21 +27,21 @@ class test_parameters(unittest.TestCase):
out_channels = 8
k_dims = [2, 2]
conv_op = aidge_core.Conv2D(in_channels , out_channels, k_dims).get_operator()
self.assertEqual(conv_op.get("InChannels"), in_channels)
self.assertEqual(conv_op.get("OutChannels"), out_channels)
self.assertEqual(conv_op.get("KernelDims"), k_dims)
self.assertEqual(conv_op.get_parameter("InChannels"), in_channels)
self.assertEqual(conv_op.get_parameter("OutChannels"), out_channels)
self.assertEqual(conv_op.get_parameter("KernelDims"), k_dims)
def test_fc(self):
out_channels = 8
nb_bias = True
fc_op = aidge_core.FC(out_channels, nb_bias).get_operator()
self.assertEqual(fc_op.get("OutChannels"), out_channels)
self.assertEqual(fc_op.get("NoBias"), nb_bias)
self.assertEqual(fc_op.get_parameter("OutChannels"), out_channels)
self.assertEqual(fc_op.get_parameter("NoBias"), nb_bias)
def test_matmul(self):
out_channels = 8
matmul_op = aidge_core.Matmul(out_channels).get_operator()
self.assertEqual(matmul_op.get("OutChannels"), out_channels)
self.assertEqual(matmul_op.get_parameter("OutChannels"), out_channels)
def test_producer_1D(self):
dims = [5]
......@@ -71,7 +71,7 @@ class test_parameters(unittest.TestCase):
def test_leaky_relu(self):
negative_slope = 0.25
leakyrelu_op = aidge_core.LeakyReLU(negative_slope).get_operator()
self.assertEqual(leakyrelu_op.get("NegativeSlope"), negative_slope)
self.assertEqual(leakyrelu_op.get_parameter("NegativeSlope"), negative_slope)
if __name__ == '__main__':
unittest.main()
......@@ -20,7 +20,7 @@ namespace py = pybind11;
namespace Aidge {
void init_GenericOperator(py::module& m) {
py::class_<GenericOperator_Op, std::shared_ptr<GenericOperator_Op>, Operator>(m, "GenericOperatorOp",
py::class_<GenericOperator_Op, std::shared_ptr<GenericOperator_Op>, Operator, DynamicParameters>(m, "GenericOperatorOp",
py::multiple_inheritance());
m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"),
py::arg("name") = "");
......
......@@ -11,7 +11,7 @@ void init_Parameters(py::module& m){
.def("get_parameters_name", &Parameters::getParametersName)
.def("get_parameter", &Parameters::getPy, py::arg("name"));
py::class_<DynamicParameters, std::shared_ptr<DynamicParameters>>(m, "DynamicParameters")
py::class_<DynamicParameters, std::shared_ptr<DynamicParameters>, Parameters>(m, "DynamicParameters")
.def("add_parameter", &DynamicParameters::addParameter<bool>)
.def("add_parameter", &DynamicParameters::addParameter<int>)
.def("add_parameter", &DynamicParameters::addParameter<float>)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment