-
Houssem ROUIS authoredHoussem ROUIS authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
AddImpl.cpp 4.88 KiB
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/AddImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/utils/Types.h"
void Aidge::AddImpl_cuda::forward() {
const Add_Op& op = static_cast<const Add_Op&>(mOp);
// Check inputs
AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run Add forward because the 0-th input has no implementation.");
assert(op.getInput(0) && "missing input in Add operator");
DataType datatypeFirstInput = op.getInput(0)->dataType();
for (IOIndex_t i = 1; i < op.nbInputs(); ++i) {
AIDGE_ASSERT(op.getInput(i)->hasImpl(), "cannot run Add forward because the {}-th input has no implementation.", i);
assert(op.getInput(i) && "missing input in Add operator");
assert(op.getInput(i)->dataType() == datatypeFirstInput);
}
std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
std::vector<Tensor> inputs(op.nbInputs());
std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
// Get tensor dims and broadcast them
std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
// Compute the corresponding strides
std::vector<int> tensorStrides(dims[i].size());
int product = 1;
for (size_t j = dims[i].size(); j > 0; --j) {
tensorStrides[j - 1] = product;
product *= dims[i][j - 1];
}
strides[i] = tensorStrides;
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(inputs, dims, strides);
break;
case DataType::Float32:
forward_<float>(inputs, dims, strides);
break;
case DataType::Float16:
forward_<half>(inputs, dims, strides);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
}
template <class T>
void Aidge::AddImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
// Create a Tensor descriptor with the broadcasted dims and strides
cudnnTensorDescriptor_t tesnsorDesc;
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tesnsorDesc));
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tesnsorDesc, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
// Add first input
CHECK_CUDNN_STATUS(
cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha,
tesnsorDesc,
inputs[0].getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
);
// Add other inputs if there are any
for (size_t i = 1; i < op.nbInputs(); ++i)
{
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tesnsorDesc, CudaContext::data_type<T>::value, inputsDims[i].size(), inputsDims[i].data(), inputsStrides[i].data()));
CHECK_CUDNN_STATUS(
cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha,
tesnsorDesc,
inputs[i].getImpl()->rawPtr(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
);
}
CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tesnsorDesc));
}