-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
RoundImpl.cpp 2.01 KiB
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>
#include <cuda_fp16.h>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/RoundImpl.hpp"
#include "aidge/backend/cuda/operator/RoundImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/Round.hpp"
#include "aidge/utils/Types.h"
void Aidge::RoundImpl_cuda::forward() {
const Round_Op& op = static_cast<const Round_Op&>(mOp);
// Check inputs
AIDGE_ASSERT(op.getInput(0), "missing input in Round operator");
AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run Round forward because the 0-th input has no implementation.");
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>();
break;
case DataType::Float32:
forward_<float>();
break;
case DataType::Float16:
forward_<half>();
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
}
template <class T>
void Aidge::RoundImpl_cuda::forward_()
{
const Round_Op& op = static_cast<const Round_Op&>(mOp);
int size = op.getInput(0)->size();
const T* inputPtr = static_cast<T*>(op.getInput(0)->getImpl()->rawPtr());
T* outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr());
Aidge::roundForward<T>(inputPtr,outputPtr,size,op.roundingMode());
}