Skip to content
Snippets Groups Projects
Commit 4f3d4fbc authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Tensor] Add method to convert 1D idx <-> coordinate + python unittest.

parent 6dbe2ca8
No related branches found
No related tags found
1 merge request!9Fuse bn
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
......@@ -11,9 +11,8 @@ SPDX-License-Identifier: EPL-2.0
import unittest
import aidge_core
class test_parameters(unittest.TestCase):
"""Very basic test to make sure the python APi is not broken.
Can be remove in later stage of the developpement.
class test_recipies(unittest.TestCase):
"""
"""
def setUp(self):
pass
......
"""
Copyright (c) 2023 CEA-List
This program and the accompanying materials are made available under the
terms of the Eclipse Public License 2.0 which is available at
http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
import unittest
import aidge_core
from functools import reduce
import numpy as np
class test_tesnor(unittest.TestCase):
"""
"""
def setUp(self):
pass
def tearDown(self):
pass
def test_remove_flatten(self):
dims = [2,2,2]
size = reduce((lambda x, y: x*y), dims)
np_array = np.arange(size).reshape(dims)
t = aidge_core.Tensor(np_array)
for i in range(size):
coord = t.get_coord(i)
idx = t.get_idx(coord)
self.assertEqual(idx, i)
if __name__ == '__main__':
unittest.main()
......@@ -559,6 +559,40 @@ class Tensor : public Data,
return mGrad;
}
/**
* @brief From the the 1D index, return the coordinate of an element in the tensor.
*
* @param flatIdx 1D index of the value considering a flatten tensor.
* @return std::vector<DimSize_t>
*/
std::vector<DimSize_t> getCoord(DimSize_t flatIdx){
std::vector<DimSize_t> coordIdx = {};
DimSize_t idx = flatIdx;
for (DimSize_t d: mDims){
coordIdx.push_back(idx % d);
idx/=d;
}
return coordIdx;
}
/**
* @brief From the coordinate returns the 1D index of an element in the tensor.
*
* @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t
*/
DimSize_t getIdx(std::vector<DimSize_t> coordIdx){
DimSize_t flatIdx = 0;
DimSize_t stride = 1;
assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
for(std::size_t i=0; i< mDims.size(); ++i){
assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
flatIdx += (coordIdx[i] * stride);
stride *= mDims[i];
}
return flatIdx;
}
private:
///\bug not protected against overflow
std::size_t computeSize() {
......
......@@ -26,10 +26,10 @@ namespace Aidge {
template<typename T>
void addCtor(py::class_<Tensor,
std::shared_ptr<Tensor>,
Data,
std::shared_ptr<Tensor>,
Data,
Registrable<Tensor,
std::tuple<std::string, DataType>,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){
mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
/* Request a buffer descriptor from Python */
......@@ -46,7 +46,7 @@ void addCtor(py::class_<Tensor,
}else{
printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n");
}
return newTensor;
}));
}
......@@ -54,16 +54,16 @@ void addCtor(py::class_<Tensor,
void init_Tensor(py::module& m){
py::class_<Registrable<Tensor,
std::tuple<std::string, DataType>,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>,
std::shared_ptr<Registrable<Tensor,
std::tuple<std::string, DataType>,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable");
py::class_<Tensor, std::shared_ptr<Tensor>,
Data,
py::class_<Tensor, std::shared_ptr<Tensor>,
Data,
Registrable<Tensor,
std::tuple<std::string, DataType>,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
......@@ -74,6 +74,8 @@ void init_Tensor(py::module& m){
.def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl)
.def("get_coord", &Tensor::getCoord)
.def("get_idx", &Tensor::getIdx)
.def_static("get_available_backends", &Tensor::getAvailableBackends)
.def("__str__", [](Tensor& b) {
return b.toString();
......@@ -142,6 +144,6 @@ void init_Tensor(py::module& m){
// #if SIZE_MAX != 0xFFFFFFFF
addCtor<double>(pyClassTensor);
// #endif
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment