Skip to content
Snippets Groups Projects
Commit 71223519 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix python binding of concat by adding nb_in attr

parent 988ef2f4
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
...@@ -48,11 +48,15 @@ public: ...@@ -48,11 +48,15 @@ public:
using Attributes_ = StaticAttributes<ConcatAttr, int>; using Attributes_ = StaticAttributes<ConcatAttr, int>;
template <ConcatAttr e> using attr = typename Attributes_::template attr<e>; template <ConcatAttr e> using attr = typename Attributes_::template attr<e>;
Concat_Op(int axis) Concat_Op(int axis, IOIndex_t nbIn)
: Operator(Type), : Operator(Type),
Attributes_( mNbIn(nbIn),
attr<ConcatAttr::Axis>(axis)) Attributes_(attr<ConcatAttr::Axis>(axis))
{ {
mInputs = std::vector<std::shared_ptr<Tensor>>(nbIn);
for (std::size_t i = 0; i < nbIn; ++i) {
mInputs[i] = std::make_shared<Tensor>();
}
setDatatype(DataType::Float32); setDatatype(DataType::Float32);
} }
...@@ -67,12 +71,12 @@ public: ...@@ -67,12 +71,12 @@ public:
mOutput(std::make_shared<Tensor>(*op.mOutput)) mOutput(std::make_shared<Tensor>(*op.mOutput))
{ {
// cpy-ctor // cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn); mInputs = std::vector<std::shared_ptr<Tensor>>(op.mNbIn);
for (std::size_t i = 0; i < mNbIn; ++i) { for (std::size_t i = 0; i < op.mNbIn; ++i) {
mInputs[i] = std::make_shared<Tensor>(); mInputs[i] = std::make_shared<Tensor>();
} }
setDatatype(op.mOutput->dataType());
} }
/** /**
...@@ -84,30 +88,25 @@ public: ...@@ -84,30 +88,25 @@ public:
} }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
// assert(inputIdx < mNbIn && "operators supports only x inputs"); assert(inputIdx < mNbIn && "index out of bound");
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
if (strcmp(data->type(), Tensor::Type) == 0) { mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
// TODO: associate input only if of type Tensor, otherwise do nothing
if(inputIdx<mInputs.size())
mInputs.insert( mInputs.begin() + inputIdx, std::dynamic_pointer_cast<Tensor>(data));
else
mInputs.emplace_back(std::dynamic_pointer_cast<Tensor>(data));
mNbIn = mInputs.size();
}
} }
void computeOutputDims() override final { void computeOutputDims() override final {
if (!mInputs.empty() && !mInputs[0]->empty()) if (!mInputs.empty() && !mInputs[0]->empty())
{ {
// mOutput->resize(mInputs[0]->dims());
Concat_Op::Attrs attr = getStaticAttributes(); Concat_Op::Attrs attr = getStaticAttributes();
const int& axis = static_cast<const int&>(std::get<0>(attr)); const int& axis = static_cast<const int&>(std::get<0>(attr));
std::size_t dimOnAxis = 0;
for(std::size_t i=0; i<mNbIn; ++i)
{
dimOnAxis += mInputs[i]->dims()[axis];
}
std::vector<DimSize_t> outputDims; std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) { for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) {
if(i==axis) if(i==axis)
outputDims.push_back(mInputs.size() * mInputs[0]->dims()[i]); outputDims.push_back(dimOnAxis);
else else
outputDims.push_back(mInputs[0]->dims()[i]); outputDims.push_back(mInputs[0]->dims()[i]);
} }
...@@ -121,8 +120,7 @@ public: ...@@ -121,8 +120,7 @@ public:
inline Tensor& input(const IOIndex_t inputIdx) const override final { inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator"); assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator");
printf("Info: using input() on a GenericOperator.\n");
return *mInputs[inputIdx]; return *mInputs[inputIdx];
} }
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
...@@ -133,7 +131,7 @@ public: ...@@ -133,7 +131,7 @@ public:
return mInputs[inputIdx]; return mInputs[inputIdx];
} }
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Concat Operator has only 1 output"); assert((outputIdx == 0) && "Concat operator has only 1 output");
(void) outputIdx; // avoid unused warning (void) outputIdx; // avoid unused warning
return mOutput; return mOutput;
} }
...@@ -143,7 +141,7 @@ public: ...@@ -143,7 +141,7 @@ public:
return std::static_pointer_cast<Data>(mInputs[inputIdx]); return std::static_pointer_cast<Data>(mInputs[inputIdx]);
} }
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output"); assert(outputIdx == 0 && "Concat operator supports only 1 output");
(void) outputIdx; // avoid unused warning (void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput); return std::static_pointer_cast<Data>(mOutput);
} }
...@@ -172,15 +170,15 @@ public: ...@@ -172,15 +170,15 @@ public:
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; } inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
} }
}; };
inline std::shared_ptr<Node> Concat(int axis, const std::string& name = "") { inline std::shared_ptr<Node> Concat(int axis, IOIndex_t nbIn, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Concat_Op>(axis), name); return std::make_shared<Node>(std::make_shared<Concat_Op>(axis, nbIn), name);
} }
} // namespace Aidge } // namespace Aidge
......
...@@ -23,6 +23,6 @@ void init_Concat(py::module& m) { ...@@ -23,6 +23,6 @@ void init_Concat(py::module& m) {
.def("get_inputs_name", &Concat_Op::getInputsName) .def("get_inputs_name", &Concat_Op::getInputsName)
.def("get_outputs_name", &Concat_Op::getOutputsName); .def("get_outputs_name", &Concat_Op::getOutputsName);
m.def("Concat", &Concat, py::arg("axis"), py::arg("name") = ""); m.def("Concat", &Concat, py::arg("axis"), py::arg("nb_in"), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment