Skip to content
Snippets Groups Projects
Commit 4cd0fd80 authored by Maxence Naud's avatar Maxence Naud
Browse files

Upd: Tensor operators and fmt display

- add compound assignment operator (+=, -=, *=, /=) with Tensor and arithmetic values
- add operator(+, -, *, /) with arithmetic values
- add friend symetric operator (+, -, *) with arithmetic values
- add fmt parsing function
parent 3f12c96d
No related branches found
No related tags found
No related merge requests found
Pipeline #60970 passed
......@@ -23,6 +23,8 @@
#include <type_traits> // std::is_arithmetic
#include <vector>
#include <fmt/core.h>
#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
......@@ -272,6 +274,17 @@ class Tensor : public Data,
* @return Tensor
*/
Tensor operator+(const Tensor& other) const;
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
Tensor operator+(T val) const { return *this + Tensor(val); }
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
friend Tensor operator+(T val, const Tensor& other) { return other + val; }
Tensor& operator+=(const Tensor& other);
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
Tensor& operator+=(T val) {return *this += Tensor(val); }
/**
* @brief Element-wise subtraction operation for two ``Tensor``s.
......@@ -283,6 +296,17 @@ class Tensor : public Data,
* @return Tensor
*/
Tensor operator-(const Tensor& other) const;
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
inline Tensor operator-(T val) const { return *this - Tensor(val); }
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
friend inline Tensor operator-(T val, const Tensor& other) { return other - val; }
Tensor& operator-=(const Tensor& other);
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
inline Tensor& operator-=(T val) {return *this -= Tensor(val); }
/**
* @brief Element-wise multiplication operation for two ``Tensor``s.
......@@ -294,6 +318,17 @@ class Tensor : public Data,
* @return Tensor
*/
Tensor operator*(const Tensor& other) const;
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
inline Tensor operator*(T val) const { return *this * Tensor(val); }
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
friend inline Tensor operator*(T val, const Tensor& other) { return other * val; }
Tensor& operator*=(const Tensor& other);
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
inline Tensor& operator*=(T val) {return *this *= Tensor(val); }
/**
* @brief Element-wise division operation for two ``Tensor``s.
......@@ -305,6 +340,14 @@ class Tensor : public Data,
* @return Tensor
*/
Tensor operator/(const Tensor& other) const;
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
inline Tensor operator/(T val) const { return *this / Tensor(val); }
Tensor& operator/=(const Tensor& other);
template<typename T,
typename VT = std::enable_if_t<std::is_arithmetic<T>::value, std::decay_t<T>>>
inline Tensor& operator/=(T val) {return *this /= Tensor(val); }
/**
* @brief Element-wise sqrt operation for Tensor.
......@@ -927,4 +970,17 @@ private:
};
} // namespace Aidge
template<>
struct fmt::formatter<Aidge::Tensor> {
template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template<typename FormatContext>
inline auto format(Aidge::Tensor const& t, FormatContext& ctx) const {
return fmt::format_to(ctx.out(), "{}", t.toString());
}
};
#endif /* AIDGE_CORE_DATA_TENSOR_H_ */
......@@ -44,7 +44,24 @@ Tensor Tensor::operator+(const Tensor& other) const {
add_.setBackend(mImpl->backend());
add_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return add_.getOutput(0)->clone();
return *add_.getOutput(0);
}
Tensor& Tensor::operator+=(const Tensor& other) {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto add_ = Add_Op();
const auto thisPtr = std::make_shared<Tensor>(*this);
add_.associateInput(0, thisPtr);
add_.associateInput(1, std::make_shared<Tensor>(other));
add_.setOutput(0, thisPtr);
add_.setDataType(dataType());
add_.setDataFormat(dataFormat());
add_.setBackend(mImpl->backend());
add_.forward();
return *this;
}
......@@ -61,7 +78,25 @@ Tensor Tensor::operator-(const Tensor& other) const {
sub_.setBackend(mImpl->backend());
sub_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return sub_.getOutput(0)->clone();
return *sub_.getOutput(0);
}
Tensor& Tensor::operator-=(const Tensor& other) {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto sub_ = Sub_Op();
const auto thisPtr = std::make_shared<Tensor>(*this);
sub_.associateInput(0, thisPtr);
sub_.associateInput(1, std::make_shared<Tensor>(other));
sub_.setOutput(0, thisPtr);
sub_.setDataType(dataType());
sub_.setDataFormat(dataFormat());
sub_.setBackend(mImpl->backend());
sub_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return *this;
}
......@@ -81,6 +116,24 @@ Tensor Tensor::operator*(const Tensor& other) const {
return mul_.getOutput(0)->clone();
}
Tensor& Tensor::operator*=(const Tensor& other) {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto mul_ = Mul_Op();
const auto thisPtr = std::make_shared<Tensor>(*this);
mul_.associateInput(0, thisPtr);
mul_.associateInput(1, std::make_shared<Tensor>(other));
mul_.setOutput(0, thisPtr);
mul_.setDataType(dataType());
mul_.setDataFormat(dataFormat());
mul_.setBackend(mImpl->backend());
mul_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return *this;
}
Tensor Tensor::operator/(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
......@@ -98,6 +151,24 @@ Tensor Tensor::operator/(const Tensor& other) const {
return div_.getOutput(0)->clone();
}
Tensor& Tensor::operator/=(const Tensor& other) {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto div_ = Div_Op();
const auto thisPtr = std::make_shared<Tensor>(*this);
div_.associateInput(0, thisPtr);
div_.associateInput(1, std::make_shared<Tensor>(other));
div_.setOutput(0, thisPtr);
div_.setDataType(dataType());
div_.setDataFormat(dataFormat());
div_.setBackend(mImpl->backend());
div_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return *this;
}
Tensor Tensor::sqrt() const {
AIDGE_ASSERT(hasImpl(), "Tensor has no implementation.");
auto sqrt_ = Sqrt_Op();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment