Skip to content
Snippets Groups Projects
Commit 1e582ba0 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed wrong scaling types

parent 4a03726d
No related branches found
No related tags found
1 merge request!32version 0.2.1
Pipeline #44284 failed
......@@ -45,18 +45,14 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
&strides[0]));
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
// Do the actual forward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input);
}
else {
forward_<float>(input);
}
}
......
......@@ -45,18 +45,14 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
&strides[0]));
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
// Do the actual forward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input);
}
else {
forward_<float>(input);
}
}
......
......@@ -37,18 +37,14 @@ void Aidge::ReLUImpl_cuda::forward() {
#endif
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
// Do the actual forward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input);
}
else {
forward_<float>(input);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment