From dbf80ab9989cdf298b53f94b05ab2c6ec5bf349e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 28 Mar 2025 12:09:30 +0100 Subject: [PATCH] Fixed 2nd output handling --- include/aidge/operator/TopK.hpp | 1 + src/operator/TopK.cpp | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/include/aidge/operator/TopK.hpp b/include/aidge/operator/TopK.hpp index 1b9a27851..e1aa193bb 100644 --- a/include/aidge/operator/TopK.hpp +++ b/include/aidge/operator/TopK.hpp @@ -94,6 +94,7 @@ public: bool forwardDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + void setDataType(const DataType& dataType) const override final; std::set<std::string> getAvailableBackends() const override; inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } diff --git a/src/operator/TopK.cpp b/src/operator/TopK.cpp index 6f942a639..52bb37554 100644 --- a/src/operator/TopK.cpp +++ b/src/operator/TopK.cpp @@ -78,6 +78,8 @@ bool Aidge::TopK_Op::forwardDims(bool allowDataDependency) { const auto kAxis = (axis() >= 0) ? axis() : axis() + static_cast<std::int8_t>(outDims.size()); outDims[kAxis] = k(); mOutputs[0]->resize(outDims); + mOutputs[1]->resize(outDims); + mOutputs[1]->setDataType(DataType::Int64); return true; } @@ -87,6 +89,12 @@ bool Aidge::TopK_Op::forwardDims(bool allowDataDependency) { void Aidge::TopK_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { SET_IMPL_MACRO(TopK_Op, *this, name); mOutputs[0]->setBackend(name, device); + mOutputs[1]->setBackend(name, device); +} + +void Aidge::TopK_Op::setDataType(const DataType& dataType) const { + mOutputs[0]->setDataType(dataType); + // mOutputs[1] data type is fixed (Int64) } std::set<std::string> Aidge::TopK_Op::getAvailableBackends() const { -- GitLab