Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
RoundImpl.cpp 2.01 KiB
/********************************************************************************
 * Copyright (c) 2024 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>

#include <cuda_fp16.h>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/RoundImpl.hpp"
#include "aidge/backend/cuda/operator/RoundImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/Round.hpp"
#include "aidge/utils/Types.h"

void Aidge::RoundImpl_cuda::forward() {
    const Round_Op& op = static_cast<const Round_Op&>(mOp);
    // Check inputs
    AIDGE_ASSERT(op.getInput(0), "missing input in Round operator");
    AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run Round forward because the 0-th input has no implementation.");
    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
        case DataType::Float64:
            forward_<double>();
            break;
        case DataType::Float32:
            forward_<float>();
            break;
        case DataType::Float16:
            forward_<half>();
            break;
        default:
            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
    }
}

template <class T>
void Aidge::RoundImpl_cuda::forward_() 
{
    const Round_Op& op = static_cast<const Round_Op&>(mOp);
    int size = op.getInput(0)->size();
    const T* inputPtr = static_cast<T*>(op.getInput(0)->getImpl()->rawPtr());
    T* outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr());
    Aidge::roundForward<T>(inputPtr,outputPtr,size,op.roundingMode());
}