diff --git a/aidge_core/export/node_export.py b/aidge_core/export/node_export.py index 980cb05a5814b7476d64757353e393ad6130218b..bea61551d6b4363d234fba4df6138ccef3154331 100644 --- a/aidge_core/export/node_export.py +++ b/aidge_core/export/node_export.py @@ -37,15 +37,15 @@ class ExportNode(ABC): for idx, parent_node in enumerate(self.node.get_parents()): self.inputs.append(parent_node) if parent_node is not None: - self.inputs_dims.append(self.operator.input(idx).dims()) + self.inputs_dims.append(self.operator.get_input(idx).dims()) else: self.inputs_dims.append(None) for idx, child_node in enumerate(self.node.get_children()): self.outputs.append(child_node) - + # Dirty hot fix, change it quickly - self.outputs_dims.append(self.operator.output(0).dims()) + self.outputs_dims.append(self.operator.get_output(0).dims()) @abstractmethod def export(self, export_folder:str, list_configs:list): diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index 7bd1e730a973810db89aa786b52fa05c53c43590..825ca6100382116443699a00bcff27b9bbca028a 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -16,14 +16,14 @@ class test_operator_binding(unittest.TestCase): Can be remove in later stage of the developpement. """ def setUp(self): - self.generic_operator = aidge_core.GenericOperator("FakeConv", 1, 1, 1).get_operator() + self.generic_operator = aidge_core.GenericOperator("FakeConv", 1, 0, 1).get_operator() def tearDown(self): pass def test_default_name(self): op_type = "Conv" - gop = aidge_core.GenericOperator(op_type, 1, 1, 1, "FictiveName") + gop = aidge_core.GenericOperator(op_type, 1, 0, 1, "FictiveName") # check node name is not operator type self.assertNotEqual(gop.name(), "Conv") # check node name is not default @@ -95,12 +95,12 @@ class test_operator_binding(unittest.TestCase): def test_compute_output_dims(self): in_dims=[25, 25] input = aidge_core.Producer(in_dims, name="In") - genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp") + genOp = aidge_core.GenericOperator("genOp", 1, 0, 1, name="genOp") _ = aidge_core.sequential([input, genOp]) - self.assertListEqual(genOp.get_operator().output(0).dims(), []) + self.assertListEqual(genOp.get_operator().get_output(0).dims(), []) genOp.get_operator().set_compute_output_dims(lambda x:x) genOp.get_operator().compute_output_dims() - self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) + self.assertListEqual(genOp.get_operator().get_output(0).dims(), in_dims) def test_set_impl(self): @@ -116,7 +116,7 @@ class test_operator_binding(unittest.TestCase): """ self.idx += 1 - generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu") + generic_node = aidge_core.GenericOperator("Relu", 1, 0, 1, name="myReLu") generic_op = generic_node.get_operator() customImpl = PythonCustomImpl(generic_op) diff --git a/aidge_core/unit_tests/test_parameters.py b/aidge_core/unit_tests/test_parameters.py index 566650713c36236c19763f466ee906970466c02e..620beb160fb3494f156c1a4b512d386447081154 100644 --- a/aidge_core/unit_tests/test_parameters.py +++ b/aidge_core/unit_tests/test_parameters.py @@ -32,15 +32,17 @@ class test_attributes(unittest.TestCase): self.assertEqual(conv_op.get_attr("KernelDims"), k_dims) def test_fc(self): + in_channels = 4 out_channels = 8 nb_bias = True - fc_op = aidge_core.FC(out_channels, nb_bias).get_operator() + fc_op = aidge_core.FC(in_channels, out_channels, nb_bias).get_operator() self.assertEqual(fc_op.get_attr("OutChannels"), out_channels) self.assertEqual(fc_op.get_attr("NoBias"), nb_bias) def test_matmul(self): + in_channels = 4 out_channels = 8 - matmul_op = aidge_core.MatMul(out_channels).get_operator() + matmul_op = aidge_core.MatMul(in_channels, out_channels).get_operator() self.assertEqual(matmul_op.get_attr("OutChannels"), out_channels) def test_producer_1D(self): diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 883cbffc0f22c6a3d009f643dadf0aec9eb3f8fc..6cf89a45fd0d4cf1dc970d199d074e886b131896 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -22,8 +22,8 @@ class test_recipies(unittest.TestCase): def test_remove_flatten(self): graph_view = aidge_core.sequential([ - aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), - aidge_core.FC(50, name='0') + aidge_core.GenericOperator("Flatten", 1, 0, 1, name="Flatten0"), + aidge_core.FC(10, 50, name='0') ]) old_nodes = graph_view.get_nodes() aidge_core.remove_flatten(graph_view) @@ -33,9 +33,9 @@ class test_recipies(unittest.TestCase): self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) def test_fuse_matmul_add(self): - matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0") + matmul0 = aidge_core.MatMul(1, 1, name="MatMul0") add0 = aidge_core.Add(2, name="Add0") - matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1") + matmul1 = aidge_core.MatMul(1, 1, name="MatMul1") add1 = aidge_core.Add(2, name="Add1") graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1]) diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index 09b1abef1e5ea66c6843594e3fa3beb20ec10740..f9482eda2f93b5492cfcc89175da69d140f23df8 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,10 +20,8 @@ namespace Aidge { void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) - // .def("set_output", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) .def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx")) .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) - // .def("set_input", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) .def("nb_inputs", &Operator::nbInputs) .def("nb_data", &Operator::nbData) diff --git a/python_binding/operator/pybind_OperatorTensor.cpp b/python_binding/operator/pybind_OperatorTensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..605ee27f491b02e5b4eeeb77287f5332bfd20bb0 --- /dev/null +++ b/python_binding/operator/pybind_OperatorTensor.cpp @@ -0,0 +1,26 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Operator.hpp" +#include <pybind11/stl.h> + +namespace py = pybind11; +namespace Aidge { +void init_OperatorTensor(py::module& m){ + py::class_<OperatorTensor, std::shared_ptr<OperatorTensor>, Operator>(m, "OperatorTensor") + .def("get_output", &OperatorTensor::getOutput, py::arg("outputIdx")) + .def("get_input", &OperatorTensor::getInput, py::arg("inputIdx")) + ; +} +} diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 6cc597b5ee934e4a3b849d45e92e5cb62be1b312..e625978ba2f15d4aff9e847e18ebc8076f31a165 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -19,6 +19,7 @@ void init_Tensor(py::module&); void init_OperatorImpl(py::module&); void init_Attributes(py::module&); void init_Operator(py::module&); +void init_OperatorTensor(py::module&); void init_Add(py::module&); void init_AvgPooling(py::module&); @@ -65,6 +66,7 @@ void init_Aidge(py::module& m){ init_OperatorImpl(m); init_Attributes(m); init_Operator(m); + init_OperatorTensor(m); init_Add(m); init_AvgPooling(m); init_BatchNorm(m);