Skip to content
Snippets Groups Projects
Commit c56937fd authored by Maxence Naud's avatar Maxence Naud
Browse files

fix missing access to removed attributes

parent 76615539
No related branches found
No related tags found
No related merge requests found
......@@ -36,7 +36,7 @@ class test_attributes(unittest.TestCase):
out_channels = 8
nb_bias = True
fc_op = aidge_core.FC(in_channels, out_channels, nb_bias).get_operator()
self.assertEqual(fc_op.get_attr("OutChannels"), out_channels)
self.assertEqual(fc_op.out_channels(), out_channels)
self.assertEqual(fc_op.get_attr("NoBias"), nb_bias)
def test_producer_1D(self):
......
......@@ -112,7 +112,7 @@ public:
DimSize_t outChannels() const {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed.");
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed.");
}
return getInput(1)->template dims<DIM+2>()[0];
}
......
......@@ -73,6 +73,13 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
DimSize_t outChannels() const {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Fully Connected (FC) operator has no weight Tensor associated so no specific number of output channel imposed.");
}
return getInput(1)->template dims<2>()[0];
}
static const std::vector<std::string> getInputsName() {
return {"data_input", "weight", "bias"};
}
......
......@@ -41,7 +41,9 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) {
py::arg("no_bias"))
.def_static("get_inputs_name", &ConvDepthWise_Op<DIM>::getInputsName)
.def_static("get_outputs_name", &ConvDepthWise_Op<DIM>::getOutputsName)
.def_static("attributes_name", &ConvDepthWise_Op<DIM>::staticGetAttrsName);
.def_static("attributes_name", &ConvDepthWise_Op<DIM>::staticGetAttrsName)
.def("nb_channels", &ConvDepthWise_Op<DIM>::nbChannels);
declare_registrable<ConvDepthWise_Op<DIM>>(m, pyClassName);
m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const DimSize_t nb_channels,
const std::vector<DimSize_t>& kernel_dims,
......
......@@ -25,7 +25,8 @@ void declare_FC(py::module &m) {
.def(py::init<bool>(), py::arg("no_bias"))
.def_static("get_inputs_name", &FC_Op::getInputsName)
.def_static("get_outputs_name", &FC_Op::getOutputsName)
.def_static("attributes_name", &FC_Op::staticGetAttrsName);
.def_static("attributes_name", &FC_Op::staticGetAttrsName)
.def("out_channels", &FC_Op::outChannels);
declare_registrable<FC_Op>(m, "FCOp");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment