Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
BitShitImpl.cpp 5.75 KiB
/********************************************************************************
 * Copyright (c) 2024 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

 #include <algorithm>
 #include <cassert>
 #include <numeric>
 #include <vector>
 
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/BitShiftImpl.hpp"
 #include "aidge/backend/cuda/operator/BitShiftImpl_CUDA_kernels.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
 #include "aidge/backend/cuda/utils/CudaUtils.hpp"
 #include "aidge/operator/BitShift.hpp"
 #include "aidge/utils/Types.h"
 
 void Aidge::BitShiftImpl_cuda::forward() {
     const BitShift_Op& op = static_cast<const BitShift_Op&>(mOp);
     // Check inputs
     AIDGE_ASSERT(op.getInput(0), "missing input in BitShift operator");
     AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run BitShift forward because the 0-th input has no implementation.");
     DataType datatypeFirstInput = op.getInput(0)->dataType();
     for (IOIndex_t i = 1; i < op.nbInputs(); ++i) {
         AIDGE_ASSERT(op.getInput(i), "missing input in BitShift operator");
         AIDGE_ASSERT(op.getInput(i)->hasImpl(), "cannot run BitShift forward because the {}-th input has no implementation.", i);
         AIDGE_ASSERT(op.getInput(i)->dataType() == datatypeFirstInput, "Cannot BitShift inputs with two differents data type.");
     }
 
     std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
     std::vector<Tensor> inputs(op.nbInputs());
     std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
     std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
     for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
        // TODO: remove the forced cast to int64
        const auto dt = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
        if(dt == DataType::Float32 || dt == DataType::Float64 ) {
            inputs[i] = op.getInput(i)->refCast(inputFallbacks[i], DataType::Int32);
        }
        else {
            inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
        } 
         // Get tensor dims and broadcast them
         std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
         dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
 
         if (dims[i].size() < 4) {
             dims[i].resize(4, 1);
         }
 
         // Compute the corresponding strides
         std::vector<int> tensorStrides(dims[i].size());
         int product = 1;
         for (size_t j = dims[i].size(); j > 0; --j) {
             tensorStrides[j - 1] = product;
             product *= dims[i][j - 1];
         }
         strides[i] = tensorStrides;
     }
     bool left = op.direction() == BitShift_Op::BitShiftDirection::left;
     switch(inputs[0].dataType()) {
         case DataType::Int64:
             forward_<int64_t>(inputs, dims, strides, left);
             break;
         case DataType::Int32:
             forward_<int32_t>(inputs, dims, strides, left);
             break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
 }
 
 template <class T>
 void Aidge::BitShiftImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides, bool left) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     // const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     // const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
     const T * input1Ptr = static_cast<const T*>(inputs[0].getImpl()->rawPtr());
     const T * input2Ptr = static_cast<const T*>(inputs[1].getImpl()->rawPtr());
    //  T * outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr());

    std::shared_ptr<Tensor> outputFallback;
    const auto dt = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
    Tensor outputCasted;
    if(dt == DataType::Float32 || dt == DataType::Float64 ) {
        outputCasted = op.getOutput(0)->refCastFrom(outputFallback, DataType::Int32, "cuda", op.getOutput(0)->device());
    } else {
        outputCasted = op.getOutput(0)->refCastFrom(outputFallback, *op.getOutput(0));
    }
    std::vector<int> outputStrides(op.getOutput(0)->nbDims(), 1);
    if(op.getOutput(0)->nbDims()>1) {
        for (int i = op.getOutput(0)->nbDims()-2; i >= 0; i--) {
            outputStrides[i] = outputStrides[i+1] *  op.getOutput(0)->dims()[i+1];
        }
    }
    std::vector<int> outDims(std::max(op.getOutput(0)->nbDims(),std::size_t(4)), 1);
    for (std::size_t i = 0; i < op.getOutput(0)->nbDims(); i++) {
        outDims[i] = static_cast<int>(op.getOutput(0)->dims()[i]);
    }

    Aidge::bitShiftForward<T>(input1Ptr, input2Ptr, static_cast<T*>(outputCasted.getImpl()->rawPtr()),
                inputsDims[0], inputsDims[1], outDims,
                inputsStrides[0], inputsStrides[1], outputStrides,
                static_cast<int>(op.getOutput(0)->size()), left);

   if(dt == DataType::Float32 || dt == DataType::Float64 ) {
       op.getOutput(0)->getImpl()->copyCast(outputCasted.getImpl()->rawPtr(),DataType::Int32, outputCasted.size());
   }else {
       // op.getOutput(0)->getImpl()->copy(outputCasted.getImpl()->rawPtr(),outputCasted.size());
       CHECK_CUDA_STATUS(cudaMemcpy(op.getOutput(0)->getImpl()->rawPtr(), outputCasted.getImpl()->rawPtr(), outputCasted.size() * sizeof(int), cudaMemcpyDeviceToDevice));
   }
}