Skip to content
Snippets Groups Projects
Commit 1b16db67 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] Scalar constructor for Tensor and int8 support

parent 5cf910bf
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!65[Add] broadcasting for Arithmetic Operators
This commit is part of merge request !65. Comments created here will be created in the context of that merge request.
......@@ -185,6 +185,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int8(
{"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<std::int8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int64(
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<long>::create);
} // namespace
......
......@@ -17,6 +17,7 @@
#include <memory>
#include <numeric> // std::accumulate
#include <string>
#include <type_traits> // std::is_arithmetic
#include <vector>
#include "aidge/backend/TensorImpl.hpp"
......@@ -63,7 +64,7 @@ class Tensor : public Data,
* @brief Construct a new Tensor object from another one (shallow copy).
* Data memory is not copied, but shared between the new Tensor and the
* initial one.
*
*
* @param otherTensor
*/
Tensor(const Tensor&) = default;
......@@ -85,6 +86,17 @@ class Tensor : public Data,
return newTensor;
}
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
Tensor(T val)
: Data(Type),
mDataType(NativeType<VT>::type),
mDims({}), mStrides({1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<VT>::type})(0, 1)),
mSize(1) {
*static_cast<VT*>(mImpl->rawPtr()) = static_cast<VT>(val);
}
/**
* @brief Construct a new Tensor object from the 1-dimension Array helper.
* @tparam T datatype
......@@ -306,7 +318,7 @@ class Tensor : public Data,
/**
* @brief Set the Impl object
*
*
* @param impl New impl shared pointer
* @param implOffset Storage offset in this new impl for this Tensor
*/
......@@ -631,7 +643,7 @@ class Tensor : public Data,
* tensor is returned.
* It current tensor was contiguous, the returned tensor is garanteed to be
* contiguous as well.
*
*
* @param coordIdx Coordinates of the sub-tensor to extract
* @return Tensor Sub-tensor.
*/
......@@ -639,7 +651,7 @@ class Tensor : public Data,
/**
* Returns a sub-tensor at some coordinate and with some dimension.
*
*
* @param coordIdx First coordinates of the sub-tensor to extract
* @param dims Dimensions of the sub-tensor to extract
* @return Tensor Sub-tensor.
......@@ -704,7 +716,7 @@ class Tensor : public Data,
* The data type, backend and device stay the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* will occur if fallback has already been allocated with the right
* type/size/device.
* @return Reference to either itself or to fallback.
*/
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment