Skip to content
Snippets Groups Projects
Commit bb712294 authored by Olivier Antoni's avatar Olivier Antoni
Browse files

Add Tensor::sqrt()

parent 013d6349
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!142Add Tensor::sqrt()
Pipeline #48470 passed
......@@ -28,6 +28,7 @@
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ArrayHelpers.hpp"
......@@ -341,6 +342,21 @@ class Tensor : public Data,
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return div_.getOutput(0)->clone();
}
/**
* @brief Element-wise sqrt operation for Tensor.
* @return Tensor
*/
Tensor sqrt() const {
AIDGE_ASSERT(hasImpl(), "Tensor has no implementation.");
auto sqrt_ = Sqrt_Op();
sqrt_.associateInput(0, std::make_shared<Tensor>(*this));
sqrt_.setDataType(dataType());
sqrt_.setDataFormat(dataFormat());
sqrt_.setBackend(mImpl->backend());
sqrt_.forward();
return sqrt_.getOutput(0)->clone();
}
~Tensor() noexcept;
......
......@@ -77,6 +77,7 @@ void init_Tensor(py::module& m){
.def(py::self - py::self)
.def(py::self * py::self)
.def(py::self / py::self)
.def("sqrt", &Tensor::sqrt)
.def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment