Skip to content
Snippets Groups Projects
Commit 14453dbf authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

update gather to accept either indices inputs or attributes

parent 8c16072a
No related branches found
No related tags found
No related merge requests found
......@@ -31,24 +31,26 @@ public:
void forward() override;
};
enum class GatherAttr { Axis };
enum class GatherAttr { Axis, Indices, GatheredShape };
class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::shared_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, std::int8_t> {
public StaticAttributes<GatherAttr, std::int8_t, std::vector<int64_t>, std::vector<DimSize_t>> {
public:
static const std::string Type;
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, std::int8_t>;
using Attributes_ = StaticAttributes<GatherAttr, std::int8_t, std::vector<int64_t>, std::vector<DimSize_t>>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(std::int8_t axis)
Gather_Op(std::int8_t axis, const std::vector<int64_t>& indices, const std::vector<DimSize_t>& gatheredShape)
: OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<GatherAttr::Axis>(axis))
Attributes_(attr<GatherAttr::Axis>(axis),
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape))
{
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
......@@ -89,14 +91,14 @@ public:
}
};
inline std::shared_ptr<Node> Gather(std::int8_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name);
inline std::shared_ptr<Node> Gather(std::int8_t axis = 0, const std::vector<int64_t>& indices = {}, const std::vector<DimSize_t>& gatheredShape = {}, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(axis, indices, gatheredShape), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"};
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis", "Indices", "GatheredShape"};
}
#endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */
......@@ -11,6 +11,7 @@
#include <pybind11/pybind11.h>
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Gather.hpp"
......@@ -26,6 +27,6 @@ void init_Gather(py::module& m) {
.def("attributes_name", &Gather_Op::staticGetAttrsName);
declare_registrable<Gather_Op>(m, "GatherOp");
m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("name") = "");
m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("indices") = std::vector<std::int64_t>(), py::arg("gathered_shape") = std::vector<std::size_t>(), py::arg("name") = "");
}
} // namespace Aidge
......@@ -35,15 +35,14 @@ void Aidge::Gather_OpImpl::forward() {
preAxisElems *= op.getInput(0)->dims()[i];
}
const auto indices = static_cast<const int*>(op.getInput(1)->getImpl()->rawPtr());
std::size_t outputOffset = 0;
for (std::size_t i=0; i<preAxisElems; ++i)
{
for(std::size_t j=0; j<op.getInput(1)->size(); ++j)
for(std::size_t j=0; j<op.template getAttr<std::vector<int64_t>>("Indices").size(); ++j)
{
const std::size_t idx = indices[j] >= 0 ?
static_cast<std::size_t>(indices[j]) :
static_cast<std::size_t>(indices[j] + static_cast<int>(op.getInput(0)->dims()[axisIdx]));
const std::size_t idx = op.template getAttr<std::vector<int64_t>>("Indices")[j] >= 0 ?
static_cast<std::size_t>(op.template getAttr<std::vector<int64_t>>("Indices")[j]) :
static_cast<std::size_t>(op.template getAttr<std::vector<int64_t>>("Indices")[j] + static_cast<int>(op.getInput(0)->dims()[axisIdx]));
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(i * postAxisElems * op.getInput(0)->dims()[axisIdx] + idx * postAxisElems), postAxisElems, outputOffset);
outputOffset += postAxisElems;
}
......@@ -53,25 +52,57 @@ void Aidge::Gather_OpImpl::forward() {
const std::string Aidge::Gather_Op::Type = "Gather";
bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) {
// check inputs have been associated
for(int i=0; i<2; ++i){
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
}
// check data input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!getInput(0)->empty() && !getInput(1)->empty()) {
if (!getInput(0)->empty()) {
if (this->template getAttr<GatherAttr::Indices>().empty())
{
if(getInput(1)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type());
}
this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims();
this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> indicesDims = getInput(1)->dims();
std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
this->template getAttr<GatherAttr::Axis>():
this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if( indicesDims[0]>0 ) // In case indices is a scalar indicesDims is a 0
if( !this->template getAttr<GatherAttr::GatheredShape>().empty())
{
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indicesDims.begin(),indicesDims.end());
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx),
this->template getAttr<GatherAttr::GatheredShape>().begin(),
this->template getAttr<GatherAttr::GatheredShape>().end());
}
mOutputs[0]->resize(outDims);
return true;
......
......@@ -94,4 +94,36 @@ TEST_CASE("[cpu/operator] Gather(forward)", "[Gather][CPU]") {
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
SECTION("Init with attributes") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,3,3> {
{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,3,1,2> {
{
{
{1, 3}
},
{
{4, 6}
},
{
{7, 9}
}
}
});
std::shared_ptr<Node> myGather = Gather(1, {0, 2}, {1, 2});
auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator());
op->associateInput(0,input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myGather->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment