Skip to content
Snippets Groups Projects
Commit 82c69826 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

changed gather input to attr

parent 6424edc9
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!74Update vit operators
......@@ -27,25 +27,26 @@
#include "aidge/utils/Types.h"
namespace Aidge {
enum class GatherAttr { Axis };
enum class GatherAttr { Indices, GatheredShape, Axis };
class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, int> {
public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> {
public:
static const std::string Type;
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, int>;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(int axis)
: OperatorTensor(Type, 2, 0, 1),
Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape),
attr<GatherAttr::Axis>(axis))
{}
......@@ -76,21 +77,21 @@ public:
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "indexes"};
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Gather(int axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name);
inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"};
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Indices", "GatheredShape", "Axis"};
}
#endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */
......@@ -23,6 +23,6 @@ void init_Gather(py::module& m) {
.def("get_inputs_name", &Gather_Op::getInputsName)
.def("get_outputs_name", &Gather_Op::getOutputsName);
m.def("Gather", &Gather, py::arg("axis"), py::arg("name") = "");
m.def("Gather", &Gather, py::arg("indices"), py::arg("gathered_shape"), py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
......@@ -22,18 +22,22 @@ const std::string Aidge::Gather_Op::Type = "Gather";
void Aidge::Gather_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (getInput(1)->nbDims()!=2){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor");
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> indexesDims = getInput(1)->dims();
int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape
std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
if (!gatheredShape.empty())
{
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), gatheredShape.begin(),gatheredShape.end());
}
mOutputs[0]->resize(outDims);
}
\ No newline at end of file
......@@ -34,17 +34,19 @@ void Aidge::Slice_Op::computeOutputDims() {
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis];
const std::int32_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int32_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int32_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis];
const std::size_t sliceLength = end - start + 1;
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
}
outDims[axis] = sliceLength;
}
mOutputs[0]->resize(outDims);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment