Skip to content
Snippets Groups Projects
Commit ab8c1701 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] Remove 'NbInputs' for Add and Concat attributes

parent ba4f14bc
No related branches found
No related tags found
No related merge requests found
......@@ -32,15 +32,13 @@ private:
// FIXME: change accessibility
std::vector<std::shared_ptr<Tensor>> mInputs;
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
const IOIndex_t mNbInputs;
public:
static constexpr const char* Type = "Add";
Add_Op(const IOIndex_t nbIn)
: Operator(Type),
mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())),
mNbInputs(nbIn)
mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>()))
{
assert(nbIn > 0 && "Add should have at least one input");
setDatatype(DataType::Float32);
......@@ -52,12 +50,11 @@ public:
*/
Add_Op(const Add_Op& op)
: Operator(Type),
mInputs(op.mInputs),
mNbInputs(op.mNbInputs),
mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs())),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
assert(mNbInputs > 0 && "Add should have at least one input");
assert(op.nbInputs() > 0 && "Add should have at least one input");
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Add_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
......@@ -80,7 +77,7 @@ public:
// }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
......@@ -90,10 +87,10 @@ public:
if (!mInputs[0]->empty()) {
const auto expectedDims = mInputs[0]->dims();
std::size_t nonEmptyInputTensor = 1;
for (; nonEmptyInputTensor < mNbInputs && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) {
for (; nonEmptyInputTensor < nbInputs() && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) {
assert(expectedDims == mInputs[nonEmptyInputTensor]->dims());
}
if (nonEmptyInputTensor == mNbInputs) {
if (nonEmptyInputTensor == nbInputs()) {
mOutput->resize(expectedDims);
}
}
......@@ -101,8 +98,8 @@ public:
bool outputDimsForwarded() const override final {
std::size_t forwarded = 0;
for (; forwarded < mNbInputs && (!mInputs[forwarded]->empty()); ++forwarded) {}
return ((forwarded==mNbInputs) && !(mOutput->empty()));
for (; forwarded < nbInputs() && (!mInputs[forwarded]->empty()); ++forwarded) {}
return ((forwarded==nbInputs()) && !(mOutput->empty()));
}
// void checkDims() const override final {
......@@ -112,13 +109,13 @@ public:
// }
// }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
......@@ -128,7 +125,7 @@ public:
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator.");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
......@@ -143,7 +140,7 @@ public:
mOutput->setBackend(name);
// FIXME: temporary workaround
for (std::size_t i = 0; i < mNbInputs; ++i) {
for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setBackend(name);
}
}
......@@ -152,13 +149,13 @@ public:
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
for (std::size_t i = 0; i < mNbInputs; ++i) {
for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setDatatype(datatype);
}
}
inline IOIndex_t nbInputs() const noexcept override final { return mNbInputs; }
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbInputs; }
inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
......
......@@ -26,11 +26,11 @@
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ConcatAttr { NbInputs, Axis };
enum class ConcatAttr { Axis };
class Concat_Op : public Operator,
public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>,
public StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t> {
public StaticAttributes<ConcatAttr, DimSize_t> {
private:
// FIXME: change accessibility
std::vector<std::shared_ptr<Tensor>> mInputs;
......@@ -39,15 +39,14 @@ private:
public:
static constexpr const char* Type = "Concat";
using Attributes_ = StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t>;
using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>;
template <ConcatAttr e>
using attr = typename Attributes_::template attr<e>;
Concat_Op(const IOIndex_t nbIn, const DimSize_t axis)
: Operator(Type),
mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())),
Attributes_(attr<ConcatAttr::NbInputs>(nbIn),
attr<ConcatAttr::Axis>(axis))
Attributes_(attr<ConcatAttr::Axis>(axis))
{
assert(nbIn > 0 && "Concat should have at least one input");
setDatatype(DataType::Float32);
......@@ -60,11 +59,11 @@ public:
Concat_Op(const Concat_Op& op)
: Operator(Type),
Attributes_(op),
mInputs(std::vector<std::shared_ptr<Tensor>>(op.getAttr<ConcatAttr::NbInputs>())),
mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs(), std::make_shared<Tensor>())),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
assert(op.getAttr<ConcatAttr::NbInputs>() > 0 && "Concat should have at least one input");
assert(op.nbInputs() > 0 && "Concat should have at least one input");
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
......@@ -87,7 +86,7 @@ public:
// }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
......@@ -103,7 +102,7 @@ public:
if (computable) {
auto outputDims = mInputs[0]->dims();
for (std::size_t i = 1; i < getAttr<ConcatAttr::NbInputs>(); ++i) {
for (std::size_t i = 1; i < nbInputs(); ++i) {
outputDims[getAttr<ConcatAttr::Axis>()] += mInputs[i]->dims()[getAttr<ConcatAttr::Axis>()];
}
mOutput->resize(outputDims);
......@@ -121,13 +120,13 @@ public:
// }
// }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
......@@ -137,7 +136,7 @@ public:
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator.");
assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator.");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
......@@ -152,7 +151,7 @@ public:
mOutput->setBackend(name);
// FIXME: temporary workaround
for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) {
for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setBackend(name);
}
}
......@@ -161,13 +160,13 @@ public:
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) {
for (std::size_t i = 0; i < nbInputs(); ++i) {
mInputs[i]->setDatatype(datatype);
}
}
inline IOIndex_t nbInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); }
inline IOIndex_t nbDataInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); }
inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
......@@ -186,7 +185,6 @@ inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis =
namespace {
template <>
const char* const EnumStrings<Aidge::ConcatAttr>::data[] = {
"NbInputs",
"Axis"
};
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment