diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 87b54afbfd0b4c2d3bb57812d07575bc0e255626..6def25ba76e6ced9a1a623f0817e1f58c10631e1 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -33,14 +33,35 @@ void Aidge::ConvImpl2D_cpu::forward() {
     assert(mOp.getRawInput(2) && "missing input #2");
 
     // Find the correct kernel type
-    auto kernelFunc =
-            Registrar<ConvImpl2DForward_cpu>::create({std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
-                                                      std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
-                                                      std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
-                                                      std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
+    const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
+    const Registrar<ConvImpl2DForward_cpu>::registrar_key registrarKey = {
+        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
+        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
+        std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
+        outputDataType};
+
+    Registrar<ConvImpl2DForward_cpu>::registrar_type kernelFunc;
+    if (Registrar<ConvImpl2DForward_cpu>::exists(registrarKey)) {
+        // One exists with the right inputs/output types
+        kernelFunc = Registrar<ConvImpl2DForward_cpu>::create(registrarKey);
+    }
+    else {
+        // Otherwise, fallback to the kernel with all types matching output type
+        kernelFunc = Registrar<ConvImpl2DForward_cpu>::create({
+            outputDataType, outputDataType, outputDataType, outputDataType});
+    }
+
+    // Convert input data (no overhead if not needed!)
+    // TODO: right now, if needed, memory will be allocated/deallocated at each
+    // call to forward(). We might put the following shared_ptr as members of
+    // this class to avoid that.
+    std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
+    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCast(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCast(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCast(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
 
     // Call kernel
     kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
-               std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
-               std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
+        input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
+        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
 }