Skip to content
Snippets Groups Projects
Commit c41f8643 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

minor cleanings

parent e8564e81
No related branches found
No related tags found
2 merge requests!32version 0.2.1,!14MobileNet operators
/******************************************************************************** /********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2024 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
......
/******************************************************************************** /********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2024 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
......
...@@ -45,18 +45,11 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() { ...@@ -45,18 +45,11 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
&strides[0])); &strides[0]));
} }
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { if (op.getOutput(0)->dataType() == DataType::Float64) {
case DataType::Float64: forward_<double>(input);
forward_<double>(input); }
break; else {
case DataType::Float32: forward_<float>(input);
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
} }
} }
......
/******************************************************************************** /********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2024 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
......
/******************************************************************************** /********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2024 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
...@@ -39,18 +39,11 @@ void Aidge::GlobalAveragePoolingImpl_cuda::forward() { ...@@ -39,18 +39,11 @@ void Aidge::GlobalAveragePoolingImpl_cuda::forward() {
); );
} }
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { if (op.getOutput(0)->dataType() == DataType::Float64) {
case DataType::Float64: forward_<double>(input);
forward_<double>(input); }
break; else {
case DataType::Float32: forward_<float>(input);
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
} }
} }
......
...@@ -45,18 +45,11 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() { ...@@ -45,18 +45,11 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
&strides[0])); &strides[0]));
} }
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { if (op.getOutput(0)->dataType() == DataType::Float64) {
case DataType::Float64: forward_<double>(input);
forward_<double>(input); }
break; else {
case DataType::Float32: forward_<float>(input);
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
} }
} }
......
/******************************************************************************** /********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2024 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
...@@ -37,18 +37,11 @@ void Aidge::ReLUImpl_cuda::forward() { ...@@ -37,18 +37,11 @@ void Aidge::ReLUImpl_cuda::forward() {
#endif #endif
} }
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { if (op.getOutput(0)->dataType() == DataType::Float64) {
case DataType::Float64: forward_<double>(input);
forward_<double>(input); }
break; else {
case DataType::Float32: forward_<float>(input);
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
} }
} }
......
/******************************************************************************** /********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2024 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
#include <thread> // std::this_thread::sleep_for #include <thread> // std::this_thread::sleep_for
#include <vector> #include <vector>
#include "aidge/utils/Types.h"
#include "aidge/operator/Reshape.hpp"
#include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/ReshapeImpl.hpp" #include "aidge/backend/cuda/operator/ReshapeImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/operator/Reshape.hpp"
#include "aidge/utils/Types.h"
void Aidge::ReshapeImpl_cuda::forward() { void Aidge::ReshapeImpl_cuda::forward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment