Skip to content
Snippets Groups Projects
Commit 31f049fc authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Use new conversion facilities from code for Conv

parent 4ae84dbb
No related branches found
No related tags found
2 merge requests!29Temporary master branch,!26Draft: Add Convert operator (a.k.a. Transmitter)
Pipeline #35472 failed
...@@ -33,14 +33,35 @@ void Aidge::ConvImpl2D_cpu::forward() { ...@@ -33,14 +33,35 @@ void Aidge::ConvImpl2D_cpu::forward() {
assert(mOp.getRawInput(2) && "missing input #2"); assert(mOp.getRawInput(2) && "missing input #2");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
Registrar<ConvImpl2DForward_cpu>::create({std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), const Registrar<ConvImpl2DForward_cpu>::registrar_key registrarKey = {
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
outputDataType};
Registrar<ConvImpl2DForward_cpu>::registrar_type kernelFunc;
if (Registrar<ConvImpl2DForward_cpu>::exists(registrarKey)) {
// One exists with the right inputs/output types
kernelFunc = Registrar<ConvImpl2DForward_cpu>::create(registrarKey);
}
else {
// Otherwise, fallback to the kernel with all types matching output type
kernelFunc = Registrar<ConvImpl2DForward_cpu>::create({
outputDataType, outputDataType, outputDataType, outputDataType});
}
// Convert input data (no overhead if not needed!)
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCast(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCast(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCast(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
// Call kernel // Call kernel
kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(), kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(), input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment