diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 9031d19f30e9f634269e71c0ced9d7782d801054..627a5a4784b4e6546cdfc96b65acbe2a39ee119c 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -23,6 +23,8 @@ #include <type_traits> // std::is_arithmetic #include <vector> +#include <fmt/core.h> + #include "aidge/backend/TensorImpl.hpp" #include "aidge/data/Data.hpp" #include "aidge/utils/ArrayHelpers.hpp" @@ -272,6 +274,17 @@ class Tensor : public Data, * @return Tensor */ Tensor operator+(const Tensor& other) const; + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + Tensor operator+(T val) const { return *this + Tensor(val); } + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + friend Tensor operator+(T val, const Tensor& other) { return other + val; } + + Tensor& operator+=(const Tensor& other); + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + Tensor& operator+=(T val) {return *this += Tensor(val); } /** * @brief Element-wise subtraction operation for two ``Tensor``s. @@ -283,6 +296,17 @@ class Tensor : public Data, * @return Tensor */ Tensor operator-(const Tensor& other) const; + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + inline Tensor operator-(T val) const { return *this - Tensor(val); } + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + friend inline Tensor operator-(T val, const Tensor& other) { return other - val; } + + Tensor& operator-=(const Tensor& other); + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + inline Tensor& operator-=(T val) {return *this -= Tensor(val); } /** * @brief Element-wise multiplication operation for two ``Tensor``s. @@ -294,6 +318,17 @@ class Tensor : public Data, * @return Tensor */ Tensor operator*(const Tensor& other) const; + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + inline Tensor operator*(T val) const { return *this * Tensor(val); } + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + friend inline Tensor operator*(T val, const Tensor& other) { return other * val; } + + Tensor& operator*=(const Tensor& other); + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + inline Tensor& operator*=(T val) {return *this *= Tensor(val); } /** * @brief Element-wise division operation for two ``Tensor``s. @@ -305,6 +340,14 @@ class Tensor : public Data, * @return Tensor */ Tensor operator/(const Tensor& other) const; + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + inline Tensor operator/(T val) const { return *this / Tensor(val); } + + Tensor& operator/=(const Tensor& other); + template<typename T, + typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>> + inline Tensor& operator/=(T val) {return *this /= Tensor(val); } /** * @brief Element-wise sqrt operation for Tensor. @@ -927,4 +970,17 @@ private: }; } // namespace Aidge +template<> +struct fmt::formatter<Aidge::Tensor> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(Aidge::Tensor const& t, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{}", t.toString()); + } +}; + #endif /* AIDGE_CORE_DATA_TENSOR_H_ */ diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index d54806948246ec97384ffcb4a43cc3c3428ce352..c834167abe15fb8a7ce96053a87a958b7515fe17 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -44,7 +44,24 @@ Tensor Tensor::operator+(const Tensor& other) const { add_.setBackend(mImpl->backend()); add_.forward(); // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>; - return add_.getOutput(0)->clone(); + return *add_.getOutput(0); +} + +Tensor& Tensor::operator+=(const Tensor& other) { + AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation."); + AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend"); + AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type"); + AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format"); + auto add_ = Add_Op(); + const auto thisPtr = std::make_shared<Tensor>(*this); + add_.associateInput(0, thisPtr); + add_.associateInput(1, std::make_shared<Tensor>(other)); + add_.setOutput(0, thisPtr); + add_.setDataType(dataType()); + add_.setDataFormat(dataFormat()); + add_.setBackend(mImpl->backend()); + add_.forward(); + return *this; } @@ -61,7 +78,25 @@ Tensor Tensor::operator-(const Tensor& other) const { sub_.setBackend(mImpl->backend()); sub_.forward(); // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>; - return sub_.getOutput(0)->clone(); + return *sub_.getOutput(0); +} + +Tensor& Tensor::operator-=(const Tensor& other) { + AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation."); + AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend"); + AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type"); + AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format"); + auto sub_ = Sub_Op(); + const auto thisPtr = std::make_shared<Tensor>(*this); + sub_.associateInput(0, thisPtr); + sub_.associateInput(1, std::make_shared<Tensor>(other)); + sub_.setOutput(0, thisPtr); + sub_.setDataType(dataType()); + sub_.setDataFormat(dataFormat()); + sub_.setBackend(mImpl->backend()); + sub_.forward(); + // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>; + return *this; } @@ -81,6 +116,24 @@ Tensor Tensor::operator*(const Tensor& other) const { return mul_.getOutput(0)->clone(); } +Tensor& Tensor::operator*=(const Tensor& other) { + AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation."); + AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend"); + AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type"); + AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format"); + auto mul_ = Mul_Op(); + const auto thisPtr = std::make_shared<Tensor>(*this); + mul_.associateInput(0, thisPtr); + mul_.associateInput(1, std::make_shared<Tensor>(other)); + mul_.setOutput(0, thisPtr); + mul_.setDataType(dataType()); + mul_.setDataFormat(dataFormat()); + mul_.setBackend(mImpl->backend()); + mul_.forward(); + // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>; + return *this; +} + Tensor Tensor::operator/(const Tensor& other) const { AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation."); @@ -98,6 +151,24 @@ Tensor Tensor::operator/(const Tensor& other) const { return div_.getOutput(0)->clone(); } +Tensor& Tensor::operator/=(const Tensor& other) { + AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation."); + AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend"); + AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type"); + AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format"); + auto div_ = Div_Op(); + const auto thisPtr = std::make_shared<Tensor>(*this); + div_.associateInput(0, thisPtr); + div_.associateInput(1, std::make_shared<Tensor>(other)); + div_.setOutput(0, thisPtr); + div_.setDataType(dataType()); + div_.setDataFormat(dataFormat()); + div_.setBackend(mImpl->backend()); + div_.forward(); + // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>; + return *this; +} + Tensor Tensor::sqrt() const { AIDGE_ASSERT(hasImpl(), "Tensor has no implementation."); auto sqrt_ = Sqrt_Op();