Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
MaxPoolingImpl.cpp 3.46 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <chrono>  // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread>  // std::this_thread::sleep_for
#include <vector>

#include "aidge/utils/Types.h"
#include "aidge/operator/MaxPooling.hpp"

#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"

template <Aidge::DimIdx_t DIM>
void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);

    assert(mOp.getRawInput(0) && "missing input #0");

    std::shared_ptr<Tensor> inputFallback;
    const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(op.getRawOutput(0)));

    // Lazy-initialize CuDNN MaxPooling descriptor
    if (mMaxPoolingDesc == nullptr) {
        const MaxPooling_Op<DIM>& maxPoolingOp = static_cast<const MaxPooling_Op<DIM>&>(op);
        const std::vector<int> strides(maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().end());
        const std::vector<int> paddings(DIM, 0);
        const std::vector<int> window_dims(maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().end());

        CHECK_CUDNN_STATUS(cudnnCreatePoolingDescriptor(&mMaxPoolingDesc));
        CHECK_CUDNN_STATUS(
            cudnnSetPoolingNdDescriptor(mMaxPoolingDesc,
                                        mMode,
                                        CUDNN_NOT_PROPAGATE_NAN,
                                        DIM,
                                        &window_dims[0],
                                        &paddings[0],
                                        &strides[0]));
    }

    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
        forward_<double>(input);
    }
    else {
        forward_<float>(input);
    }
}

template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::MaxPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
    const T alpha = 1.0f;
    const T beta = 0.0f;
    CHECK_CUDNN_STATUS(
        cudnnPoolingForward(
            CudaContext::cudnnHandle(),
            mMaxPoolingDesc,
            &alpha,
            std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
            input.getImpl()->rawPtr(),
            &beta,
            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
            std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()
        )
    );
}

template <Aidge::DimIdx_t DIM>
Aidge::MaxPoolingImpl_cuda<DIM>::~MaxPoolingImpl_cuda() {
    if(mMaxPoolingDesc != nullptr)
        cudnnDestroyPoolingDescriptor(mMaxPoolingDesc);
}


// Template declarations
template class Aidge::MaxPoolingImpl_cuda<2>;