Skip to content
Snippets Groups Projects
Commit f6b039fb authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : [Conv] updated the method for attribute accession to ease code reading

parent cb45da44
No related branches found
No related tags found
No related merge requests found
...@@ -70,14 +70,14 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { ...@@ -70,14 +70,14 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2;
unsigned int out_dims_index = (getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; unsigned int out_dims_index = (getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2;
for (std::size_t dim = 0; dim < mAttributes->template getAttr<Attr::KernelDims>().size(); ++dim) { for (std::size_t dim = 0; dim < kernelDims().size(); ++dim) {
const DimSize_t kernelExtent = mAttributes->template getAttr<Attr::DilationDims>()[dim] * const DimSize_t kernelExtent = dilationDims()[dim] *
(mAttributes->template getAttr<Attr::KernelDims>()[dim] - 1) + (kernelDims()[dim] - 1) +
1; 1;
outputDims[dim + out_dims_index] = 1 + static_cast<DimSize_t>( outputDims[dim + out_dims_index] = 1 + static_cast<DimSize_t>(
floor(static_cast<float>(inputDims[dim + in_dims_index] - kernelExtent) / floor(static_cast<float>(inputDims[dim + in_dims_index] - kernelExtent) /
static_cast<float>(mAttributes->template getAttr<Attr::StrideDims>()[dim])) static_cast<float>(strideDims()[dim]))
); );
} }
...@@ -123,18 +123,18 @@ Aidge::Conv_Op<DIM>::computeReceptiveField( ...@@ -123,18 +123,18 @@ Aidge::Conv_Op<DIM>::computeReceptiveField(
std::vector<DimSize_t> inputDims{outputDims[0], getInput(0)->dims()[1]}; std::vector<DimSize_t> inputDims{outputDims[0], getInput(0)->dims()[1]};
for (DimIdx_t i = 0; i < DIM; ++i) { for (DimIdx_t i = 0; i < DIM; ++i) {
inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1) inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1)
* mAttributes->template getAttr<Attr::StrideDims>()[static_cast<std::size_t>(i)] * strideDims()[static_cast<std::size_t>(i)]
+ 1 + 1
+ (mAttributes->template getAttr<Attr::KernelDims>()[static_cast<std::size_t>(i)] - 1) + (kernelDims()[static_cast<std::size_t>(i)] - 1)
* mAttributes->template getAttr<Attr::DilationDims>()[static_cast<std::size_t>(i)]); * dilationDims()[static_cast<std::size_t>(i)]);
inputIdxDims[2+i] *= mAttributes->template getAttr<Attr::StrideDims>()[static_cast<std::size_t>(i)]; inputIdxDims[2+i] *= strideDims()[static_cast<std::size_t>(i)];
} }
// Weight // Weight
// same output value, every input channel is used // same output value, every input channel is used
std::vector<DimSize_t> weightDims{outputDims[1], getInput(0)->dims()[1]}; std::vector<DimSize_t> weightDims{outputDims[1], getInput(0)->dims()[1]};
for (std::size_t i = 0; i < DIM; ++i) { for (std::size_t i = 0; i < DIM; ++i) {
weightDims.push_back(mAttributes->template getAttr<Attr::KernelDims>()[i]); weightDims.push_back(kernelDims()[i]);
} }
std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0); std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0);
weightIdxDims[0] = firstEltDims[1]; weightIdxDims[0] = firstEltDims[1];
...@@ -242,4 +242,4 @@ std::shared_ptr<Aidge::Node> Aidge::Conv( ...@@ -242,4 +242,4 @@ std::shared_ptr<Aidge::Node> Aidge::Conv(
} }
template std::shared_ptr<Aidge::Node> Aidge::Conv<1>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[1], const std::string&, const std::array<Aidge::DimSize_t, 1>&, const std::array<Aidge::DimSize_t, 1>&, bool); template std::shared_ptr<Aidge::Node> Aidge::Conv<1>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[1], const std::string&, const std::array<Aidge::DimSize_t, 1>&, const std::array<Aidge::DimSize_t, 1>&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Conv<2>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[2], const std::string&, const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&, bool); template std::shared_ptr<Aidge::Node> Aidge::Conv<2>(Aidge::DimSize_t, Aidge::DimSize_t, Aidge::DimSize_t const (&)[2], const std::string&, const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&, bool);
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment