Skip to content
Snippets Groups Projects
Commit 00400ceb authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add broadcasting functions

parent 19b436a6
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!30add broadcasting for Arithmetic operators
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_DATA_BROADCASTING_H_
#define AIDGE_CPU_DATA_BROADCASTING_H_
#include <vector>
namespace Aidge {
// Function to broadCast an input dims vector into the same size as an outputDims vector
std::vector<std::size_t> getBroadcastedDims(const std::vector<std::size_t>& outputDims, const std::vector<std::size_t>& dimsToBroadcast);
// Function to get multi-dimensional indices from a flattened index
std::vector<std::size_t> getMultiDimIndices(const std::vector<std::size_t>& dimensions, std::size_t idx);
// Function to get a flattened index from multi-dimensional indices
std::size_t getFlattenedIndex(const std::vector<std::size_t>& dimensions, const std::vector<std::size_t>& indices);
} // namespace Aidge
#endif // AIDGE_CPU_DATA_BROADCASTING_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/data/Broadcasting.hpp"
std::vector<std::size_t> Aidge::getBroadcastedDims(const std::vector<std::size_t>& outputDims, const std::vector<std::size_t>& dimsToBroadcast){
std::vector<std::size_t> broadcastedDims(outputDims.size(), 1);
for(std::size_t j=dimsToBroadcast.size()-1; j+1>0; --j)
{
std::size_t idx = outputDims.size() - (dimsToBroadcast.size()-j);
broadcastedDims[idx] = dimsToBroadcast[j];
}
return broadcastedDims;
}
std::vector<std::size_t> Aidge::getMultiDimIndices(const std::vector<size_t>& dimensions, std::size_t idx){
std::vector<std::size_t> indices(dimensions.size(), 0);
for (int i = dimensions.size() - 1; i >= 0; --i) {
indices[i] = idx % dimensions[i];
idx /= dimensions[i];
}
return indices;
}
std::size_t Aidge::getFlattenedIndex(const std::vector<std::size_t>& dimensions, const std::vector<std::size_t>& indices){
std::size_t flattenedIdx = 0;
std::size_t stride = 1;
for (int i = dimensions.size() - 1; i >= 0; --i) {
std::size_t idx = dimensions[i]>1 ? indices[i] : 0;
flattenedIdx += idx * stride;
stride *= dimensions[i];
}
return flattenedIdx;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment