Skip to content
Snippets Groups Projects
Commit 9ab835a7 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add forwardDType to FC operator.

parent 0437b28a
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!363Add first version of forwardDType.
...@@ -29,7 +29,7 @@ namespace Aidge { ...@@ -29,7 +29,7 @@ namespace Aidge {
* @brief Description of a Fully Connected (FC) operation on an input Tensor. * @brief Description of a Fully Connected (FC) operation on an input Tensor.
* *
* The Fully Connected (FC) operation applies a linear transformation to the input Tensor * The Fully Connected (FC) operation applies a linear transformation to the input Tensor
* by multiplying it with a weight matrix and optionally adding a bias vector: * by multiplying it with a weight matrix and optionally adding a bias vector:
* - If `bias` is included: * - If `bias` is included:
* f(x) = x × weights^T + bias * f(x) = x × weights^T + bias
* - If `bias` is omitted: * - If `bias` is omitted:
...@@ -74,7 +74,7 @@ public: ...@@ -74,7 +74,7 @@ public:
* *
* Copies the attributes and output tensor(s) of the operator, but does not * Copies the attributes and output tensor(s) of the operator, but does not
* copy input tensors. The new operator instance has no associated inputs. * copy input tensors. The new operator instance has no associated inputs.
* *
* @param op The `FC_Op` instance to copy. * @param op The `FC_Op` instance to copy.
*/ */
FC_Op(const FC_Op& op) FC_Op(const FC_Op& op)
...@@ -114,6 +114,12 @@ public: ...@@ -114,6 +114,12 @@ public:
*/ */
bool forwardDims(bool allowDataDependency = false) override final; bool forwardDims(bool allowDataDependency = false) override final;
/**
* @brief Forward the data type.
* @return True if successful, false otherwise.
*/
bool forwardDType() override final;
/** /**
* @brief Sets the backend for the operator. * @brief Sets the backend for the operator.
* *
......
...@@ -40,6 +40,19 @@ void Aidge::FC_Op::associateInput(const Aidge::IOIndex_t inputIdx, const std::sh ...@@ -40,6 +40,19 @@ void Aidge::FC_Op::associateInput(const Aidge::IOIndex_t inputIdx, const std::sh
mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()}); mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()});
} }
bool Aidge::FC_Op::forwardDType(){
// Current naive forwarDType based on bias.
// Bias is optional so this will not always work
// But is good enough for now.
// Feel free to upgrade the function!
if (getInput(2)) {
mOutputs[0]->setDataType(getInput(2)->dataType());
return true;
}
Log::notice("FC_Op: No bias associated, failed to forward data type.");
return false;
}
bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) { if (inputsAssociated()) {
// first check weight since it defines inChannels and outChannels // first check weight since it defines inChannels and outChannels
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment