Skip to content
Snippets Groups Projects
Commit 32b1a744 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added binding for TopK

parent 6a0808ed
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!377Add TopK operator
Pipeline #69116 failed
...@@ -108,6 +108,14 @@ public: ...@@ -108,6 +108,14 @@ public:
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"values", "indices"}; return {"values", "indices"};
} }
/**
* @brief Retrieves the names of the attributes for the operator.
* @return A vector containing the attributes name.
*/
static constexpr const char* const* attributesName(){
return EnumStrings<Aidge::TopKAttr>::data;
}
}; };
std::shared_ptr<Node> TopK(const std::string& name = ""); std::shared_ptr<Node> TopK(const std::string& name = "");
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/TopK.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_TopK(py::module& m) {
py::class_<TopK_Op, std::shared_ptr<TopK_Op>, OperatorTensor>(m, "TopKOp", py::multiple_inheritance())
.def(py::init<int64_t, bool, bool, IOIndex_t>(), py::arg("axis") = -1, py::arg("largest") = true, py::arg("sorted") = true, py::arg("k") = 0)
.def_static("get_inputs_name", &TopK_Op::getInputsName)
.def_static("get_outputs_name", &TopK_Op::getOutputsName)
.def_static("attributes_name", []() {
std::vector<std::string> result;
auto attributes = TopK_Op::attributesName();
for (size_t i = 0; i < size(EnumStrings<TopKAttr>::data); ++i) {
result.emplace_back(attributes[i]);
}
return result;
})
.def_readonly_static("Type", &TopK_Op::Type);
m.def("TopK", &TopK, py::arg("name") = "");
}
} // namespace Aidge
...@@ -98,6 +98,7 @@ void init_Squeeze(py::module&); ...@@ -98,6 +98,7 @@ void init_Squeeze(py::module&);
void init_Stack(py::module&); void init_Stack(py::module&);
void init_Sub(py::module&); void init_Sub(py::module&);
void init_Tanh(py::module&); void init_Tanh(py::module&);
void init_TopK(py::module&);
void init_Transpose(py::module&); void init_Transpose(py::module&);
void init_Unfold(py::module&); void init_Unfold(py::module&);
void init_Unsqueeze(py::module&); void init_Unsqueeze(py::module&);
...@@ -207,6 +208,7 @@ void init_Aidge(py::module& m) { ...@@ -207,6 +208,7 @@ void init_Aidge(py::module& m) {
init_Stack(m); init_Stack(m);
init_Sub(m); init_Sub(m);
init_Tanh(m); init_Tanh(m);
init_TopK(m);
init_Transpose(m); init_Transpose(m);
init_Unfold(m); init_Unfold(m);
init_Unsqueeze(m); init_Unsqueeze(m);
......
...@@ -233,6 +233,8 @@ std::string Aidge::Node::outputName(Aidge::IOIndex_t outID) const { ...@@ -233,6 +233,8 @@ std::string Aidge::Node::outputName(Aidge::IOIndex_t outID) const {
} }
std::string Aidge::Node::outputName(Aidge::IOIndex_t outID, std::string newName) { std::string Aidge::Node::outputName(Aidge::IOIndex_t outID, std::string newName) {
AIDGE_ASSERT(outID < mIdInChildren.size(), "Output index out of bound.");
this->mOutputNames[outID] = newName; this->mOutputNames[outID] = newName;
for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) { for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) {
if (std::shared_ptr<Node> child = mChildren[outID][i].lock()) { if (std::shared_ptr<Node> child = mChildren[outID][i].lock()) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment