Skip to content
Snippets Groups Projects
Commit 6a3c1b3b authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Cyril Moineau
Browse files

Fixed wrong scaling types

parent 6605927b
No related branches found
No related tags found
1 merge request!200.2.1
......@@ -45,18 +45,14 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
&strides[0]));
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
// Do the actual forward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input);
}
else {
forward_<float>(input);
}
}
......
......@@ -45,18 +45,14 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
&strides[0]));
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
// Do the actual forward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input);
}
else {
forward_<float>(input);
}
}
......
......@@ -37,18 +37,14 @@ void Aidge::ReLUImpl_cuda::forward() {
#endif
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
// Do the actual forward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input);
}
else {
forward_<float>(input);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment