Skip to content
Snippets Groups Projects
Commit d339a1f5 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : [Conv] updated the method for attribute accession to ease code reading

parent 3cd8d3fe
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !319. Comments created here will be created in the context of that merge request.
...@@ -70,14 +70,14 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { ...@@ -70,14 +70,14 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2;
unsigned int out_dims_index = (getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; unsigned int out_dims_index = (getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2;
for (std::size_t dim = 0; dim < mAttributes->template getAttr<Attr::KernelDims>().size(); ++dim) { for (std::size_t dim = 0; dim < kernelDims().size(); ++dim) {
const DimSize_t kernelExtent = mAttributes->template getAttr<Attr::DilationDims>()[dim] * const DimSize_t kernelExtent = dilationDims()[dim] *
(mAttributes->template getAttr<Attr::KernelDims>()[dim] - 1) + (kernelDims()[dim] - 1) +
1; 1;
outputDims[dim + out_dims_index] = 1 + static_cast<DimSize_t>( outputDims[dim + out_dims_index] = 1 + static_cast<DimSize_t>(
floor(static_cast<float>(inputDims[dim + in_dims_index] - kernelExtent) / floor(static_cast<float>(inputDims[dim + in_dims_index] - kernelExtent) /
static_cast<float>(mAttributes->template getAttr<Attr::StrideDims>()[dim])) static_cast<float>(strideDims()[dim]))
); );
} }
...@@ -123,18 +123,18 @@ Aidge::Conv_Op<DIM>::computeReceptiveField( ...@@ -123,18 +123,18 @@ Aidge::Conv_Op<DIM>::computeReceptiveField(
std::vector<DimSize_t> inputDims{outputDims[0], getInput(0)->dims()[1]}; std::vector<DimSize_t> inputDims{outputDims[0], getInput(0)->dims()[1]};
for (DimIdx_t i = 0; i < DIM; ++i) { for (DimIdx_t i = 0; i < DIM; ++i) {
inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1) inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1)
* mAttributes->template getAttr<Attr::StrideDims>()[static_cast<std::size_t>(i)] * strideDims()[static_cast<std::size_t>(i)]
+ 1 + 1
+ (mAttributes->template getAttr<Attr::KernelDims>()[static_cast<std::size_t>(i)] - 1) + (kernelDims()[static_cast<std::size_t>(i)] - 1)
* mAttributes->template getAttr<Attr::DilationDims>()[static_cast<std::size_t>(i)]); * dilationDims()[static_cast<std::size_t>(i)]);
inputIdxDims[2+i] *= mAttributes->template getAttr<Attr::StrideDims>()[static_cast<std::size_t>(i)]; inputIdxDims[2+i] *= strideDims()[static_cast<std::size_t>(i)];
} }
// Weight // Weight
// same output value, every input channel is used // same output value, every input channel is used
std::vector<DimSize_t> weightDims{outputDims[1], getInput(0)->dims()[1]}; std::vector<DimSize_t> weightDims{outputDims[1], getInput(0)->dims()[1]};
for (std::size_t i = 0; i < DIM; ++i) { for (std::size_t i = 0; i < DIM; ++i) {
weightDims.push_back(mAttributes->template getAttr<Attr::KernelDims>()[i]); weightDims.push_back(kernelDims()[i]);
} }
std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0); std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0);
weightIdxDims[0] = firstEltDims[1]; weightIdxDims[0] = firstEltDims[1];
...@@ -242,4 +242,4 @@ std::shared_ptr<Aidge::Node> Aidge::Conv( ...@@ -242,4 +242,4 @@ std::shared_ptr<Aidge::Node> Aidge::Conv(
} }
template std::shared_ptr<Aidge::Node> Aidge::Conv<1>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[1], const std::string&, const std::array<Aidge::DimSize_t, 1>&, const std::array<Aidge::DimSize_t, 1>&, bool); template std::shared_ptr<Aidge::Node> Aidge::Conv<1>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[1], const std::string&, const std::array<Aidge::DimSize_t, 1>&, const std::array<Aidge::DimSize_t, 1>&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Conv<2>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[2], const std::string&, const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&, bool); template std::shared_ptr<Aidge::Node> Aidge::Conv<2>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[2], const std::string&, const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&, bool);
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment