diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 5f6be6045167f6ff523876aaa309a536683810de..0b04789ed5c78dbcd424bf5cd135ce16e12cb50e 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -28,6 +28,7 @@ #include "aidge/operator/Div.hpp" #include "aidge/operator/Mul.hpp" #include "aidge/operator/Sub.hpp" +#include "aidge/operator/Sqrt.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ArrayHelpers.hpp" @@ -341,6 +342,21 @@ class Tensor : public Data, // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>; return div_.getOutput(0)->clone(); } + + /** + * @brief Element-wise sqrt operation for Tensor. + * @return Tensor + */ + Tensor sqrt() const { + AIDGE_ASSERT(hasImpl(), "Tensor has no implementation."); + auto sqrt_ = Sqrt_Op(); + sqrt_.associateInput(0, std::make_shared<Tensor>(*this)); + sqrt_.setDataType(dataType()); + sqrt_.setDataFormat(dataFormat()); + sqrt_.setBackend(mImpl->backend()); + sqrt_.forward(); + return sqrt_.getOutput(0)->clone(); + } ~Tensor() noexcept; diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 005175ab613594c48959073c4674e6d69b60b29f..72057910cddebe5d93baf27b4b768350b9cff02e 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -77,6 +77,7 @@ void init_Tensor(py::module& m){ .def(py::self - py::self) .def(py::self * py::self) .def(py::self / py::self) + .def("sqrt", &Tensor::sqrt) .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true) .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)