Skip to content
Snippets Groups Projects
Commit 1520a9a0 authored by Maxence Naud's avatar Maxence Naud
Browse files

Propagate the new class hierarchy consequences to Conv and AvgPooling

parent 105f3960
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
...@@ -29,15 +29,11 @@ namespace Aidge { ...@@ -29,15 +29,11 @@ namespace Aidge {
enum class AvgPoolingAttr { StrideDims, KernelDims }; enum class AvgPoolingAttr { StrideDims, KernelDims };
template <DimIdx_t DIM> template <DimIdx_t DIM>
class AvgPooling_Op : public Operator, class AvgPooling_Op : public OperatorTensor,
public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>, public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>,
public StaticAttributes<AvgPoolingAttr, public StaticAttributes<AvgPoolingAttr,
std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>> { std::array<DimSize_t, DIM>> {
private:
// FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public: public:
static constexpr const char *Type = "AvgPooling"; static constexpr const char *Type = "AvgPooling";
...@@ -52,10 +48,10 @@ public: ...@@ -52,10 +48,10 @@ public:
constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1))
: Operator(Type), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims), Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims),
attr<AvgPoolingAttr::KernelDims>(kernel_dims)) { attr<AvgPoolingAttr::KernelDims>(kernel_dims)) {
setDatatype(DataType::Float32); setDataType(DataType::Float32);
} }
/** /**
...@@ -63,12 +59,12 @@ public: ...@@ -63,12 +59,12 @@ public:
* @param op Operator to copy. * @param op Operator to copy.
*/ */
AvgPooling_Op(const AvgPooling_Op<DIM>& op) AvgPooling_Op(const AvgPooling_Op<DIM>& op)
: Operator(Type), : OperatorTensor(Type, 1, 0, 1),
Attributes_(op), Attributes_(op),
mOutput(std::make_shared<Tensor>(*op.mOutput)) mOutput(std::make_shared<Tensor>(*op.mOutput))
{ {
// cpy-ctor // cpy-ctor
setDatatype(op.mOutput->dataType()); setDataType(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr;
} }
...@@ -80,64 +76,23 @@ public: ...@@ -80,64 +76,23 @@ public:
return std::make_shared<AvgPooling_Op<DIM>>(*this); return std::make_shared<AvgPooling_Op<DIM>>(*this);
} }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 1 && "operators supports only 3 inputs");
(void) inputIdx; // avoid unused warning
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
mInput = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final { void computeOutputDims() override final {
if (!mInput->empty()) { if (!*mInputs[0]->empty()) {
std::array<DimSize_t, DIM + 2> outputDims = {}; std::array<DimSize_t, DIM + 2> outputDims = {};
for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) { for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) {
outputDims[dim+2] = 1 + static_cast<DimSize_t>( outputDims[dim+2] = 1 + static_cast<DimSize_t>(
std::floor(static_cast<float>(mInput->dims()[dim+2] - std::floor(static_cast<float>(*mInputs[0]->dims()[dim+2] -
this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) / this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) /
static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim]))); static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim])));
} }
outputDims[1] = mInput->dims()[1]; outputDims[1] = *mInputs[0]->dims()[1];
outputDims[0] = mInput->dims()[0]; outputDims[0] = *mInputs[0]->dims()[0];
mOutput->resize(outputDims); mOutputs[0]->resize(outputDims);
} }
} }
bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "operators supports only 1 inputs");
(void) inputIdx; // avoid unused warning
return *(mInput.get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "AvgPooling Operators supports only 1 inputs");
(void) inputIdx; // avoid unused warning
return mInput;
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "AvgPooling Operators has only 1 outputs");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "operators supports only 1 inputs");
(void) inputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mInput);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string &name) override { void setBackend(const std::string &name) override {
mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this);
...@@ -147,16 +102,6 @@ public: ...@@ -147,16 +102,6 @@ public:
mInput->setBackend(name); mInput->setBackend(name);
} }
void setDatatype(const DataType &datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInput->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input"};
} }
...@@ -190,4 +135,4 @@ const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims", ...@@ -190,4 +135,4 @@ const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims",
"KernelDims"}; "KernelDims"};
} }
#endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */ #endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */
\ No newline at end of file
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
...@@ -29,17 +29,12 @@ namespace Aidge { ...@@ -29,17 +29,12 @@ namespace Aidge {
enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims }; enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims };
template <DimIdx_t DIM> template <DimIdx_t DIM>
class Conv_Op : public Operator, class Conv_Op : public OperatorTensor,
public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
DimSize_t, std::array<DimSize_t, DIM>> { DimSize_t, std::array<DimSize_t, DIM>> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(),
std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public: public:
static constexpr const char *Type = "Conv"; static constexpr const char *Type = "Conv";
Conv_Op() = delete; Conv_Op() = delete;
...@@ -54,13 +49,13 @@ public: ...@@ -54,13 +49,13 @@ public:
const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &kernel_dims,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
: Operator(Type), : OperatorTensor(Type, 1, 2, 1),
Attributes_(attr<ConvAttr::StrideDims>(stride_dims), Attributes_(attr<ConvAttr::StrideDims>(stride_dims),
attr<ConvAttr::DilationDims>(dilation_dims), attr<ConvAttr::DilationDims>(dilation_dims),
attr<ConvAttr::InChannels>(in_channels), attr<ConvAttr::InChannels>(in_channels),
attr<ConvAttr::OutChannels>(out_channels), attr<ConvAttr::OutChannels>(out_channels),
attr<ConvAttr::KernelDims>(kernel_dims)) { attr<ConvAttr::KernelDims>(kernel_dims)) {
setDatatype(DataType::Float32); setDataType(DataType::Float32);
} }
/** /**
...@@ -68,12 +63,12 @@ public: ...@@ -68,12 +63,12 @@ public:
* @param op Operator to copy. * @param op Operator to copy.
*/ */
Conv_Op(const Conv_Op<DIM>& op) Conv_Op(const Conv_Op<DIM>& op)
: Operator(Type), : OperatorTensor(Type, 1, 2, 1),
Attributes_(op), Attributes_(op),
mOutput(std::make_shared<Tensor>(*op.mOutput)) mOutput(std::make_shared<Tensor>(*op.mOutput))
{ {
// cpy-ctor // cpy-ctor
setDatatype(op.mOutput->dataType()); setDataType(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr;
} }
...@@ -98,13 +93,6 @@ public: ...@@ -98,13 +93,6 @@ public:
// } // }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final { void computeOutputDims() override final {
if (!mInputs[0]->empty()) { if (!mInputs[0]->empty()) {
std::array<DimSize_t, DIM + 2> outputDims = {}; std::array<DimSize_t, DIM + 2> outputDims = {};
...@@ -125,37 +113,6 @@ public: ...@@ -125,37 +113,6 @@ public:
} }
} }
bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
return *(mInputs[inputIdx].get()); }
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "Conv Operators supports only 3 inputs");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Conv Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string &name) override { void setBackend(const std::string &name) override {
mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
mOutput->setBackend(name); mOutput->setBackend(name);
...@@ -165,18 +122,6 @@ public: ...@@ -165,18 +122,6 @@ public:
mInputs[2]->setBackend(name); mInputs[2]->setBackend(name);
} }
void setDatatype(const DataType &datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(datatype);
mInputs[2]->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"}; return {"data_input", "weight", "bias"};
} }
......
...@@ -25,25 +25,25 @@ Aidge::OperatorImpl::OperatorImpl(const Operator& op): ...@@ -25,25 +25,25 @@ Aidge::OperatorImpl::OperatorImpl(const Operator& op):
} }
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getInput(inputIdx) && "requires valid input"); assert(mOp.getRawInput(inputIdx) && "requires valid input");
// Requires the whole tensor by default // Requires the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
} }
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
assert(mOp.getInput(inputIdx) && "requires valid input"); assert(mOp.getRawInput(inputIdx) && "requires valid input");
// Protect the whole tensor by default // Protect the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
} }
Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
assert(mOp.getOutput(outputIdx) && "requires valid output"); assert(mOp.getRawOutput(outputIdx) && "requires valid output");
// Requires the whole tensor by default, regardless of available data on inputs // Requires the whole tensor by default, regardless of available data on inputs
return std::static_pointer_cast<Tensor>(mOp.getOutput(outputIdx))->size(); return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size();
} }
Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment