Skip to content
Snippets Groups Projects

Make forwardDims() optional and handle data dependency

Merged Olivier BICHLER requested to merge fowarddims into dev
Compare and
81 files
+ 1417
443
Compare changes
  • Side-by-side
  • Inline
Files
81
@@ -39,7 +39,7 @@ class test_OperatorImpl(unittest.TestCase):
global GLOBAL_CPT
matmul = aidge_core.GenericOperator("MatMul", 1, 0, 1, name="MatMul0")
generic_matmul_op = matmul.get_operator()
generic_matmul_op.set_compute_output_dims(lambda x: x)
generic_matmul_op.set_forward_dims(lambda x: x)
generic_matmul_op.set_impl(testImpl(generic_matmul_op))
generic_matmul_op.forward()
self.assertEqual(GLOBAL_CPT, 1)
@@ -52,6 +52,7 @@ class test_OperatorImpl(unittest.TestCase):
self.assertTrue("cpu" in aidge_core.get_keys_ConvOp2D())
conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0")
conv.get_operator().set_backend("cpu")
conv.get_operator().set_input(0, aidge_core.Tensor(np.arange(18).reshape(1,2,3,3)))
conv.get_operator().forward()
self.assertEqual(GLOBAL_CPT, 1)
@@ -65,6 +66,7 @@ class test_OperatorImpl(unittest.TestCase):
conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0")
model = aidge_core.sequential([conv])
model.set_backend("cpu")
conv.get_operator().set_input(0, aidge_core.Tensor(np.arange(18).reshape(1,2,3,3)))
conv.get_operator().forward()
self.assertEqual(GLOBAL_CPT, 1)
Loading