Skip to content
Snippets Groups Projects
Commit 71223519 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix python binding of concat by adding nb_in attr

parent 988ef2f4
No related branches found
No related tags found
No related merge requests found
......@@ -48,11 +48,15 @@ public:
using Attributes_ = StaticAttributes<ConcatAttr, int>;
template <ConcatAttr e> using attr = typename Attributes_::template attr<e>;
Concat_Op(int axis)
Concat_Op(int axis, IOIndex_t nbIn)
: Operator(Type),
Attributes_(
attr<ConcatAttr::Axis>(axis))
mNbIn(nbIn),
Attributes_(attr<ConcatAttr::Axis>(axis))
{
mInputs = std::vector<std::shared_ptr<Tensor>>(nbIn);
for (std::size_t i = 0; i < nbIn; ++i) {
mInputs[i] = std::make_shared<Tensor>();
}
setDatatype(DataType::Float32);
}
......@@ -67,12 +71,12 @@ public:
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn);
for (std::size_t i = 0; i < mNbIn; ++i) {
mInputs = std::vector<std::shared_ptr<Tensor>>(op.mNbIn);
for (std::size_t i = 0; i < op.mNbIn; ++i) {
mInputs[i] = std::make_shared<Tensor>();
}
setDatatype(op.mOutput->dataType());
}
/**
......@@ -84,30 +88,25 @@ public:
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
// assert(inputIdx < mNbIn && "operators supports only x inputs");
if (strcmp(data->type(), Tensor::Type) == 0) {
// TODO: associate input only if of type Tensor, otherwise do nothing
if(inputIdx<mInputs.size())
mInputs.insert( mInputs.begin() + inputIdx, std::dynamic_pointer_cast<Tensor>(data));
else
mInputs.emplace_back(std::dynamic_pointer_cast<Tensor>(data));
mNbIn = mInputs.size();
}
assert(inputIdx < mNbIn && "index out of bound");
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final {
if (!mInputs.empty() && !mInputs[0]->empty())
{
// mOutput->resize(mInputs[0]->dims());
Concat_Op::Attrs attr = getStaticAttributes();
const int& axis = static_cast<const int&>(std::get<0>(attr));
std::size_t dimOnAxis = 0;
for(std::size_t i=0; i<mNbIn; ++i)
{
dimOnAxis += mInputs[i]->dims()[axis];
}
std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) {
if(i==axis)
outputDims.push_back(mInputs.size() * mInputs[0]->dims()[i]);
outputDims.push_back(dimOnAxis);
else
outputDims.push_back(mInputs[0]->dims()[i]);
}
......@@ -121,8 +120,7 @@ public:
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator");
printf("Info: using input() on a GenericOperator.\n");
assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator");
return *mInputs[inputIdx];
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
......@@ -133,7 +131,7 @@ public:
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Concat Operator has only 1 output");
assert((outputIdx == 0) && "Concat operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
......@@ -143,7 +141,7 @@ public:
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
assert(outputIdx == 0 && "Concat operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
......@@ -172,15 +170,15 @@ public:
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Concat(int axis, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Concat_Op>(axis), name);
inline std::shared_ptr<Node> Concat(int axis, IOIndex_t nbIn, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Concat_Op>(axis, nbIn), name);
}
} // namespace Aidge
......
......@@ -23,6 +23,6 @@ void init_Concat(py::module& m) {
.def("get_inputs_name", &Concat_Op::getInputsName)
.def("get_outputs_name", &Concat_Op::getOutputsName);
m.def("Concat", &Concat, py::arg("axis"), py::arg("name") = "");
m.def("Concat", &Concat, py::arg("axis"), py::arg("nb_in"), py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment