diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index 87b54afbfd0b4c2d3bb57812d07575bc0e255626..6def25ba76e6ced9a1a623f0817e1f58c10631e1 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -33,14 +33,35 @@ void Aidge::ConvImpl2D_cpu::forward() { assert(mOp.getRawInput(2) && "missing input #2"); // Find the correct kernel type - auto kernelFunc = - Registrar<ConvImpl2DForward_cpu>::create({std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType(); + const Registrar<ConvImpl2DForward_cpu>::registrar_key registrarKey = { + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(), + outputDataType}; + + Registrar<ConvImpl2DForward_cpu>::registrar_type kernelFunc; + if (Registrar<ConvImpl2DForward_cpu>::exists(registrarKey)) { + // One exists with the right inputs/output types + kernelFunc = Registrar<ConvImpl2DForward_cpu>::create(registrarKey); + } + else { + // Otherwise, fallback to the kernel with all types matching output type + kernelFunc = Registrar<ConvImpl2DForward_cpu>::create({ + outputDataType, outputDataType, outputDataType, outputDataType}); + } + + // Convert input data (no overhead if not needed!) + // TODO: right now, if needed, memory will be allocated/deallocated at each + // call to forward(). We might put the following shared_ptr as members of + // this class to avoid that. + std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback; + const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCast(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCast(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCast(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); // Call kernel kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); }