diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp index d69aad616bcdaedd7ffa9cdb04d02802bb998f5a..8d5f3332218c9fb0dbdb9398de254cfcccae9d7e 100644 --- a/src/operator/Conv.cpp +++ b/src/operator/Conv.cpp @@ -70,14 +70,14 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; unsigned int out_dims_index = (getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; - for (std::size_t dim = 0; dim < mAttributes->template getAttr<Attr::KernelDims>().size(); ++dim) { - const DimSize_t kernelExtent = mAttributes->template getAttr<Attr::DilationDims>()[dim] * - (mAttributes->template getAttr<Attr::KernelDims>()[dim] - 1) + + for (std::size_t dim = 0; dim < kernelDims().size(); ++dim) { + const DimSize_t kernelExtent = dilationDims()[dim] * + (kernelDims()[dim] - 1) + 1; outputDims[dim + out_dims_index] = 1 + static_cast<DimSize_t>( floor(static_cast<float>(inputDims[dim + in_dims_index] - kernelExtent) / - static_cast<float>(mAttributes->template getAttr<Attr::StrideDims>()[dim])) + static_cast<float>(strideDims()[dim])) ); } @@ -123,18 +123,18 @@ Aidge::Conv_Op<DIM>::computeReceptiveField( std::vector<DimSize_t> inputDims{outputDims[0], getInput(0)->dims()[1]}; for (DimIdx_t i = 0; i < DIM; ++i) { inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1) - * mAttributes->template getAttr<Attr::StrideDims>()[static_cast<std::size_t>(i)] + * strideDims()[static_cast<std::size_t>(i)] + 1 - + (mAttributes->template getAttr<Attr::KernelDims>()[static_cast<std::size_t>(i)] - 1) - * mAttributes->template getAttr<Attr::DilationDims>()[static_cast<std::size_t>(i)]); - inputIdxDims[2+i] *= mAttributes->template getAttr<Attr::StrideDims>()[static_cast<std::size_t>(i)]; + + (kernelDims()[static_cast<std::size_t>(i)] - 1) + * dilationDims()[static_cast<std::size_t>(i)]); + inputIdxDims[2+i] *= strideDims()[static_cast<std::size_t>(i)]; } // Weight // same output value, every input channel is used std::vector<DimSize_t> weightDims{outputDims[1], getInput(0)->dims()[1]}; for (std::size_t i = 0; i < DIM; ++i) { - weightDims.push_back(mAttributes->template getAttr<Attr::KernelDims>()[i]); + weightDims.push_back(kernelDims()[i]); } std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0); weightIdxDims[0] = firstEltDims[1]; @@ -242,4 +242,4 @@ std::shared_ptr<Aidge::Node> Aidge::Conv( } template std::shared_ptr<Aidge::Node> Aidge::Conv<1>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[1], const std::string&, const std::array<Aidge::DimSize_t, 1>&, const std::array<Aidge::DimSize_t, 1>&, bool); -template std::shared_ptr<Aidge::Node> Aidge::Conv<2>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[2], const std::string&, const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&, bool); \ No newline at end of file +template std::shared_ptr<Aidge::Node> Aidge::Conv<2>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[2], const std::string&, const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&, bool);