From 9ab835a7fab3ffe7eb98535d4d10efdde325f140 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Sun, 23 Mar 2025 19:11:17 +0000 Subject: [PATCH] Add forwardDType to FC operator. --- include/aidge/operator/FC.hpp | 10 ++++++++-- src/operator/FC.cpp | 13 +++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 393e640d6..39d2765c3 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -29,7 +29,7 @@ namespace Aidge { * @brief Description of a Fully Connected (FC) operation on an input Tensor. * * The Fully Connected (FC) operation applies a linear transformation to the input Tensor - * by multiplying it with a weight matrix and optionally adding a bias vector: + * by multiplying it with a weight matrix and optionally adding a bias vector: * - If `bias` is included: * f(x) = x × weights^T + bias * - If `bias` is omitted: @@ -74,7 +74,7 @@ public: * * Copies the attributes and output tensor(s) of the operator, but does not * copy input tensors. The new operator instance has no associated inputs. - * + * * @param op The `FC_Op` instance to copy. */ FC_Op(const FC_Op& op) @@ -114,6 +114,12 @@ public: */ bool forwardDims(bool allowDataDependency = false) override final; + /** + * @brief Forward the data type. + * @return True if successful, false otherwise. + */ + bool forwardDType() override final; + /** * @brief Sets the backend for the operator. * diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index dd3ed7aba..07208b522 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -40,6 +40,19 @@ void Aidge::FC_Op::associateInput(const Aidge::IOIndex_t inputIdx, const std::sh mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()}); } +bool Aidge::FC_Op::forwardDType(){ + // Current naive forwarDType based on bias. + // Bias is optional so this will not always work + // But is good enough for now. + // Feel free to upgrade the function! + if (getInput(2)) { + mOutputs[0]->setDataType(getInput(2)->dataType()); + return true; + } + Log::notice("FC_Op: No bias associated, failed to forward data type."); + return false; +} + bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { if (inputsAssociated()) { // first check weight since it defines inChannels and outChannels -- GitLab