diff --git a/aidge_core/unit_tests/test_parameters.py b/aidge_core/unit_tests/test_parameters.py index 342e4d9666910e90960b762277e8914c68df28b5..e7b16963f4c26e5d014ce90fa289c043e2eb0be4 100644 --- a/aidge_core/unit_tests/test_parameters.py +++ b/aidge_core/unit_tests/test_parameters.py @@ -36,7 +36,7 @@ class test_attributes(unittest.TestCase): out_channels = 8 nb_bias = True fc_op = aidge_core.FC(in_channels, out_channels, nb_bias).get_operator() - self.assertEqual(fc_op.get_attr("OutChannels"), out_channels) + self.assertEqual(fc_op.out_channels(), out_channels) self.assertEqual(fc_op.get_attr("NoBias"), nb_bias) def test_producer_1D(self): diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index f1c344063acdf1c530f10e5be0629c90bf6235f7..c30282f3438889e233f3d9ed22ab7c7e795b2951 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -112,7 +112,7 @@ public: DimSize_t outChannels() const { if (!getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed."); + AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed."); } return getInput(1)->template dims<DIM+2>()[0]; } diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 0b87ff7882c6bf0745cdc52a470b83a8276ad9d5..9f10970c4fd5b21a1cb92b334167d353f066e05b 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -73,6 +73,13 @@ public: void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + DimSize_t outChannels() const { + if (!getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Fully Connected (FC) operator has no weight Tensor associated so no specific number of output channel imposed."); + } + return getInput(1)->template dims<2>()[0]; + } + static const std::vector<std::string> getInputsName() { return {"data_input", "weight", "bias"}; } diff --git a/python_binding/operator/pybind_ConvDepthWise.cpp b/python_binding/operator/pybind_ConvDepthWise.cpp index c0c494e8db29c9520cc88fb1fb622cfa831fe9f3..ce286094d6606d8b7161acf9e3fb3c6cbcbb88c9 100644 --- a/python_binding/operator/pybind_ConvDepthWise.cpp +++ b/python_binding/operator/pybind_ConvDepthWise.cpp @@ -41,7 +41,9 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { py::arg("no_bias")) .def_static("get_inputs_name", &ConvDepthWise_Op<DIM>::getInputsName) .def_static("get_outputs_name", &ConvDepthWise_Op<DIM>::getOutputsName) - .def_static("attributes_name", &ConvDepthWise_Op<DIM>::staticGetAttrsName); + .def_static("attributes_name", &ConvDepthWise_Op<DIM>::staticGetAttrsName) + .def("nb_channels", &ConvDepthWise_Op<DIM>::nbChannels); + declare_registrable<ConvDepthWise_Op<DIM>>(m, pyClassName); m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const DimSize_t nb_channels, const std::vector<DimSize_t>& kernel_dims, diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index 989f88d1e62d354004da9c2d81a6006f90e47a9d..6cff90d0ad3aacf4cf8a465408eb490e3f21abda 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -25,7 +25,8 @@ void declare_FC(py::module &m) { .def(py::init<bool>(), py::arg("no_bias")) .def_static("get_inputs_name", &FC_Op::getInputsName) .def_static("get_outputs_name", &FC_Op::getOutputsName) - .def_static("attributes_name", &FC_Op::staticGetAttrsName); + .def_static("attributes_name", &FC_Op::staticGetAttrsName) + .def("out_channels", &FC_Op::outChannels); declare_registrable<FC_Op>(m, "FCOp");