Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
FCImpl.cpp 10.40 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <chrono>  // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread>  // std::this_thread::sleep_for
#include <vector>

#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/FCImpl.hpp"
#include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/utils/Types.h"

void Aidge::FCImpl_cuda::forward() {
    AIDGE_ASSERT(mOp.getRawInput(0), "missing input #0");
    AIDGE_ASSERT(mOp.getRawInput(1), "missing input #1");
    AIDGE_ASSERT(mOp.getRawInput(2), "missing input #2");

    const auto& fcOp = static_cast<const FC_Op&>(mOp);
    std::size_t outChannels = fcOp.outChannels();

    const auto& input0 = fcOp.getInput(0)->refCastFrom(mInput0Fallback, *fcOp.getOutput(0));
    const auto& input1 = fcOp.getInput(1)->refCastFrom(mInput1Fallback, *fcOp.getOutput(0));
    const auto& input2 = (fcOp.getInput(2)) ? fcOp.getInput(2)->refCastFrom(mInput2Fallback, *fcOp.getOutput(0)) : Tensor();

    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
        case DataType::Float64:
            forward_<double>(input0, input1, input2, outChannels);
            break;
        case DataType::Float32:
            forward_<float>(input0, input1, input2, outChannels);
            break;
        case DataType::Float16:
            forward_<half>(input0, input1, input2, outChannels);
            break;
        default:
            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
    }
}

template<class T>
void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, std::size_t outChannels)
{
    const T * input = static_cast<const T*>(input0.getImpl()->rawPtr());
    const T * weights = static_cast<const T*>(input1.getImpl()->rawPtr());
    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());

    // Performing output = T(weights) * input
    //            [n x m] = [n x k] * [k x m]
    // cublas is column-major so instead of transposing inputs, computing output [m x n] and transposing output, we compute output as [n x m]
    int n = outChannels;
    int m = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->size()/n;
    int k = input0.size()/m;
    int lda = k;  // leading dimension of weights
    int ldb = k;  // leading dimension of input
    int ldc = n;  // leading dimension of output
    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
    CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(),
                                    CUBLAS_OP_T,
                                    CUBLAS_OP_N,
                                    n,
                                    m,
                                    k,
                                    reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                    weights,
                                    ldb,
                                    input,
                                    lda,
                                    reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&beta),
                                    output,
                                    ldc));

    if(!input2.empty()){
        T* onesVector;
        CHECK_CUDA_STATUS(cudaMalloc((void**)&onesVector, m * sizeof(T)));
        // Fill the vector with ones
        std::vector<T> onesVec(m, T(1.0));
        CHECK_CUDA_STATUS(cudaMemcpy(onesVector,
                                    &onesVec[0],
                                    m * sizeof(T),
                                    cudaMemcpyHostToDevice));
        const T * biases = static_cast<const T*>(input2.getImpl()->rawPtr());
        // Performing output = biases * onesVector + output
        //           [n x m] = [n x 1] * [1 x m]   + [n x m]
        CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(),
                                       CUBLAS_OP_N,
                                       CUBLAS_OP_N,
                                       n,
                                       m,
                                       1,
                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                       biases,
                                       n,
                                       onesVector,
                                       1,
                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                       output,
                                       n));

        CHECK_CUDA_STATUS(cudaFree(onesVector));
    }

}

void Aidge::FCImpl_cuda::backward() {

    AIDGE_ASSERT(mOp.getRawInput(0), "missing input #0");
    AIDGE_ASSERT(mOp.getRawInput(1), "missing input #1");
    AIDGE_ASSERT(mOp.getRawInput(2), "missing input #2");

    const auto& fcOp = static_cast<const FC_Op&>(mOp);
    std::size_t outChannels = fcOp.outChannels();

    const auto& input0 = fcOp.getInput(0)->refCastFrom(mInput0Fallback, *fcOp.getOutput(0));
    const auto& input1 = fcOp.getInput(1)->refCastFrom(mInput1Fallback, *fcOp.getOutput(0));
    const auto& input2 = (fcOp.getInput(2)) ? fcOp.getInput(2)->refCastFrom(mInput2Fallback, *fcOp.getOutput(0)) : Tensor();

    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
        case DataType::Float64:
            backward_<double>(input0, input1, input2, outChannels);
            break;
        case DataType::Float32:
            backward_<float>(input0, input1, input2, outChannels);
            break;
        case DataType::Float16:
            backward_<half>(input0, input1, input2, outChannels);
            break;
        default:
            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
    }
}

template<class T>
void Aidge::FCImpl_cuda::backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, std::size_t outChannels)
{
    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
    const typename Cuda::cudnn_scaling_type<T>::type betaData = 0.0f;
    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
    const T * input = static_cast<const T*>(input0.getImpl()->rawPtr());
    const T * weights = static_cast<const T*>(input1.getImpl()->rawPtr());
    const T * outputGrad = static_cast<const T*>(op.getOutput(0)->grad()->getImpl()->rawPtr());
    T * weightsGrad = static_cast<T*>(op.getInput(1)->grad()->getImpl()->rawPtr());




    // Performing weightsGrad = (input) * T(outputGrad)
    //              [n x m]   = [n x k] *   [k x m]
    int m = input1.dims()[1];
    int k = input0.size()/m;
    int n = input1.size()/m;
    CHECK_CUBLAS_STATUS(cublasGemm(
        CudaContext::cublasHandle(),
        CUBLAS_OP_N,
        CUBLAS_OP_T,
        m,
        n,
        k,
        reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
        input,
        m,
        outputGrad,
        n,
        reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&beta),
        weightsGrad,
        m));

    if(!input2.empty()){
        T * biasGrad = static_cast<T*>(op.getInput(2)->grad()->getImpl()->rawPtr());
        T* onesVector;
        CHECK_CUDA_STATUS(cudaMalloc((void**)&onesVector, m * sizeof(T)));
        // Fill the vector with ones
        std::vector<T> onesVec(m, T(1.0));
        CHECK_CUDA_STATUS(cudaMemcpy(onesVector,
                                    &onesVec[0],
                                    m * sizeof(T),
                                    cudaMemcpyHostToDevice));
        // Performing biasGrad = outputGrad * onesVector
        CHECK_CUBLAS_STATUS(cublasGemv(CudaContext::cublasHandle(),
                                       CUBLAS_OP_N,
                                       outChannels,
                                       k,
                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                       outputGrad,
                                       outChannels,
                                       onesVector,
                                       1,
                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&beta),
                                       biasGrad,
                                       1));
        CHECK_CUDA_STATUS(cudaFree(onesVector));
    }

    // XXX XXX XXX
/*
    op.getOutput(0)->grad()->setBackend("cpu");
    float * test_ptr = static_cast<float *> (op.getOutput(0)->grad()->getImpl()->rawPtr());
    float acc = 0;
    for (int i  = 0; i < op.getOutput(0)->grad()->size(); i++)
        acc += test_ptr[i];
    printf(" FC OUT GRAD = %f \n", 1000 * acc);
    op.getOutput(0)->grad()->setBackend("cuda");
*/

    // Performing inputGrad = (weights) * (outputGrad)
    CHECK_CUBLAS_STATUS(cublasGemm(
        CudaContext::cublasHandle(),
        CUBLAS_OP_N,
        CUBLAS_OP_N,
        op.getInput(1)->grad()->size()/outChannels,
        k,
        outChannels,
        reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
        weights,//w
        op.getInput(1)->grad()->size()/outChannels,
        outputGrad,//dY
        outChannels,
        reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&betaData),
        static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()),//dX
        op.getInput(1)->grad()->size()/outChannels));

        // XXX XXX XXX
/*
        op.getInput(1)->grad()->setBackend("cpu");
        test_ptr = static_cast<float *> (op.getInput(1)->grad()->getImpl()->rawPtr());
        acc = 0;
        for (int i  = 0; i < op.getInput(1)->grad()->size(); i++)
            acc += test_ptr[i];
        printf(" FC IN GRAD = %f \n", 1000 * acc);
        op.getInput(1)->grad()->setBackend("cuda");
*/
}