Skip to content
Snippets Groups Projects
Commit 3fb21e1a authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] Attributes receptive fields

parent ed10a0f2
No related branches found
No related tags found
2 merge requests!46Remove Operator reference to Tensor,!20Draft: Introduction of Tiling
Pipeline #34233 failed
......@@ -145,12 +145,10 @@ public:
}
// padding is not a parameter of Conv_Op. It is handled in Pad_Op Operator
// Width
std::vector<DimSize_t> inputDims;
inputDims.push_back(outputDims[0]); // same batch value
inputDims.push_back(mInputs[0]->dims()[1]); // every input channel is used
for (DimIdx_t i = 0; i < DIM; ++i) {
// Input
// same batch value, every input channel is used
std::vector<DimSize_t> inputDims{outputDims[0], mInputs[0]->dims()[1]};
for (DimIdx_t i = 0; i < DIM; ++i) {
inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1)
* this->template getAttr<ConvAttr::StrideDims>()[static_cast<std::size_t>(i)]
+ 1
......@@ -158,8 +156,23 @@ public:
* this->template getAttr<ConvAttr::DilationDims>()[static_cast<std::size_t>(i)]);
inputIdxDims[2+i] *= this->template getAttr<ConvAttr::StrideDims>()[static_cast<std::size_t>(i)];
}
std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res = std::vector<std::pair<std::size_t, std::vector<DimSize_t>>>();
// Weight
// same output value, every input channel is used
std::vector<DimSize_t> weightDims{outputDims[0], mInputs[0]->dims()[1]};
weightDims.insert(weightDims.end(), this->template getAttr<ConvAttr::KernelDims>()[0], this->template getAttr<ConvAttr::KernelDims>()[static_cast<std::size_t>(DIM)]);
std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0);
weightIdxDims[0] = outputIdxDims[1];
// Bias
const std::vector<DimSize_t> biasDims{outputDims[0]};
const std::vector<DimSize_t> biasIdxDims{outputIdxDims[1]};
// Result
std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res;
res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInputs[0]->getIdx(inputIdxDims), inputDims));
res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(weightIdxDims, weightDims));
res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(biasIdxDims, biasDims));
return res;
}
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
......
......@@ -38,6 +38,9 @@ std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>> Aidge::Operat
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator output index out of range.");
}
if (nbInputs() != nbDataInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator has attributes. Must be handled in an overrided function.");
}
if (!outputDimsForwarded() || getOutput(0)->nbDims() != outputDims.size()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment