Skip to content
Snippets Groups Projects

Update operators implementation

Merged Maxence Naud requested to merge OperatorTensor into master
53 files
+ 1770
726
Compare changes
  • Side-by-side
  • Inline
Files
53
@@ -49,24 +49,24 @@ class test_recipies(unittest.TestCase):
@@ -49,24 +49,24 @@ class test_recipies(unittest.TestCase):
np_shift = np.array([0.05]).astype(np.float32)
np_shift = np.array([0.05]).astype(np.float32)
np_mean = np.array([0.05]).astype(np.float32)
np_mean = np.array([0.05]).astype(np.float32)
np_var = np.array([0.05]).astype(np.float32)
np_var = np.array([0.05]).astype(np.float32)
conv.input(1)[0].get_operator().set_output_tensor(aidge_core.Tensor(np_weights))
conv.input(1)[0].get_operator().set_output(0, aidge_core.Tensor(np_weights))
conv.input(2)[0].get_operator().set_output_tensor(aidge_core.Tensor(np_bias))
conv.input(2)[0].get_operator().set_output(0, aidge_core.Tensor(np_bias))
bn.input(1)[0].get_operator().set_output_tensor(aidge_core.Tensor(np_scale))
bn.input(1)[0].get_operator().set_output(0, aidge_core.Tensor(np_scale))
bn.input(2)[0].get_operator().set_output_tensor(aidge_core.Tensor(np_shift))
bn.input(2)[0].get_operator().set_output(0, aidge_core.Tensor(np_shift))
bn.input(3)[0].get_operator().set_output_tensor(aidge_core.Tensor(np_mean))
bn.input(3)[0].get_operator().set_output(0, aidge_core.Tensor(np_mean))
bn.input(4)[0].get_operator().set_output_tensor(aidge_core.Tensor(np_var))
bn.input(4)[0].get_operator().set_output(0, aidge_core.Tensor(np_var))
scheduler0 = aidge_core.SequentialScheduler(graph_view)
scheduler0 = aidge_core.SequentialScheduler(graph_view)
scheduler0.forward()
scheduler0.forward()
for outNode in graph_view.get_output_nodes():
for outNode in graph_view.get_output_nodes():
output_aidge0 = outNode.get_operator().output(0)
output_aidge0 = outNode.get_operator().get_output(0)
aidge_core.fuse_batchnorm(graph_view)
aidge_core.fuse_batchnorm(graph_view)
scheduler1 = aidge_core.SequentialScheduler(graph_view)
scheduler1 = aidge_core.SequentialScheduler(graph_view)
scheduler1.forward()
scheduler1.forward()
for outNode in graph_view.get_output_nodes():
for outNode in graph_view.get_output_nodes():
output_aidge1 = outNode.get_operator().output(0)
output_aidge1 = outNode.get_operator().get_output(0)
self.assertTrue(aidge_core.approx_eq(output_aidge0, output_aidge1, 0.000001, 0.0001))
self.assertTrue(aidge_core.approx_eq(output_aidge0, output_aidge1, 0.000001, 0.0001))
Loading