diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 393e640d60934059a9c216a9335a7018388fe9da..39d2765c316445f28a1daf1687f3daefaad5a802 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -29,7 +29,7 @@ namespace Aidge { * @brief Description of a Fully Connected (FC) operation on an input Tensor. * * The Fully Connected (FC) operation applies a linear transformation to the input Tensor - * by multiplying it with a weight matrix and optionally adding a bias vector: + * by multiplying it with a weight matrix and optionally adding a bias vector: * - If `bias` is included: * f(x) = x × weights^T + bias * - If `bias` is omitted: @@ -74,7 +74,7 @@ public: * * Copies the attributes and output tensor(s) of the operator, but does not * copy input tensors. The new operator instance has no associated inputs. - * + * * @param op The `FC_Op` instance to copy. */ FC_Op(const FC_Op& op) @@ -114,6 +114,12 @@ public: */ bool forwardDims(bool allowDataDependency = false) override final; + /** + * @brief Forward the data type. + * @return True if successful, false otherwise. + */ + bool forwardDType() override final; + /** * @brief Sets the backend for the operator. * diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index dd3ed7aba65cf1875d691d9bc2c8c94bb03856c7..07208b5221326eaf1c0cfd8829c97dc4543c659b 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -40,6 +40,19 @@ void Aidge::FC_Op::associateInput(const Aidge::IOIndex_t inputIdx, const std::sh mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()}); } +bool Aidge::FC_Op::forwardDType(){ + // Current naive forwarDType based on bias. + // Bias is optional so this will not always work + // But is good enough for now. + // Feel free to upgrade the function! + if (getInput(2)) { + mOutputs[0]->setDataType(getInput(2)->dataType()); + return true; + } + Log::notice("FC_Op: No bias associated, failed to forward data type."); + return false; +} + bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { if (inputsAssociated()) { // first check weight since it defines inChannels and outChannels