Skip to content
Snippets Groups Projects
Commit 764504df authored by Jerome Hue's avatar Jerome Hue Committed by Maxence Naud
Browse files

Add a Tensor.repeat() method

parent 76ccc349
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!351feat: add rate spikegen for snns
#ifndef AIDGE_CORE_DATA_SPIKEGEN_H_
#define AIDGE_CORE_DATA_SPIKEGEN_H_
// Spikegen algorithm :
//
// time_data = data.repeat(time_steps)
// spike_data = rate_conv(time_data)
// return spike_data
#endif
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
#define AIDGE_CORE_DATA_TENSOR_H_ #define AIDGE_CORE_DATA_TENSOR_H_
#include <algorithm> #include <algorithm>
#include <cstddef> // std::size_t #include <cstddef> // std::size_t #include <cstring> #include <functional> // std::multiplies
#include <cstring>
#include <functional> // std::multiplies
#include <set> #include <set>
#include <memory> #include <memory>
#include <numeric> // std::accumulate #include <numeric> // std::accumulate
...@@ -989,6 +987,18 @@ public: ...@@ -989,6 +987,18 @@ public:
return ref(fallback, targetReqs.dataType(), device.first, device.second); return ref(fallback, targetReqs.dataType(), device.first, device.second);
} }
/**
* @brief Repeat the tensor along a new first dimension.
* For example, if the current tensor has dimensions (n, m),
* calling repeat(10) returns a tensor of shape (10, n, m)
* with 10 copies of the original data.
*
* @param times number of repetitions (must be positive)
* @return Tensor new tensor containing the repeated data.
*/
Tensor repeat(int times) const;
private: private:
/** /**
* @brief Compute the number of elements in the Tensor. * @brief Compute the number of elements in the Tensor.
......
...@@ -802,6 +802,43 @@ const Tensor& Tensor::ref(std::shared_ptr<Tensor>& fallback, ...@@ -802,6 +802,43 @@ const Tensor& Tensor::ref(std::shared_ptr<Tensor>& fallback,
} }
} }
Tensor Tensor::repeat(int times) const {
AIDGE_ASSERT(times > 0, "repeat count must be positive");
// Ensure that the source tensor is contiguous.
Tensor src = *this;
if (not src.isContiguous()) {
src = src.clone();
src.makeContiguous();
}
// Build new dimensions: new_dims = {times} followed by current dims.
std::vector<DimSize_t> newDims;
newDims.push_back(static_cast<DimSize_t>(times));
for (const auto &d : dims()) {
newDims.push_back(d);
}
// Create an output tensor with the new dimensions.
Tensor out(newDims);
out.setDataType(dataType(), false);
out.setDataFormat(dataFormat());
if (hasImpl()) {
out.setBackend(getImpl()->backend(), device());
}
// Each "block" is a copy of the data from the original tensor.
const std::size_t block = src.size();
// Loop over the repeat count and copy the block each time.
for (int i = 0; i < times; ++i) {
// out.getImpl()->copy(source pointer, number of elements, destination offset)
out.getImpl()->copy(src.getImpl()->rawPtr(src.getImplOffset()),
block,
i * block);
}
return out;
}
std::vector<std::size_t> std::vector<std::size_t>
Tensor::toCoord(const std::vector<DimSize_t>& dimensions, std::size_t index) { Tensor::toCoord(const std::vector<DimSize_t>& dimensions, std::size_t index) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment