Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
FCImpl.cpp 5.09 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <chrono>  // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread>  // std::this_thread::sleep_for
#include <vector>

#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/FCImpl.hpp"
#include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/utils/Types.h"

void Aidge::FCImpl_cuda::forward() {
    assert(mOp.getRawInput(0) && "missing input #0");
    assert(mOp.getRawInput(1) && "missing input #1");
    assert(mOp.getRawInput(2) && "missing input #2");

    std::shared_ptr<Tensor> inputFallback, input1Fallback, input2Fallback;
    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));

    const auto& fcOp = static_cast<const FC_Op&>(mOp);
    bool noBias = fcOp.template getAttr<FCAttr::NoBias>();
    std::size_t outChannels = static_cast<std::size_t>(fcOp.template getAttr<FCAttr::OutChannels>());

    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
        case DataType::Float64:
            forward_<double>(input0, input1, input2, noBias, outChannels);
            break;
        case DataType::Float32:
            forward_<float>(input0, input1, input2, noBias, outChannels);
            break;
        case DataType::Float16:
            forward_<half>(input0, input1, input2, noBias, outChannels);
            break;
        default:
            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
    }
}

template<class T>
void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, bool noBias, std::size_t outChannels)
{

    const T * input = static_cast<const T*>(input0.getImpl()->rawPtr());
    const T * weights = static_cast<const T*>(input1.getImpl()->rawPtr());
    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());

    int n = outChannels;
    int m = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->size()/n;
    int k = input0.size()/m;
    int lda = k;
    int ldb = k;
    int ldc = n;
    const T alpha = 1.0f;
    const T beta = 0.0f;
    CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(),
                                    CUBLAS_OP_T,
                                    CUBLAS_OP_N,
                                    n,
                                    m,
                                    k,
                                    reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                    weights,
                                    ldb,
                                    input, 
                                    lda,
                                    reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&beta),
                                    output,
                                    ldc));

    if(!noBias){
        T* onesVector;
        CHECK_CUDA_STATUS(cudaMalloc((void**)&onesVector, m * sizeof(T)));
        // Fill the vector with ones
        std::vector<T> onesVec(m, T(1.0));
        CHECK_CUDA_STATUS(cudaMemcpy(onesVector,
                                    &onesVec[0],
                                    m * sizeof(T),
                                    cudaMemcpyHostToDevice));
        const T * biases = static_cast<const T*>(input2.getImpl()->rawPtr());

        CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(),
                                       CUBLAS_OP_N,
                                       CUBLAS_OP_N,
                                       n,
                                       m,
                                       1,
                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                       biases,
                                       n,
                                       onesVector,
                                       1,
                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                       output,
                                       n));

        cudaFree(onesVector);
    }

}