Skip to content
Snippets Groups Projects
Commit 05374c04 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

keep merged concat operator

parent d0b89c96
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
......@@ -12,7 +12,6 @@
#ifndef AIDGE_CORE_OPERATOR_CONCAT_H_
#define AIDGE_CORE_OPERATOR_CONCAT_H_
#include <cassert>
#include <numeric>
#include <vector>
#include <cmath>
......@@ -67,29 +66,41 @@ public:
return std::make_shared<Concat_Op>(*this);
}
// Data operator[](const char* inputName) override final {
// std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
// (strcmp(inputName, "weight") ? mInputs[1] :
// (strcmp(inputName, "bias") ? mInputs[2] :
// nullptr));
// assert((in!=nullptr) && "No such parameter");
// return *in;
// }
void computeOutputDims() override final {
if (!mInputs.empty() && !mInputs[0]->empty())
{
Concat_Op::Attrs attr = getStaticAttributes();
const int& axis = static_cast<const int&>(std::get<0>(attr));
std::size_t dimOnAxis = 0;
for(std::size_t i=0; i< mInputs.size(); ++i)
{
dimOnAxis += mInputs[i]->dims()[axis];
// Every input is non-empty with the same number of dimensions
bool associated = (getInput(0) != nullptr);
associated &= !(getInput(0)->empty()) && (getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims()); // do not compute anything if no input
auto outputDims = getInput(0)->dims();
const auto firstInputNbDims = getInput(0) -> nbDims();
for (IOIndex_t i = 1; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) {
if(i==axis)
outputDims.push_back(dimOnAxis);
else
outputDims.push_back(mInputs[0]->dims()[i]);
associated &= (getInput(i)->nbDims() == firstInputNbDims);
for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) {
if (dim == getAttr<ConcatAttr::Axis>()) {
outputDims[dim] += getInput(i)->dims()[dim];
}
else {
associated &= (getInput(i)->dims()[dim] == outputDims[dim]);
}
}
mOutputs[0]->resize(outputDims);
}
if (associated) {
getOutput(0)->resize(outputDims);
}
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Concat_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
......@@ -108,14 +119,16 @@ public:
}
};
inline std::shared_ptr<Node> Concat(IOIndex_t nbIn, int axis, const std::string& name = "") {
inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Concat_Op>(nbIn, axis), name);
}
} // namespace Aidge
}
namespace {
template <>
const char *const EnumStrings<Aidge::ConcatAttr>::data[] = {"Axis"};
template <>
const char* const EnumStrings<Aidge::ConcatAttr>::data[] = {
"Axis"
};
}
#endif /* AIDGE_CORE_OPERATOR_CONCAT_H_ */
......@@ -23,6 +23,6 @@ void init_Concat(py::module& m) {
.def("get_inputs_name", &Concat_Op::getInputsName)
.def("get_outputs_name", &Concat_Op::getOutputsName);
m.def("Concat", &Concat, py::arg("axis"), py::arg("nb_in"), py::arg("name") = "");
m.def("Concat", &Concat, py::arg("nbIn"), py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment