Skip to content
Snippets Groups Projects
Commit ce73448a authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

change gather input into attr

parent 36d8eb39
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
......@@ -24,10 +24,10 @@ namespace Aidge {
// compute kernel registry for forward and backward
class GatherImplForward_cpu
: public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> {
: public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> {
};
class GatherImplBackward_cpu
: public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> {
: public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> {
};
class GatherImpl_cpu : public OperatorImpl {
......
......@@ -22,12 +22,13 @@
namespace Aidge {
template <class I, class O>
void GatherImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const std::vector<DimSize_t>& indicesDims, const void* input_, const void* indexes_, void* output_)
void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_)
{
const I* input = static_cast<const I*>(input_);
const int* indexes = static_cast<const int*>(indexes_);
O* output = static_cast<O*>(output_);
std::size_t axisIdx = std::get<2>(attrs)>=0 ? std::get<2>(attrs) : static_cast<std::size_t>(std::get<2>(attrs)) + inputDims.size();
std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
postAxisElems *= inputDims[i];
......@@ -37,17 +38,15 @@ void GatherImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSiz
preAxisElems *= inputDims[i];
}
std::vector<std::int64_t> indices = std::get<0>(attrs);
for (std::size_t i=0; i<preAxisElems; ++i)
{
for(std::size_t idxRow=0; idxRow<indicesDims[0]; ++idxRow)
for(std::size_t j=0; j<indices.size(); ++j)
{
for(std::size_t idxCol=0; idxCol<indicesDims[1]; ++idxCol)
{
std::size_t idx = indexes[indicesDims[1] * idxRow + idxCol];
const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems);
std::copy_n(startPtr, postAxisElems, output);
output += postAxisElems;
}
std::size_t idx = indices[j] >= 0 ? indices[j] : indices[j] + inputDims[axisIdx];
const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems);
std::copy_n(startPtr, postAxisElems, output);
output += postAxisElems;
}
}
}
......
......@@ -27,19 +27,14 @@ Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOInd
}
void Aidge::GatherImpl_cpu::forward() {
Gather_Op::Attrs attr = dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes();
const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx)
auto kernelFunc = Registrar<GatherImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel
kernelFunc(axisIdx,
kernelFunc(dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
......@@ -44,14 +44,16 @@ TEST_CASE("[cpu/operator] Gather(forward)") {
}
});
std::shared_ptr<Node> myGather = Gather(0);
std::shared_ptr<Node> myGather = Gather({1, 2}, {1, 2}, 0);
auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator());
op->associateInput(0,input);
op->associateInput(1,indexes);
// op->associateInput(1,indexes);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
op->computeOutputDims();
myGather->forward();
op->getOutput(0)->print();
expectedOutput->print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
......@@ -83,10 +85,10 @@ TEST_CASE("[cpu/operator] Gather(forward)") {
}
});
std::shared_ptr<Node> myGather = Gather(1);
std::shared_ptr<Node> myGather = Gather({0, 2}, {1, 2}, 1);
auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator());
op->associateInput(0,input);
op->associateInput(1,indexes);
// op->associateInput(1,indexes);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
op->computeOutputDims();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment