Skip to content
Snippets Groups Projects
Commit 85b20786 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add broadcasting for Add operator

parent 602a9337
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!30add broadcasting for Arithmetic operators
...@@ -25,10 +25,10 @@ namespace Aidge { ...@@ -25,10 +25,10 @@ namespace Aidge {
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class AddImplForward_cpu class AddImplForward_cpu
: public Registrable<AddImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<const void*>, void*)> {}; : public Registrable<AddImplForward_cpu, std::tuple<DataType, DataType>, void(const std::vector<const void*>, const std::vector<std::vector<std::size_t>>&, const std::size_t, const std::vector<std::size_t>&, void*)> {};
class AddImplBackward_cpu class AddImplBackward_cpu
: public Registrable<AddImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<const void*>, void*)> {}; : public Registrable<AddImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::vector<const void*>, const std::vector<std::vector<std::size_t>>&, const std::size_t, const std::vector<std::size_t>&, void*)> {};
class AddImpl_cpu : public OperatorImpl { class AddImpl_cpu : public OperatorImpl {
......
...@@ -18,8 +18,33 @@ ...@@ -18,8 +18,33 @@
namespace Aidge { namespace Aidge {
// Function to get multi-dimensional indices from a flattened index
std::vector<size_t> getMultiDimIndices(const std::vector<size_t>& dimensions, size_t idx) {
std::vector<size_t> indices(dimensions.size(), 0);
for (int i = dimensions.size() - 1; i >= 0; --i) {
indices[i] = idx % dimensions[i];
idx /= dimensions[i];
}
return indices;
}
// Function to get a flattened index from multi-dimensional indices
std::size_t getFlattenedIndex(const std::vector<size_t>& dimensions, const std::vector<size_t>& indices) {
std::size_t flattenedIdx = 0;
std::size_t stride = 1;
for (int i = dimensions.size() - 1; i >= 0; --i) {
std::size_t idx = dimensions[i]>1 ? indices[i] : 0;
flattenedIdx += idx * stride;
stride *= dimensions[i];
}
return flattenedIdx;
}
template <class I, class O> template <class I, class O>
void AddImpl_cpu_forward_kernel(const std::size_t inputLength, const std::vector<const void*> inputs_, void* output_) { void AddImpl_cpu_forward_kernel(const std::vector<const void*> inputs_, const std::vector<std::vector<std::size_t>>& inputDims, const std::size_t outputLength, const std::vector<std::size_t>& outDims, void* output_) {
// FIXME: missing Add attributes as arguments // FIXME: missing Add attributes as arguments
std::vector<const I*> inputs; std::vector<const I*> inputs;
for (const auto& input_ : inputs_) { for (const auto& input_ : inputs_) {
...@@ -27,9 +52,13 @@ void AddImpl_cpu_forward_kernel(const std::size_t inputLength, const std::vector ...@@ -27,9 +52,13 @@ void AddImpl_cpu_forward_kernel(const std::size_t inputLength, const std::vector
} }
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
for (std::size_t iIndex = 0; iIndex < inputs.size(); ++iIndex) { for (std::size_t oIndex = 0; oIndex < outputLength; ++oIndex)
for (std::size_t oIndex = 0; oIndex < inputLength; ++oIndex) { {
output[oIndex] += inputs[iIndex][oIndex]; std::vector<size_t> indexes = getMultiDimIndices(outDims, oIndex);
for(std::size_t iIndex = 0; iIndex < inputs.size(); ++iIndex) {
std::size_t idx = getFlattenedIndex(inputDims[iIndex], indexes);
output[oIndex] += inputs[iIndex][idx];
} }
} }
} }
......
...@@ -73,12 +73,24 @@ void Aidge::AddImpl_cpu::forward() { ...@@ -73,12 +73,24 @@ void Aidge::AddImpl_cpu::forward() {
datatypeFirstInput, datatypeFirstInput,
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
std::size_t nbDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->nbDims();
std::vector<std::vector<std::size_t>> inputsDims;
std::vector<const void*> opInputs; std::vector<const void*> opInputs;
for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) { for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
std::vector<std::size_t> inputDims(nbDims, 1);
auto dims = std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->dims();
for(std::size_t j=dims.size()-1; j+1>0; --j)
{
std::size_t idx = nbDims - (dims.size()-j);
inputDims[idx] = dims[j];
}
inputsDims.push_back(inputDims);
opInputs.push_back(getCPUPtr(mOp.getRawInput(i))); opInputs.push_back(getCPUPtr(mOp.getRawInput(i)));
} }
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), kernelFunc(opInputs,
opInputs, inputsDims,
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
} }
\ No newline at end of file
...@@ -117,4 +117,63 @@ TEST_CASE("[cpu/operator] Add(forward)", "[Add][CPU]") { ...@@ -117,4 +117,63 @@ TEST_CASE("[cpu/operator] Add(forward)", "[Add][CPU]") {
REQUIRE(*op->getOutput(0) == *expectedOutput); REQUIRE(*op->getOutput(0) == *expectedOutput);
} }
SECTION("Broadcasting") {
std::shared_ptr<Tensor> input_0 = std::make_shared<Tensor>(Array4D<int,3,1,3,2> {
{ //
{ //
{{0, 1},{2, 3},{4, 5}} //
}, //
{ //
{{6, 7},{8, 9},{10, 11}} //
}, //
{ //
{{12, 13},{14, 15},{16, 17}} //
} //
} //
}); //
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
{ //
{ //
{{20, 21},{22, 23},{24, 25}}, //
{{26, 27},{28, 29},{30, 31}}, //
{{32, 33},{34, 35},{36, 37}} //
} //
} //
}); //
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<int,2> {{100,200}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
{ //
{ //
{{ 120, 222},{ 124, 226},{ 128, 230}}, //
{{ 126, 228},{ 130, 232},{ 134, 236}}, //
{{ 132, 234},{ 136, 238},{ 140, 242}} //
}, //
{ //
{{ 126, 228},{ 130, 232},{ 134, 236}}, //
{{ 132, 234},{ 136, 238},{ 140, 242}}, //
{{ 138, 240},{ 142, 244},{ 146, 248}} //
}, //
{ //
{{ 132, 234},{ 136, 238},{140, 242}}, //
{{ 138, 240},{ 142, 244},{146, 248}}, //
{{ 144, 246},{ 148, 250},{152, 254}} //
} //
} //
}); //
std::shared_ptr<Node> myAdd = Add(3);
auto op = std::static_pointer_cast<OperatorTensor>(myAdd -> getOperator());
op->associateInput(0, input_0);
op->associateInput(1, input_1);
op->associateInput(2, input_2);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
op->computeOutputDims();
myAdd->forward();
op->getOutput(0)->print();
expectedOutput->print();
REQUIRE(*op->getOutput(0) == *expectedOutput);
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment