Skip to content
Snippets Groups Projects

Draft: Fix import tests

Open Houssem ROUIS requested to merge fix_import_tests into dev
3 files
+ 53
8
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -25,6 +25,11 @@
namespace Aidge {
enum class FCAttr {
Alpha, // The scalar multiplier for the product of input tensors A * B.
Beta, // The scalar multiplier for the bias.
};
/**
* @brief Description of a Fully Connected (FC) operation on an input Tensor.
*
@@ -54,6 +59,15 @@ class FC_Op : public OperatorTensor,
public Registrable<FC_Op,
std::string,
std::function<std::shared_ptr<OperatorImpl>(const FC_Op &)>> {
private:
using Attributes_ = StaticAttributes<FCAttr,
float,
float>;
template <FCAttr e>
using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
/**
* @brief Static type identifier for the FC operator.
@@ -65,8 +79,11 @@ public:
*
* Initializes the operator with a type identifier and input categories.
*/
FC_Op()
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1)
FC_Op(float alpha = 1.0f, float beta = 1.0f)
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<FCAttr::Alpha>(alpha),
attr<FCAttr::Beta>(beta)))
{}
/**
@@ -160,6 +177,24 @@ public:
return getInput(1)->template dims<2>()[0];
}
/**
* @brief Get the attributes of the operator.
* @return A shared pointer to the operator's attributes.
*/
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
/**
* @brief Get the alpha coefficient.
* @return The alpha coefficient.
*/
inline float& alpha() const { return mAttributes->template getAttr<FCAttr::Alpha>(); }
/**
* @brief Get the beta coefficient.
* @return The beta coefficient.
*/
inline float& beta() const { return mAttributes->template getAttr<FCAttr::Beta>(); }
/**
* @brief Retrieves the input tensor names for the FC operator.
* @return A vector of input tensor names: `{"data_input", "weight", "bias"}`.
@@ -180,16 +215,20 @@ public:
/**
* @brief Creates a Fully Connected operation node.
*
* Constructs an FC operator node with the specified input and output channels.
*
* @param[in] inChannels Number of input channels.
* @param[in] outChannels Number of output channels.
* @param[in] alpha Scalar multiplier for the product of input tensors A * B.
* @param[in] beta Scalar multiplier for the bias.
* @param[in] noBias Flag indicating whether to use a bias term (default is `false`).
* @param[in] name Name of the operator (optional).
* @return A shared pointer to the Node containing the FC operator.
*/
std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, bool noBias = false, const std::string& name = "");
std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, bool noBias = false, const std::string& name = "", float alpha = 1.0f, float beta = 1.0f);
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::FCAttr>::data[] = {"alpha", "beta"};
}
#endif /* AIDGE_CORE_OPERATOR_FC_H_ */
Loading