Skip to content
Snippets Groups Projects

Lenet operators

Merged Houssem ROUIS requested to merge hrouis/aidge_backend_cuda:lenet_operators into dev
3 files
+ 23
14
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -24,14 +24,16 @@
template <Aidge::DimIdx_t DIM>
void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(mOp.getRawInput(0) && "missing input #0");
std::shared_ptr<Tensor> inputFallback;
const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
// Lazy-initialize CuDNN AvgPooling descriptor
if (mAvgPoolingDesc == nullptr) {
const AvgPooling_Op<DIM>& avgPoolingOp = static_cast<const AvgPooling_Op<DIM>&>(mOp);
const AvgPooling_Op<DIM>& avgPoolingOp = static_cast<const AvgPooling_Op<DIM>&>(op);
const std::vector<int> strides(avgPoolingOp.template getAttr<AvgPoolingAttr::StrideDims>().begin(), avgPoolingOp.template getAttr<AvgPoolingAttr::StrideDims>().end());
const std::vector<int> paddings(DIM, 0);
const std::vector<int> window_dims(avgPoolingOp.template getAttr<AvgPoolingAttr::KernelDims>().begin(), avgPoolingOp.template getAttr<AvgPoolingAttr::KernelDims>().end());
@@ -58,6 +60,7 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::AvgPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T alpha = 1.0f;
const T beta = 0.0f;
CHECK_CUDNN_STATUS(
@@ -65,11 +68,11 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
CudaContext::cudnnHandle(),
mAvgPoolingDesc,
&alpha,
dynamic_cast<TensorImpl_cuda_*>(input.getImpl().get())->getCudnnTensorDesc(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
input.getImpl()->rawPtr(),
&beta,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()
)
);
}
Loading