Skip to content
Snippets Groups Projects
Commit 18be0fc3 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add BatchNorm operator

parent c37a52bd
No related branches found
No related tags found
2 merge requests!32version 0.2.1,!14MobileNet operators
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_BACKEND_CUDA_OPERATOR_BATCHNORMIMPL_H_
#define AIDGE_BACKEND_CUDA_OPERATOR_BATCHNORMIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include <cudnn.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
template <DimIdx_t DIM>
class BatchNormImpl_cuda : public OperatorImpl {
private:
// CuDNN specific variables
cudnnTensorDescriptor_t mBNDesc = nullptr;
cudnnBatchNormMode_t mMode;
double mEpsilon;
public:
BatchNormImpl_cuda(const BatchNorm_Op<DIM> &op) : OperatorImpl(op, "cuda") {}
static std::unique_ptr<BatchNormImpl_cuda> create(const BatchNorm_Op<DIM> &op) {
return std::make_unique<BatchNormImpl_cuda>(op);
}
public:
void forward();
~BatchNormImpl_cuda();
private:
template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, const Tensor& input3, const Tensor& input4);
};
namespace {
// add cuda backend to BatchNorm_Op<2> implementation registry
static Registrar<BatchNorm_Op<2>> registrarBatchNormImpl_cuda("cuda", Aidge::BatchNormImpl_cuda<2>::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_BACKEND_CUDA_OPERATOR_BATCHNORMIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include <iostream>
#include "aidge/utils/Types.h"
#include "aidge/operator/BatchNorm.hpp"
#include <cuda_runtime.h>
#include <cudnn.h>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/BatchNormImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
template <Aidge::DimIdx_t DIM>
void Aidge::BatchNormImpl_cuda<DIM>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
assert(mOp.getRawInput(3) && "missing input #3");
assert(mOp.getRawInput(4) && "missing input #4");
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback, input3Fallback, input4Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input3 = std::static_pointer_cast<Tensor>(mOp.getRawInput(3))->refCastFrom(input3Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input4 = std::static_pointer_cast<Tensor>(mOp.getRawInput(4))->refCastFrom(input4Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
if (mBNDesc == nullptr)
{
const BatchNorm_Op<DIM>& bnOp = static_cast<const BatchNorm_Op<DIM>&>(mOp);
mEpsilon = static_cast<double>(bnOp.template getAttr<BatchNormAttr::Epsilon>());
mMode = CUDNN_BATCHNORM_SPATIAL;
// CUDNN_BN_MIN_EPSILON is set to 0.0 since cuDNN 7.5.0
if (CUDNN_BN_MIN_EPSILON > 0.0 && mEpsilon < CUDNN_BN_MIN_EPSILON) {
mEpsilon = CUDNN_BN_MIN_EPSILON;
}
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mBNDesc));
// auto tst = dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc();
CHECK_CUDNN_STATUS(cudnnDeriveBNTensorDescriptor(
mBNDesc, std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0), mMode));
cudnnDataType_t dataType;
const unsigned int nbDimsRequested = DIM;
std::vector<int> dims(nbDimsRequested);
std::vector<int> strides(nbDimsRequested);
int nbDims;
CHECK_CUDNN_STATUS(cudnnGetTensorNdDescriptor(mBNDesc,
nbDimsRequested,
&dataType,
&nbDims,
&dims[0],
&strides[0]));
dims.resize(nbDims);
strides.resize(nbDims);
}
if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
forward_<double>(input0, input1, input2, input3, input4);
}
else {
forward_<float>(input0, input1, input2, input3, input4);
}
}
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::BatchNormImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, const Tensor& input3, const Tensor& input4) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
CHECK_CUDNN_STATUS(
cudnnBatchNormalizationForwardInference(
CudaContext::cudnnHandle(),
mMode,
&alpha,
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
input0.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(input1.getImpl())->getCudnnTensorDesc(input1),//scaleBiasMeanVarDesc,
input1.getImpl()->rawPtr(),
input2.getImpl()->rawPtr(),
input3.getImpl()->rawPtr(),
input4.getImpl()->rawPtr(),
mEpsilon)
);
}
template <Aidge::DimIdx_t DIM>
Aidge::BatchNormImpl_cuda<DIM>::~BatchNormImpl_cuda() {
if(mBNDesc != nullptr)
{
cudnnDestroyTensorDescriptor(mBNDesc);
}
}
// Template declarations
template class Aidge::BatchNormImpl_cuda<2>;
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <array>
#include <catch2/catch_test_macros.hpp>
#include "Test_cuda.hpp"
#include <iostream>
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/backend/cuda.hpp"
using namespace Aidge;
TEST_CASE("[gpu/operator] BatchNorm(forward)") {
std::shared_ptr<Node> myBatchNorm = BatchNorm<2>(3, 0.00001F, 0.1F, "mybatchnorm");
auto op = std::static_pointer_cast<OperatorTensor>(myBatchNorm -> getOperator());
op->setDataType(DataType::Float32);
op->setBackend("cuda");
std::shared_ptr<Tensor> myWeights= std::make_shared<Tensor>(Array2D<float,1,3> {{{0.9159252643585205, 0.18772238492965698, 0.4479946792125702}}});
std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array2D<float,1,3> {{{0.33898890018463135, 0.3167555630207062, 0.7047033309936523}}});
std::shared_ptr<Tensor> myMean = std::make_shared<Tensor>(Array2D<float,1,3> {{{0.45547693967819214, 0.22650663554668427, 0.6612948179244995}}});
std::shared_ptr<Tensor> myVar = std::make_shared<Tensor>(Array2D<float,1,3> {{{0.02570258639752865, 0.026536229997873306, 0.15111008286476135}}});
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,3,3> { //NCHW
{
{
{{0.12943482, 0.6451229 , 0.24979436},
{0.7551012, 0.32007095, 0.89463896},
{0.7087448, 0.6266124, 0.4782957 }},
{{0.13796203, 0.9950787, 0.71555305},
{0.01347321, 0.4395316, 0.43097174},
{0.6056306 , 0.9561122 , 0.5783939 }},
{{0.7174486 , 0.503465 , 0.23695093},
{0.5145477, 0.39576462, 0.02779444},
{0.60789394 ,0.14119725 ,0.20753163}}
},
{{{0.74452287, 0.5354875 , 0.8148496 },
{0.73356223, 0.4304034 , 0.11783765},
{0.8966221, 0.41049036, 0.95982736}},
{{0.03161403, 0.71250844, 0.14337301},
{0.5338889 , 0.13484782, 0.8055851 },
{0.71784616 ,0.8349626 , 0.10107189}},
{{0.85701346, 0.58286697, 0.9836816 },
{0.36061534, 0.03660944, 0.7375317 },
{0.6977233, 0.51965624, 0.29440993}}
}
}
});
std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,2,3,3,3> {
{
{
{{-1.5233592, 1.4222438, -0.83586717},
{ 2.0504384, -0.43444824, 2.847476 },
{ 1.7856512, 1.3165123, 0.46932936}},
{{ 0.21473758 , 1.2022772, 0.8802177 },
{ 0.07130594 , 0.5621954, 0.55233306},
{ 0.7535689 , 1.1573814, 0.72218764}},
{{ 0.7694162 , 0.52281666, 0.2156798 },
{ 0.5355886 , 0.3987003, -0.02535689},
{ 0.6431629 , 0.10533108 , 0.18177633}}},
{{{ 1.990015, 0.7960079, 2.3917203 },
{ 1.9274082, 0.19576907, -1.5896021 },
{ 2.8588037 , 0.08202624 , 3.2198315 }},
{{ 0.09220716, 0.8767097, 0.22097193},
{ 0.6709106 , 0.2111495, 0.9839494 },
{ 0.8828597 , 1.0177971 , 0.17223406}},
{{ 0.9302539 , 0.6143213 , 1.0762292 },
{ 0.35819346, -0.01519828, 0.79256046},
{ 0.7466844 , 0.5414758 , 0.28189686}}
}
}
});
myInput->setBackend("cuda");
myWeights->setBackend("cuda");
myBias->setBackend("cuda");
myMean->setBackend("cuda");
myVar->setBackend("cuda");
op->associateInput(0,myInput);
op->associateInput(1,myWeights);
op->associateInput(2,myBias);
op->associateInput(3,myMean);
op->associateInput(4,myVar);
op->computeOutputDims();
op->forward();
float* computedOutput = new float[myOutput->size()]();
cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
for(int i = 0; i < myOutput->size(); i++){
const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
std::cout << "Computed : " << computedOutput[i] << " , target: " << targetOutput << std::endl;
REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-5);
}
delete[] computedOutput;
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment