-
Houssem ROUIS authoredHoussem ROUIS authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
BitShitImpl.cpp 5.75 KiB
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/BitShiftImpl.hpp"
#include "aidge/backend/cuda/operator/BitShiftImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/BitShift.hpp"
#include "aidge/utils/Types.h"
void Aidge::BitShiftImpl_cuda::forward() {
const BitShift_Op& op = static_cast<const BitShift_Op&>(mOp);
// Check inputs
AIDGE_ASSERT(op.getInput(0), "missing input in BitShift operator");
AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run BitShift forward because the 0-th input has no implementation.");
DataType datatypeFirstInput = op.getInput(0)->dataType();
for (IOIndex_t i = 1; i < op.nbInputs(); ++i) {
AIDGE_ASSERT(op.getInput(i), "missing input in BitShift operator");
AIDGE_ASSERT(op.getInput(i)->hasImpl(), "cannot run BitShift forward because the {}-th input has no implementation.", i);
AIDGE_ASSERT(op.getInput(i)->dataType() == datatypeFirstInput, "Cannot BitShift inputs with two differents data type.");
}
std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
std::vector<Tensor> inputs(op.nbInputs());
std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
// TODO: remove the forced cast to int64
const auto dt = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
if(dt == DataType::Float32 || dt == DataType::Float64 ) {
inputs[i] = op.getInput(i)->refCast(inputFallbacks[i], DataType::Int32);
}
else {
inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
}
// Get tensor dims and broadcast them
std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
if (dims[i].size() < 4) {
dims[i].resize(4, 1);
}
// Compute the corresponding strides
std::vector<int> tensorStrides(dims[i].size());
int product = 1;
for (size_t j = dims[i].size(); j > 0; --j) {
tensorStrides[j - 1] = product;
product *= dims[i][j - 1];
}
strides[i] = tensorStrides;
}
bool left = op.direction() == BitShift_Op::BitShiftDirection::left;
switch(inputs[0].dataType()) {
case DataType::Int64:
forward_<int64_t>(inputs, dims, strides, left);
break;
case DataType::Int32:
forward_<int32_t>(inputs, dims, strides, left);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
}
template <class T>
void Aidge::BitShiftImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides, bool left) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
// const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
// const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
const T * input1Ptr = static_cast<const T*>(inputs[0].getImpl()->rawPtr());
const T * input2Ptr = static_cast<const T*>(inputs[1].getImpl()->rawPtr());
// T * outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr());
std::shared_ptr<Tensor> outputFallback;
const auto dt = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
Tensor outputCasted;
if(dt == DataType::Float32 || dt == DataType::Float64 ) {
outputCasted = op.getOutput(0)->refCastFrom(outputFallback, DataType::Int32, "cuda", op.getOutput(0)->device());
} else {
outputCasted = op.getOutput(0)->refCastFrom(outputFallback, *op.getOutput(0));
}
std::vector<int> outputStrides(op.getOutput(0)->nbDims(), 1);
if(op.getOutput(0)->nbDims()>1) {
for (int i = op.getOutput(0)->nbDims()-2; i >= 0; i--) {
outputStrides[i] = outputStrides[i+1] * op.getOutput(0)->dims()[i+1];
}
}
std::vector<int> outDims(std::max(op.getOutput(0)->nbDims(),std::size_t(4)), 1);
for (std::size_t i = 0; i < op.getOutput(0)->nbDims(); i++) {
outDims[i] = static_cast<int>(op.getOutput(0)->dims()[i]);
}
Aidge::bitShiftForward<T>(input1Ptr, input2Ptr, static_cast<T*>(outputCasted.getImpl()->rawPtr()),
inputsDims[0], inputsDims[1], outDims,
inputsStrides[0], inputsStrides[1], outputStrides,
static_cast<int>(op.getOutput(0)->size()), left);
if(dt == DataType::Float32 || dt == DataType::Float64 ) {
op.getOutput(0)->getImpl()->copyCast(outputCasted.getImpl()->rawPtr(),DataType::Int32, outputCasted.size());
}else {
// op.getOutput(0)->getImpl()->copy(outputCasted.getImpl()->rawPtr(),outputCasted.size());
CHECK_CUDA_STATUS(cudaMemcpy(op.getOutput(0)->getImpl()->rawPtr(), outputCasted.getImpl()->rawPtr(), outputCasted.size() * sizeof(int), cudaMemcpyDeviceToDevice));
}
}