diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index e9e1e9cb990603938d62bd06e05cc358d9325b88..a9f62a5fdc6d3540fa460fb556894ecba75a9735 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -145,12 +145,10 @@ public: } // padding is not a parameter of Conv_Op. It is handled in Pad_Op Operator - // Width - std::vector<DimSize_t> inputDims; - inputDims.push_back(outputDims[0]); // same batch value - inputDims.push_back(mInputs[0]->dims()[1]); // every input channel is used - - for (DimIdx_t i = 0; i < DIM; ++i) { + // Input + // same batch value, every input channel is used + std::vector<DimSize_t> inputDims{outputDims[0], mInputs[0]->dims()[1]}; + for (DimIdx_t i = 0; i < DIM; ++i) { inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1) * this->template getAttr<ConvAttr::StrideDims>()[static_cast<std::size_t>(i)] + 1 @@ -158,8 +156,23 @@ public: * this->template getAttr<ConvAttr::DilationDims>()[static_cast<std::size_t>(i)]); inputIdxDims[2+i] *= this->template getAttr<ConvAttr::StrideDims>()[static_cast<std::size_t>(i)]; } - std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res = std::vector<std::pair<std::size_t, std::vector<DimSize_t>>>(); + + // Weight + // same output value, every input channel is used + std::vector<DimSize_t> weightDims{outputDims[0], mInputs[0]->dims()[1]}; + weightDims.insert(weightDims.end(), this->template getAttr<ConvAttr::KernelDims>()[0], this->template getAttr<ConvAttr::KernelDims>()[static_cast<std::size_t>(DIM)]); + std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0); + weightIdxDims[0] = outputIdxDims[1]; + + // Bias + const std::vector<DimSize_t> biasDims{outputDims[0]}; + const std::vector<DimSize_t> biasIdxDims{outputIdxDims[1]}; + + // Result + std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res; res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInputs[0]->getIdx(inputIdxDims), inputDims)); + res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(weightIdxDims, weightDims)); + res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(biasIdxDims, biasDims)); return res; } AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 63afa325b00c8e672963e2aec494cd7dc2b4e05c..2cc7ad86e012048e7bf0824cccd4f038d1ec4971 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -38,6 +38,9 @@ std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>> Aidge::Operat if (outputIdx >= nbOutputs()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator output index out of range."); } + if (nbInputs() != nbDataInputs()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator has attributes. Must be handled in an overrided function."); + } if (!outputDimsForwarded() || getOutput(0)->nbDims() != outputDims.size()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); }