Skip to content
Snippets Groups Projects
Commit 1581a1e9 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

add broadcasting functions

parent 2c27c7ac
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!30add broadcasting for Arithmetic operators
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_DATA_BROADCASTING_H_
#define AIDGE_CPU_DATA_BROADCASTING_H_
#include <vector>
namespace Aidge {
// Function to broadCast an input dims vector into the same size as an outputDims vector
std::vector<std::size_t> getBroadcastedDims(const std::vector<std::size_t>& outputDims, const std::vector<std::size_t>& dimsToBroadcast);
// Function to get multi-dimensional indices from a flattened index
std::vector<std::size_t> getMultiDimIndices(const std::vector<std::size_t>& dimensions, std::size_t idx);
// Function to get a flattened index from multi-dimensional indices
std::size_t getFlattenedIndex(const std::vector<std::size_t>& dimensions, const std::vector<std::size_t>& indices);
} // namespace Aidge
#endif // AIDGE_CPU_DATA_BROADCASTING_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/data/Broadcasting.hpp"
std::vector<std::size_t> Aidge::getBroadcastedDims(const std::vector<std::size_t>& outputDims, const std::vector<std::size_t>& dimsToBroadcast){
std::vector<std::size_t> broadcastedDims(outputDims.size(), 1);
for(std::size_t j=dimsToBroadcast.size()-1; j+1>0; --j)
{
std::size_t idx = outputDims.size() - (dimsToBroadcast.size()-j);
broadcastedDims[idx] = dimsToBroadcast[j];
}
return broadcastedDims;
}
std::vector<std::size_t> Aidge::getMultiDimIndices(const std::vector<size_t>& dimensions, std::size_t idx){
std::vector<std::size_t> indices(dimensions.size(), 0);
for (int i = dimensions.size() - 1; i >= 0; --i) {
indices[i] = idx % dimensions[i];
idx /= dimensions[i];
}
return indices;
}
std::size_t Aidge::getFlattenedIndex(const std::vector<std::size_t>& dimensions, const std::vector<std::size_t>& indices){
std::size_t flattenedIdx = 0;
std::size_t stride = 1;
for (int i = dimensions.size() - 1; i >= 0; --i) {
std::size_t idx = dimensions[i]>1 ? indices[i] : 0;
flattenedIdx += idx * stride;
stride *= dimensions[i];
}
return flattenedIdx;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment