Skip to content
Snippets Groups Projects
Commit db0b1f71 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : updated conv forward dims attribute accession functions

parent dc105c9b
No related branches found
No related tags found
No related merge requests found
...@@ -57,14 +57,12 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { ...@@ -57,14 +57,12 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
std::array<DimSize_t, DIM + 2> outputDims{}; std::array<DimSize_t, DIM + 2> outputDims{};
const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
for (std::size_t dim = 0; dim < mAttributes->template getAttr<ConvAttr::KernelDims>().size() ; ++dim) { for (std::size_t dim = 0; dim < kernelDims().size() ; ++dim) {
const DimSize_t kernelExtent = mAttributes->template getAttr<ConvAttr::DilationDims>()[dim] * const DimSize_t kernelExtent = dilationDims()[dim] * (kernelDims()[dim] - 1) + 1;
(mAttributes->template getAttr<ConvAttr::KernelDims>()[dim] - 1) +
1;
outputDims[dim+2] = 1 + static_cast<DimSize_t>( outputDims[dim+2] = 1 + static_cast<DimSize_t>(
floor(static_cast<float>(inputDims[dim+2] - kernelExtent) / floor(static_cast<float>(inputDims[dim+2] - kernelExtent) /
static_cast<float>(mAttributes->template getAttr<ConvAttr::StrideDims>()[dim]))); static_cast<float>(strideDims()[dim])));
} }
outputDims[1] = outChannels(); outputDims[1] = outChannels();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment