Skip to content
Snippets Groups Projects
Commit 6fc040ea authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

fix : [Conv] added check to ensure that dilation & stride values were all >= 1

Also better warning message for conv operator constructor
parent db0b1f71
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !319. Comments created here will be created in the context of that merge request.
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#ifndef AIDGE_CORE_OPERATOR_CONV_H_ #ifndef AIDGE_CORE_OPERATOR_CONV_H_
#define AIDGE_CORE_OPERATOR_CONV_H_ #define AIDGE_CORE_OPERATOR_CONV_H_
#include <algorithm>
#include <array> #include <array>
#include <cmath> // std::floor #include <cmath> // std::floor
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
...@@ -237,10 +238,16 @@ std::shared_ptr<Node> Conv(DimSize_t inChannels, ...@@ -237,10 +238,16 @@ std::shared_ptr<Node> Conv(DimSize_t inChannels,
bool noBias = false); bool noBias = false);
/** /**
* @brief Helper function for Conv with C-style arrays. * @brief Perform a convolution on the input Tensor.
* *
* This helper function allows automatic template deduction of the number of dimensions (DIM) * @tparam DIM Number of dimensions for the feature map.
* based on the kernel dimensions provided. * @param inChannels Number of input channels.
* @param outChannels Number of output channels.
* @param kernelDims Dimensions of the kernel. Must be the same number of dimensions as the feature map.
* @param name Name of the operator.
* @param strideDims Dimensions of the stride attribute. Must be the same number of dimensions as the feature map.
* @param dilationDims Dimensions of the dilation attribute. Must be the same number of dimensions as the feature map.
* @return std::shared_ptr<Node> A Node containing the operator.
*/ */
template <DimSize_t DIM> template <DimSize_t DIM>
inline std::shared_ptr<Node> Conv( inline std::shared_ptr<Node> Conv(
...@@ -251,8 +258,22 @@ inline std::shared_ptr<Node> Conv( ...@@ -251,8 +258,22 @@ inline std::shared_ptr<Node> Conv(
const std::array<DimSize_t, DIM> &strideDims = create_array<DimSize_t,DIM>(1), const std::array<DimSize_t, DIM> &strideDims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, DIM> &dilationDims = create_array<DimSize_t,DIM>(1), const std::array<DimSize_t, DIM> &dilationDims = create_array<DimSize_t,DIM>(1),
bool noBias = false) { bool noBias = false) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); AIDGE_ASSERT(DIM<=MaxDim,"{}: Too many kernel dimensions required, maximum allowed : {} ", Conv_Op<DIM>::Type, MaxDim);
return Conv(inChannels, outChannels, to_array(kernelDims), name, strideDims, dilationDims, noBias); AIDGE_ASSERT(!std::any_of(dilationDims.cbegin(),
dilationDims.cend(),
[](DimSize_t val) { return val == 0; }),
"Conv : at least of of the dilation dimension is 0, expecting "
"strictly positive values. Got {}",
Conv_Op<DIM>::Type,
dilationDims);
AIDGE_ASSERT(!std::any_of(strideDims.cbegin(),
strideDims.cend(),
[](DimSize_t val) { return val == 0; }),
"{}: at least one of the stride dimension is 0, expecting "
"strictly positive values. Got {}.",
Conv_Op<DIM>::Type,
strideDims);
return Conv<DIM>(inChannels, outChannels, to_array(kernelDims), name, strideDims, dilationDims, noBias);
} }
} // namespace Aidge } // namespace Aidge
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment