Skip to content
Snippets Groups Projects
Commit ab8c1701 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] Remove 'NbInputs' for Add and Concat attributes

parent ba4f14bc
No related branches found
No related tags found
No related merge requests found
...@@ -32,15 +32,13 @@ private: ...@@ -32,15 +32,13 @@ private:
// FIXME: change accessibility // FIXME: change accessibility
std::vector<std::shared_ptr<Tensor>> mInputs; std::vector<std::shared_ptr<Tensor>> mInputs;
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
const IOIndex_t mNbInputs;
public: public:
static constexpr const char* Type = "Add"; static constexpr const char* Type = "Add";
Add_Op(const IOIndex_t nbIn) Add_Op(const IOIndex_t nbIn)
: Operator(Type), : Operator(Type),
mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())), mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>()))
mNbInputs(nbIn)
{ {
assert(nbIn > 0 && "Add should have at least one input"); assert(nbIn > 0 && "Add should have at least one input");
setDatatype(DataType::Float32); setDatatype(DataType::Float32);
...@@ -52,12 +50,11 @@ public: ...@@ -52,12 +50,11 @@ public:
*/ */
Add_Op(const Add_Op& op) Add_Op(const Add_Op& op)
: Operator(Type), : Operator(Type),
mInputs(op.mInputs), mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs())),
mNbInputs(op.mNbInputs),
mOutput(std::make_shared<Tensor>(*op.mOutput)) mOutput(std::make_shared<Tensor>(*op.mOutput))
{ {
// cpy-ctor // cpy-ctor
assert(mNbInputs > 0 && "Add should have at least one input"); assert(op.nbInputs() > 0 && "Add should have at least one input");
setDatatype(op.mOutput->dataType()); setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Add_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Add_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
} }
...@@ -80,7 +77,7 @@ public: ...@@ -80,7 +77,7 @@ public:
// } // }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
...@@ -90,10 +87,10 @@ public: ...@@ -90,10 +87,10 @@ public:
if (!mInputs[0]->empty()) { if (!mInputs[0]->empty()) {
const auto expectedDims = mInputs[0]->dims(); const auto expectedDims = mInputs[0]->dims();
std::size_t nonEmptyInputTensor = 1; std::size_t nonEmptyInputTensor = 1;
for (; nonEmptyInputTensor < mNbInputs && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { for (; nonEmptyInputTensor < nbInputs() && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) {
assert(expectedDims == mInputs[nonEmptyInputTensor]->dims()); assert(expectedDims == mInputs[nonEmptyInputTensor]->dims());
} }
if (nonEmptyInputTensor == mNbInputs) { if (nonEmptyInputTensor == nbInputs()) {
mOutput->resize(expectedDims); mOutput->resize(expectedDims);
} }
} }
...@@ -101,8 +98,8 @@ public: ...@@ -101,8 +98,8 @@ public:
bool outputDimsForwarded() const override final { bool outputDimsForwarded() const override final {
std::size_t forwarded = 0; std::size_t forwarded = 0;
for (; forwarded < mNbInputs && (!mInputs[forwarded]->empty()); ++forwarded) {} for (; forwarded < nbInputs() && (!mInputs[forwarded]->empty()); ++forwarded) {}
return ((forwarded==mNbInputs) && !(mOutput->empty())); return ((forwarded==nbInputs()) && !(mOutput->empty()));
} }
// void checkDims() const override final { // void checkDims() const override final {
...@@ -112,13 +109,13 @@ public: ...@@ -112,13 +109,13 @@ public:
// } // }
// } // }
inline Tensor& input(const IOIndex_t inputIdx) const override final { inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
return *(mInputs[inputIdx].get()); return *(mInputs[inputIdx].get());
} }
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
return mInputs[inputIdx]; return mInputs[inputIdx];
} }
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
...@@ -128,7 +125,7 @@ public: ...@@ -128,7 +125,7 @@ public:
} }
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
return std::static_pointer_cast<Data>(mInputs[inputIdx]); return std::static_pointer_cast<Data>(mInputs[inputIdx]);
} }
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
...@@ -143,7 +140,7 @@ public: ...@@ -143,7 +140,7 @@ public:
mOutput->setBackend(name); mOutput->setBackend(name);
// FIXME: temporary workaround // FIXME: temporary workaround
for (std::size_t i = 0; i < mNbInputs; ++i) { for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setBackend(name); mInputs[i]->setBackend(name);
} }
} }
...@@ -152,13 +149,13 @@ public: ...@@ -152,13 +149,13 @@ public:
mOutput->setDatatype(datatype); mOutput->setDatatype(datatype);
// FIXME: temporary workaround // FIXME: temporary workaround
for (std::size_t i = 0; i < mNbInputs; ++i) { for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setDatatype(datatype); mInputs[i]->setDatatype(datatype);
} }
} }
inline IOIndex_t nbInputs() const noexcept override final { return mNbInputs; } inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbInputs; } inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -26,11 +26,11 @@ ...@@ -26,11 +26,11 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
enum class ConcatAttr { NbInputs, Axis }; enum class ConcatAttr { Axis };
class Concat_Op : public Operator, class Concat_Op : public Operator,
public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>, public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>,
public StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t> { public StaticAttributes<ConcatAttr, DimSize_t> {
private: private:
// FIXME: change accessibility // FIXME: change accessibility
std::vector<std::shared_ptr<Tensor>> mInputs; std::vector<std::shared_ptr<Tensor>> mInputs;
...@@ -39,15 +39,14 @@ private: ...@@ -39,15 +39,14 @@ private:
public: public:
static constexpr const char* Type = "Concat"; static constexpr const char* Type = "Concat";
using Attributes_ = StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t>; using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>;
template <ConcatAttr e> template <ConcatAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Concat_Op(const IOIndex_t nbIn, const DimSize_t axis) Concat_Op(const IOIndex_t nbIn, const DimSize_t axis)
: Operator(Type), : Operator(Type),
mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())), mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())),
Attributes_(attr<ConcatAttr::NbInputs>(nbIn), Attributes_(attr<ConcatAttr::Axis>(axis))
attr<ConcatAttr::Axis>(axis))
{ {
assert(nbIn > 0 && "Concat should have at least one input"); assert(nbIn > 0 && "Concat should have at least one input");
setDatatype(DataType::Float32); setDatatype(DataType::Float32);
...@@ -60,11 +59,11 @@ public: ...@@ -60,11 +59,11 @@ public:
Concat_Op(const Concat_Op& op) Concat_Op(const Concat_Op& op)
: Operator(Type), : Operator(Type),
Attributes_(op), Attributes_(op),
mInputs(std::vector<std::shared_ptr<Tensor>>(op.getAttr<ConcatAttr::NbInputs>())), mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs(), std::make_shared<Tensor>())),
mOutput(std::make_shared<Tensor>(*op.mOutput)) mOutput(std::make_shared<Tensor>(*op.mOutput))
{ {
// cpy-ctor // cpy-ctor
assert(op.getAttr<ConcatAttr::NbInputs>() > 0 && "Concat should have at least one input"); assert(op.nbInputs() > 0 && "Concat should have at least one input");
setDatatype(op.mOutput->dataType()); setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
} }
...@@ -87,7 +86,7 @@ public: ...@@ -87,7 +86,7 @@ public:
// } // }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
...@@ -103,7 +102,7 @@ public: ...@@ -103,7 +102,7 @@ public:
if (computable) { if (computable) {
auto outputDims = mInputs[0]->dims(); auto outputDims = mInputs[0]->dims();
for (std::size_t i = 1; i < getAttr<ConcatAttr::NbInputs>(); ++i) { for (std::size_t i = 1; i < nbInputs(); ++i) {
outputDims[getAttr<ConcatAttr::Axis>()] += mInputs[i]->dims()[getAttr<ConcatAttr::Axis>()]; outputDims[getAttr<ConcatAttr::Axis>()] += mInputs[i]->dims()[getAttr<ConcatAttr::Axis>()];
} }
mOutput->resize(outputDims); mOutput->resize(outputDims);
...@@ -121,13 +120,13 @@ public: ...@@ -121,13 +120,13 @@ public:
// } // }
// } // }
inline Tensor& input(const IOIndex_t inputIdx) const override final { inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
return *(mInputs[inputIdx].get()); return *(mInputs[inputIdx].get());
} }
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
return mInputs[inputIdx]; return mInputs[inputIdx];
} }
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
...@@ -137,7 +136,7 @@ public: ...@@ -137,7 +136,7 @@ public:
} }
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
return std::static_pointer_cast<Data>(mInputs[inputIdx]); return std::static_pointer_cast<Data>(mInputs[inputIdx]);
} }
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
...@@ -152,7 +151,7 @@ public: ...@@ -152,7 +151,7 @@ public:
mOutput->setBackend(name); mOutput->setBackend(name);
// FIXME: temporary workaround // FIXME: temporary workaround
for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) { for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setBackend(name); mInputs[i]->setBackend(name);
} }
} }
...@@ -161,13 +160,13 @@ public: ...@@ -161,13 +160,13 @@ public:
mOutput->setDatatype(datatype); mOutput->setDatatype(datatype);
// FIXME: temporary workaround // FIXME: temporary workaround
for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) { for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setDatatype(datatype); mInputs[i]->setDatatype(datatype);
} }
} }
inline IOIndex_t nbInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); } inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbDataInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); } inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
...@@ -186,7 +185,6 @@ inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis = ...@@ -186,7 +185,6 @@ inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis =
namespace { namespace {
template <> template <>
const char* const EnumStrings<Aidge::ConcatAttr>::data[] = { const char* const EnumStrings<Aidge::ConcatAttr>::data[] = {
"NbInputs",
"Axis" "Axis"
}; };
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment