diff --git a/include/aidge/operator/TopK.hpp b/include/aidge/operator/TopK.hpp index 1b9a2785155879b5222795a82b2f27f23450d3b3..e1aa193bb0b7720fce0d1161d3a352f2e8109324 100644 --- a/include/aidge/operator/TopK.hpp +++ b/include/aidge/operator/TopK.hpp @@ -94,6 +94,7 @@ public: bool forwardDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + void setDataType(const DataType& dataType) const override final; std::set<std::string> getAvailableBackends() const override; inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } diff --git a/src/operator/TopK.cpp b/src/operator/TopK.cpp index 6f942a639cb6890147bd1a03af5e47eb7686646c..52bb3755431b32cc3d30a85507e0e3fa22e0250c 100644 --- a/src/operator/TopK.cpp +++ b/src/operator/TopK.cpp @@ -78,6 +78,8 @@ bool Aidge::TopK_Op::forwardDims(bool allowDataDependency) { const auto kAxis = (axis() >= 0) ? axis() : axis() + static_cast<std::int8_t>(outDims.size()); outDims[kAxis] = k(); mOutputs[0]->resize(outDims); + mOutputs[1]->resize(outDims); + mOutputs[1]->setDataType(DataType::Int64); return true; } @@ -87,6 +89,12 @@ bool Aidge::TopK_Op::forwardDims(bool allowDataDependency) { void Aidge::TopK_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { SET_IMPL_MACRO(TopK_Op, *this, name); mOutputs[0]->setBackend(name, device); + mOutputs[1]->setBackend(name, device); +} + +void Aidge::TopK_Op::setDataType(const DataType& dataType) const { + mOutputs[0]->setDataType(dataType); + // mOutputs[1] data type is fixed (Int64) } std::set<std::string> Aidge::TopK_Op::getAvailableBackends() const {