Skip to content
Snippets Groups Projects
Commit 1b16db67 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] Scalar constructor for Tensor and int8 support

parent 5cf910bf
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!65[Add] broadcasting for Arithmetic Operators
...@@ -185,6 +185,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float16( ...@@ -185,6 +185,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32( static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create); {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int8(
{"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<std::int8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int64( static Registrar<Tensor> registrarTensorImpl_cpu_Int64(
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<long>::create); {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<long>::create);
} // namespace } // namespace
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <numeric> // std::accumulate #include <numeric> // std::accumulate
#include <string> #include <string>
#include <type_traits> // std::is_arithmetic
#include <vector> #include <vector>
#include "aidge/backend/TensorImpl.hpp" #include "aidge/backend/TensorImpl.hpp"
...@@ -63,7 +64,7 @@ class Tensor : public Data, ...@@ -63,7 +64,7 @@ class Tensor : public Data,
* @brief Construct a new Tensor object from another one (shallow copy). * @brief Construct a new Tensor object from another one (shallow copy).
* Data memory is not copied, but shared between the new Tensor and the * Data memory is not copied, but shared between the new Tensor and the
* initial one. * initial one.
* *
* @param otherTensor * @param otherTensor
*/ */
Tensor(const Tensor&) = default; Tensor(const Tensor&) = default;
...@@ -85,6 +86,17 @@ class Tensor : public Data, ...@@ -85,6 +86,17 @@ class Tensor : public Data,
return newTensor; return newTensor;
} }
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
Tensor(T val)
: Data(Type),
mDataType(NativeType<VT>::type),
mDims({}), mStrides({1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<VT>::type})(0, 1)),
mSize(1) {
*static_cast<VT*>(mImpl->rawPtr()) = static_cast<VT>(val);
}
/** /**
* @brief Construct a new Tensor object from the 1-dimension Array helper. * @brief Construct a new Tensor object from the 1-dimension Array helper.
* @tparam T datatype * @tparam T datatype
...@@ -306,7 +318,7 @@ class Tensor : public Data, ...@@ -306,7 +318,7 @@ class Tensor : public Data,
/** /**
* @brief Set the Impl object * @brief Set the Impl object
* *
* @param impl New impl shared pointer * @param impl New impl shared pointer
* @param implOffset Storage offset in this new impl for this Tensor * @param implOffset Storage offset in this new impl for this Tensor
*/ */
...@@ -631,7 +643,7 @@ class Tensor : public Data, ...@@ -631,7 +643,7 @@ class Tensor : public Data,
* tensor is returned. * tensor is returned.
* It current tensor was contiguous, the returned tensor is garanteed to be * It current tensor was contiguous, the returned tensor is garanteed to be
* contiguous as well. * contiguous as well.
* *
* @param coordIdx Coordinates of the sub-tensor to extract * @param coordIdx Coordinates of the sub-tensor to extract
* @return Tensor Sub-tensor. * @return Tensor Sub-tensor.
*/ */
...@@ -639,7 +651,7 @@ class Tensor : public Data, ...@@ -639,7 +651,7 @@ class Tensor : public Data,
/** /**
* Returns a sub-tensor at some coordinate and with some dimension. * Returns a sub-tensor at some coordinate and with some dimension.
* *
* @param coordIdx First coordinates of the sub-tensor to extract * @param coordIdx First coordinates of the sub-tensor to extract
* @param dims Dimensions of the sub-tensor to extract * @param dims Dimensions of the sub-tensor to extract
* @return Tensor Sub-tensor. * @return Tensor Sub-tensor.
...@@ -704,7 +716,7 @@ class Tensor : public Data, ...@@ -704,7 +716,7 @@ class Tensor : public Data,
* The data type, backend and device stay the same. * The data type, backend and device stay the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation * The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right * will occur if fallback has already been allocated with the right
* type/size/device. * type/size/device.
* @return Reference to either itself or to fallback. * @return Reference to either itself or to fallback.
*/ */
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment