diff --git a/include/aidge/operator/TopK.hpp b/include/aidge/operator/TopK.hpp index 67d40ace4a8fa54cf07d8def8942294af4b12f21..1b9a2785155879b5222795a82b2f27f23450d3b3 100644 --- a/include/aidge/operator/TopK.hpp +++ b/include/aidge/operator/TopK.hpp @@ -108,6 +108,14 @@ public: static const std::vector<std::string> getOutputsName(){ return {"values", "indices"}; } + + /** + * @brief Retrieves the names of the attributes for the operator. + * @return A vector containing the attributes name. + */ + static constexpr const char* const* attributesName(){ + return EnumStrings<Aidge::TopKAttr>::data; + } }; std::shared_ptr<Node> TopK(const std::string& name = ""); diff --git a/python_binding/operator/pybind_TopK.cpp b/python_binding/operator/pybind_TopK.cpp new file mode 100644 index 0000000000000000000000000000000000000000..314a3283baf251171904c497aa93cc9da282d0d0 --- /dev/null +++ b/python_binding/operator/pybind_TopK.cpp @@ -0,0 +1,39 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/TopK.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_TopK(py::module& m) { + py::class_<TopK_Op, std::shared_ptr<TopK_Op>, OperatorTensor>(m, "TopKOp", py::multiple_inheritance()) + .def(py::init<int64_t, bool, bool, IOIndex_t>(), py::arg("axis") = -1, py::arg("largest") = true, py::arg("sorted") = true, py::arg("k") = 0) + .def_static("get_inputs_name", &TopK_Op::getInputsName) + .def_static("get_outputs_name", &TopK_Op::getOutputsName) + .def_static("attributes_name", []() { + std::vector<std::string> result; + auto attributes = TopK_Op::attributesName(); + for (size_t i = 0; i < size(EnumStrings<TopKAttr>::data); ++i) { + result.emplace_back(attributes[i]); + } + return result; + }) + .def_readonly_static("Type", &TopK_Op::Type); + + m.def("TopK", &TopK, py::arg("name") = ""); +} + +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index de045a4c928299200089adf060d0195ee7b59c60..61a0a271c6dd23f30065f31a711d0383395f5d9d 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -98,6 +98,7 @@ void init_Squeeze(py::module&); void init_Stack(py::module&); void init_Sub(py::module&); void init_Tanh(py::module&); +void init_TopK(py::module&); void init_Transpose(py::module&); void init_Unfold(py::module&); void init_Unsqueeze(py::module&); @@ -207,6 +208,7 @@ void init_Aidge(py::module& m) { init_Stack(m); init_Sub(m); init_Tanh(m); + init_TopK(m); init_Transpose(m); init_Unfold(m); init_Unsqueeze(m); diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 8a3975ece23e7951f51be455a13a0460813f1b73..74e0cab37489c275512f5ba53290bdb5eac065e0 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -233,6 +233,8 @@ std::string Aidge::Node::outputName(Aidge::IOIndex_t outID) const { } std::string Aidge::Node::outputName(Aidge::IOIndex_t outID, std::string newName) { + AIDGE_ASSERT(outID < mIdInChildren.size(), "Output index out of bound."); + this->mOutputNames[outID] = newName; for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) { if (std::shared_ptr<Node> child = mChildren[outID][i].lock()) {