Skip to content
Snippets Groups Projects
Commit 2c27c7ac authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

add broadcasting for Add operator

parent 2be78815
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!30add broadcasting for Arithmetic operators
......@@ -25,10 +25,10 @@ namespace Aidge {
// compute kernel registry for forward and backward
class AddImplForward_cpu
: public Registrable<AddImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<const void*>, void*)> {};
: public Registrable<AddImplForward_cpu, std::tuple<DataType, DataType>, void(const std::vector<const void*>, const std::vector<std::vector<std::size_t>>&, const std::size_t, const std::vector<std::size_t>&, void*)> {};
class AddImplBackward_cpu
: public Registrable<AddImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<const void*>, void*)> {};
: public Registrable<AddImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::vector<const void*>, const std::vector<std::vector<std::size_t>>&, const std::size_t, const std::vector<std::size_t>&, void*)> {};
class AddImpl_cpu : public OperatorImpl {
......
......@@ -18,8 +18,33 @@
namespace Aidge {
// Function to get multi-dimensional indices from a flattened index
std::vector<size_t> getMultiDimIndices(const std::vector<size_t>& dimensions, size_t idx) {
std::vector<size_t> indices(dimensions.size(), 0);
for (int i = dimensions.size() - 1; i >= 0; --i) {
indices[i] = idx % dimensions[i];
idx /= dimensions[i];
}
return indices;
}
// Function to get a flattened index from multi-dimensional indices
std::size_t getFlattenedIndex(const std::vector<size_t>& dimensions, const std::vector<size_t>& indices) {
std::size_t flattenedIdx = 0;
std::size_t stride = 1;
for (int i = dimensions.size() - 1; i >= 0; --i) {
std::size_t idx = dimensions[i]>1 ? indices[i] : 0;
flattenedIdx += idx * stride;
stride *= dimensions[i];
}
return flattenedIdx;
}
template <class I, class O>
void AddImpl_cpu_forward_kernel(const std::size_t inputLength, const std::vector<const void*> inputs_, void* output_) {
void AddImpl_cpu_forward_kernel(const std::vector<const void*> inputs_, const std::vector<std::vector<std::size_t>>& inputDims, const std::size_t outputLength, const std::vector<std::size_t>& outDims, void* output_) {
// FIXME: missing Add attributes as arguments
std::vector<const I*> inputs;
for (const auto& input_ : inputs_) {
......@@ -27,12 +52,15 @@ void AddImpl_cpu_forward_kernel(const std::size_t inputLength, const std::vector
}
O* output = static_cast<O*>(output_);
for (std::size_t oIndex = 0; oIndex < inputLength; ++oIndex) {
output[oIndex] = 0;
for (std::size_t iIndex = 0; iIndex < inputs.size(); ++iIndex) {
output[oIndex] += inputs[iIndex][oIndex];
}
}
for (std::size_t oIndex = 0; oIndex < outputLength; ++oIndex)
{
std::vector<size_t> indexes = getMultiDimIndices(outDims, oIndex);
for(std::size_t iIndex = 0; iIndex < inputs.size(); ++iIndex) {
std::size_t idx = getFlattenedIndex(inputDims[iIndex], indexes);
output[oIndex] += inputs[iIndex][idx];
}
}
}
namespace {
......
......@@ -55,15 +55,25 @@ void Aidge::AddImpl_cpu::forward() {
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::size_t nbDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->nbDims();
std::vector<std::vector<std::size_t>> inputsDims;
std::vector<const void*> opInputs;
std::vector<std::shared_ptr<Tensor>> inputsFallback(mOp.nbInputs());
for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->refCastFrom(inputsFallback[i], *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
opInputs.push_back(input.getImpl()->rawPtr());
std::vector<std::size_t> inputDims(nbDims, 1);
auto dims = std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->dims();
for(std::size_t j=dims.size()-1; j+1>0; --j)
{
std::size_t idx = nbDims - (dims.size()-j);
inputDims[idx] = dims[j];
}
inputsDims.push_back(inputDims);
opInputs.push_back(getCPUPtr(mOp.getRawInput(i)));
}
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
opInputs,
kernelFunc(opInputs,
inputsDims,
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
getCPUPtr(mOp.getRawOutput(0)));
}
......@@ -117,4 +117,63 @@ TEST_CASE("[cpu/operator] Add(forward)", "[Add][CPU]") {
REQUIRE(*op->getOutput(0) == *expectedOutput);
}
SECTION("Broadcasting") {
std::shared_ptr<Tensor> input_0 = std::make_shared<Tensor>(Array4D<int,3,1,3,2> {
{ //
{ //
{{0, 1},{2, 3},{4, 5}} //
}, //
{ //
{{6, 7},{8, 9},{10, 11}} //
}, //
{ //
{{12, 13},{14, 15},{16, 17}} //
} //
} //
}); //
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
{ //
{ //
{{20, 21},{22, 23},{24, 25}}, //
{{26, 27},{28, 29},{30, 31}}, //
{{32, 33},{34, 35},{36, 37}} //
} //
} //
}); //
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<int,2> {{100,200}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
{ //
{ //
{{ 120, 222},{ 124, 226},{ 128, 230}}, //
{{ 126, 228},{ 130, 232},{ 134, 236}}, //
{{ 132, 234},{ 136, 238},{ 140, 242}} //
}, //
{ //
{{ 126, 228},{ 130, 232},{ 134, 236}}, //
{{ 132, 234},{ 136, 238},{ 140, 242}}, //
{{ 138, 240},{ 142, 244},{ 146, 248}} //
}, //
{ //
{{ 132, 234},{ 136, 238},{140, 242}}, //
{{ 138, 240},{ 142, 244},{146, 248}}, //
{{ 144, 246},{ 148, 250},{152, 254}} //
} //
} //
}); //
std::shared_ptr<Node> myAdd = Add(3);
auto op = std::static_pointer_cast<OperatorTensor>(myAdd -> getOperator());
op->associateInput(0, input_0);
op->associateInput(1, input_1);
op->associateInput(2, input_2);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
op->computeOutputDims();
myAdd->forward();
op->getOutput(0)->print();
expectedOutput->print();
REQUIRE(*op->getOutput(0) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment