Skip to content
Snippets Groups Projects
Commit a2d3a309 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed test issues

parent 098edff3
No related branches found
No related tags found
1 merge request!51Improve export
Pipeline #76080 passed with warnings
......@@ -25,10 +25,10 @@ def initFiller(model):
value.set_backend("cpu")
tuple_out = node.output(0)[0]
# No conv in current network
if tuple_out[0].type() == "Conv" and tuple_out[1] == 1:
if tuple_out[0].type() == "Conv2D" and tuple_out[1] == 1:
# Conv weight
aidge_core.xavier_uniform_filler(value)
elif tuple_out[0].type() == "Conv" and tuple_out[1] == 2:
elif tuple_out[0].type() == "Conv2D" and tuple_out[1] == 2:
# Conv bias
aidge_core.constant_filler(value, 0.01)
elif tuple_out[0].type() == "FC" and tuple_out[1] == 1:
......@@ -520,7 +520,7 @@ class test_operator_export(unittest.TestCase):
aidge_core.ConvDepthWise2D(nb_channels=3, kernel_dims=(3, 3), name="conv")
])
self.unit_test_export(model, "ConvDepthWise2D", [[1, 3, 12, 12]], False, False)
self.unit_test_export(model, "ConvDepthWise2D", [[1, 3, 12, 12]])
def test_max_pooling(self):
print("MaxPooling2D")
......@@ -529,7 +529,7 @@ class test_operator_export(unittest.TestCase):
])
model.set_ordered_outputs([(model.get_node("max_pool"), 0)], True)
self.unit_test_export(model, "MaxPooling2D", [[1, 2, 12, 12]], False, False)
self.unit_test_export(model, "MaxPooling2D", [[1, 2, 12, 12]])
def test_avg_pooling(self):
print("AvgPooling2D")
......@@ -537,7 +537,7 @@ class test_operator_export(unittest.TestCase):
aidge_core.AvgPooling2D(kernel_dims=(3, 3), name="avg_pool")
])
self.unit_test_export(model, "AvgPooling2D", [[1, 2, 12, 12]], False, False)
self.unit_test_export(model, "AvgPooling2D", [[1, 2, 12, 12]])
def test_pad2D(self):
print("Pad2D")
......@@ -668,5 +668,12 @@ class test_operator_export(unittest.TestCase):
initFiller(model)
self.unit_test_export(model, "Conv", [[1, 1, 9, 9]])
def test_Conv2(self):
model = aidge_core.sequential([
aidge_core.Conv2D(2, 2, [3, 3], name="InputNode")
])
initFiller(model)
self.unit_test_export(model, "Conv2", [[1, 2, 9, 9]])
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment