Skip to content
Snippets Groups Projects
Commit 3b6e6292 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix python binding of concat by adding nb_in attr

parent fd1f62bb
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
......@@ -24,10 +24,10 @@ namespace Aidge {
// compute kernel registry for forward and backward
class ConcatImplForward_cpu
: public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<DimSize_t>, const std::vector<void*>, void*)> {
: public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const std::vector<void*>, void*)> {
};
class ConcatImplBackward_cpu
: public Registrable<ConcatImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<DimSize_t>, const std::vector<void*>, void*)> {
: public Registrable<ConcatImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const std::vector<void*>, void*)> {
};
class ConcatImpl_cpu : public OperatorImpl {
......
......@@ -22,7 +22,7 @@
namespace Aidge {
template <class I, class O>
void ConcatImpl_cpu_forward_kernel(const int& axisIdx, std::vector<DimSize_t> arraysDims, const std::vector<void*> input_, void* output_)
void ConcatImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const std::vector<DimSize_t>& dimsOnAxis, const std::vector<void*> input_, void* output_)
{
O* output = static_cast<O*>(output_);
std::vector<I*> input;
......@@ -31,27 +31,25 @@ void ConcatImpl_cpu_forward_kernel(const int& axisIdx, std::vector<DimSize_t> ar
input.emplace_back(static_cast<I*>(elem));
}
// compute length of chunks to copy from each input tensor
size_t chunkSize = 1;
size_t totalTensorSize = 1;
for(size_t i=arraysDims.size()-1; i>0 ; --i)
{
if(i >= axisIdx)
chunkSize *= arraysDims[i];
totalTensorSize *= arraysDims[i];
std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
postAxisElems *= inputDims[i];
}
std::size_t preAxisElems = 1;
for (std::size_t i = 0; i < axisIdx; ++i) {
preAxisElems *= inputDims[i];
}
size_t iterationsCount = totalTensorSize / chunkSize;
for(size_t i=0; i<iterationsCount ; ++i)
for(std::size_t i=0; i<preAxisElems ; ++i)
{
for(size_t j=0; j < input.size(); ++j)
for(std::size_t j=0; j < input.size(); ++j)
{
I* copyPtr = std::next(input[j], i * chunkSize);
std::copy_n(copyPtr, chunkSize, output);
output += chunkSize;
}
std::size_t strideOnAxis = postAxisElems * dimsOnAxis[j];
const I* copyPtr = std::next(input[j], i * strideOnAxis);
std::copy_n(copyPtr, strideOnAxis, output);
output += strideOnAxis;
}
}
}
namespace {
......
......@@ -21,31 +21,33 @@
#include "aidge/backend/cpu/operator/ConcatImpl.hpp"
#include "aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp"
#include <iostream>
Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::ConcatImpl_cpu::forward() {
assert(mOp.getInput(0) && "missing input #0");
for (std::size_t i = 0; i < dynamic_cast<const Concat_Op&>(mOp).mNbIn; ++i) {
assert(mOp.getInput(i) && ("missing input #"+std::to_string(i)).c_str());
}
Concat_Op::Attrs attr = dynamic_cast<const Concat_Op&>(mOp).getStaticAttributes();
const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
assert(mOp.getInput(0)->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx)
std::size_t axisIdx = static_cast<const int&>(std::get<0>(attr));
assert(mOp.getInput(0)->nbDims() > axisIdx && ("input dim must be bigger than "+std::to_string(axisIdx)).c_str());
auto kernelFunc = Registrar<ConcatImplForward_cpu>::create({
mOp.getInput(0)->dataType(),
mOp.getOutput(0)->dataType()});
// Call kernel
std::vector<void*> inputTensors;
std::vector<std::size_t> dimsOnAxis;
for (std::size_t i = 0; i < dynamic_cast<const Concat_Op&>(mOp).mNbIn; ++i) {
inputTensors.push_back(mOp.getInput(i)->getImpl()->rawPtr());
dimsOnAxis.push_back(mOp.getInput(i)->dims()[axisIdx]);
}
// Call kernel
kernelFunc(axisIdx,
mOp.getInput(0)->dims(),
dimsOnAxis,
inputTensors,
mOp.getOutput(0)->getImpl()->rawPtr());
}
......@@ -21,7 +21,7 @@
using namespace Aidge;
TEST_CASE("[cpu/operator] Concat(forward)") {
SECTION("2D Tensor") {
SECTION("2D Tensors") {
std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array2D<float,2,2> {
{
{0.00543531, 0.53726782},
......@@ -41,7 +41,7 @@ TEST_CASE("[cpu/operator] Concat(forward)") {
}
});
std::shared_ptr<Node> myConcat = Concat(1);
std::shared_ptr<Node> myConcat = Concat(1, 2);
myConcat->getOperator()->setDatatype(DataType::Float32);
myConcat->getOperator()->setBackend("cpu");
myConcat->getOperator()->associateInput(0,input1);
......@@ -49,11 +49,57 @@ TEST_CASE("[cpu/operator] Concat(forward)") {
myConcat->getOperator()->computeOutputDims();
myConcat->forward();
float* resPtr = static_cast<float*>(myConcat->getOperator()->getOutput(0)->getImpl()->rawPtr());
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 3; ++i) {
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
}
REQUIRE(*(myConcat->getOperator()->getOutput(0)) == *expectedOutput);
}
SECTION("3D Tensors") {
std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array3D<int,2,1,3> {
{
{
{1, 2, 3}
},
{
{4, 5, 6}
}
}
});
std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array3D<int,2,2,3> {
{
{
{10, 11, 12},
{13, 14, 15}
},
{
{16, 17, 18},
{19, 20, 21}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,2,4,3> {
{
{
{ 1, 2, 3 },
{ 10, 11, 12 },
{ 13, 14, 15 },
{ 1, 2, 3 }
},
{
{ 4, 5, 6 },
{ 16, 17, 18 },
{ 19, 20, 21 },
{ 4, 5, 6 }
}
}
});
std::shared_ptr<Node> myConcat = Concat(1, 3);
myConcat->getOperator()->setDatatype(DataType::Int32);
myConcat->getOperator()->setBackend("cpu");
myConcat->getOperator()->associateInput(0,input1);
myConcat->getOperator()->associateInput(1,input2);
myConcat->getOperator()->associateInput(2,input1);
myConcat->getOperator()->computeOutputDims();
myConcat->forward();
REQUIRE(*(myConcat->getOperator()->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment