From bb712294469f5707787fd4211d27f6cb33e5aa7f Mon Sep 17 00:00:00 2001
From: Antoni Olivier <olivier.antoni@cea.fr>
Date: Tue, 18 Jun 2024 16:02:28 +0200
Subject: [PATCH] Add Tensor::sqrt()

---
 include/aidge/data/Tensor.hpp         | 16 ++++++++++++++++
 python_binding/data/pybind_Tensor.cpp |  1 +
 2 files changed, 17 insertions(+)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 5f6be6045..0b04789ed 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -28,6 +28,7 @@
 #include "aidge/operator/Div.hpp"
 #include "aidge/operator/Mul.hpp"
 #include "aidge/operator/Sub.hpp"
+#include "aidge/operator/Sqrt.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
 #include "aidge/utils/ArrayHelpers.hpp"
@@ -341,6 +342,21 @@ class Tensor : public Data,
         // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
         return div_.getOutput(0)->clone();
     }
+	
+	/**
+     * @brief Element-wise sqrt operation for Tensor.
+     * @return Tensor
+     */	
+	Tensor sqrt() const {
+        AIDGE_ASSERT(hasImpl(), "Tensor has no implementation.");
+        auto sqrt_ = Sqrt_Op();
+        sqrt_.associateInput(0, std::make_shared<Tensor>(*this));
+        sqrt_.setDataType(dataType());
+        sqrt_.setDataFormat(dataFormat());
+        sqrt_.setBackend(mImpl->backend());
+        sqrt_.forward();
+        return sqrt_.getOutput(0)->clone();		
+	}
 
     ~Tensor() noexcept;
 
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 005175ab6..72057910c 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -77,6 +77,7 @@ void init_Tensor(py::module& m){
     .def(py::self - py::self)
     .def(py::self * py::self)
     .def(py::self / py::self)
+	.def("sqrt", &Tensor::sqrt)
     .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
     .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
     .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
-- 
GitLab