-
Houssem ROUIS authoredHoussem ROUIS authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
MaxPoolingImpl.cpp 3.46 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
template <Aidge::DimIdx_t DIM>
void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(mOp.getRawInput(0) && "missing input #0");
std::shared_ptr<Tensor> inputFallback;
const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(op.getRawOutput(0)));
// Lazy-initialize CuDNN MaxPooling descriptor
if (mMaxPoolingDesc == nullptr) {
const MaxPooling_Op<DIM>& maxPoolingOp = static_cast<const MaxPooling_Op<DIM>&>(op);
const std::vector<int> strides(maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().end());
const std::vector<int> paddings(DIM, 0);
const std::vector<int> window_dims(maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().end());
CHECK_CUDNN_STATUS(cudnnCreatePoolingDescriptor(&mMaxPoolingDesc));
CHECK_CUDNN_STATUS(
cudnnSetPoolingNdDescriptor(mMaxPoolingDesc,
mMode,
CUDNN_NOT_PROPAGATE_NAN,
DIM,
&window_dims[0],
&paddings[0],
&strides[0]));
}
if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
forward_<double>(input);
}
else {
forward_<float>(input);
}
}
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::MaxPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T alpha = 1.0f;
const T beta = 0.0f;
CHECK_CUDNN_STATUS(
cudnnPoolingForward(
CudaContext::cudnnHandle(),
mMaxPoolingDesc,
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
input.getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()
)
);
}
template <Aidge::DimIdx_t DIM>
Aidge::MaxPoolingImpl_cuda<DIM>::~MaxPoolingImpl_cuda() {
if(mMaxPoolingDesc != nullptr)
cudnnDestroyPoolingDescriptor(mMaxPoolingDesc);
}
// Template declarations
template class Aidge::MaxPoolingImpl_cuda<2>;