diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 078ff0fde84f70a632967e0bc467af2db6575f49..373c7da29596b163eae21c1235d6db578b755a5f 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -238,19 +238,19 @@ public: * @param[in] alpha Scalar multiplier for the product of input tensors A * B. * @param[in] beta Scalar multiplier for the bias. * @param[in] noBias Flag indicating whether to use a bias term (default is `false`). - * @param[in] name Name of the operator (optional). * @param[in] transA Flag indicating whether input#0 needs to be transposed (default is `false`). * @param[in] transB Flag indicating whether input#1 needs to be transposed (default is `false`). + * @param[in] name Name of the operator (optional). * @return A shared pointer to the Node containing the FC operator. */ std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, - bool noBias = false, - const std::string& name = "", float alpha = 1.0f, float beta = 1.0f, + bool noBias = false, bool transA = false, - bool transB = false); + bool transB = false, + const std::string& name = ""); } // namespace Aidge diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index 40433eb51b3a76b682f1503b271f334c130f77bf..dc3f738fbf5613edf4212c15e89df4624084cb39 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -47,7 +47,15 @@ void declare_FC(py::module &m) { declare_registrable<FC_Op>(m, "FCOp"); - m.def("FC", &FC, py::arg("in_channels"), py::arg("out_channels"), py::arg("no_bias") = false, py::arg("name") = "", py::arg("alpha")=1.0f, py::arg("beta")=1.0f, py::arg("transA") = false, py::arg("transB") = false, + m.def("FC", &FC, + py::arg("in_channels"), + py::arg("out_channels"), + py::arg("alpha")=1.0f, + py::arg("beta")=1.0f, + py::arg("no_bias") = false, + py::arg("transA") = false, + py::arg("transB") = false, + py::arg("name") = "", R"mydelimiter( Initialize a node containing a Fully Connected (FC) operator. diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index 54f28507b906eb435dbdb9d2cec92b71c813b760..13da22423ec1c5be748461f4518db87dc11f4fa6 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -103,12 +103,12 @@ std::set<std::string> Aidge::FC_Op::getAvailableBackends() const { std::shared_ptr<Aidge::Node> Aidge::FC(const Aidge::DimSize_t inChannels, const Aidge::DimSize_t outChannels, - bool noBias, - const std::string& name, float alpha, float beta, + bool noBias, bool transA, - bool transB) { + bool transB, + const std::string& name) { // FIXME: properly handle default w&b initialization in every cases auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta, transA, transB), name); addProducer(fc, 1, {outChannels, inChannels}, "w"); diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 582c73565a4ef7bfc96e493e1e6029b1683676ab..bd684f9ea1c951396cb186810e6adc388622e0a9 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -357,9 +357,9 @@ TEST_CASE("[core/graph] Matching") { ReLU("relu2"), Conv(4, 4, {5, 5}, "conv3"), BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"), - FC(4, 4, false, "fc1"), - FC(4, 4, false, "fc2"), - FC(4, 4, false, "fc3"), + FC(4, 4, 1.0, 1.0, false, false, false, "fc1"), + FC(4, 4, 1.0, 1.0, false, false, false, "fc2"), + FC(4, 4, 1.0, 1.0, false, false, false, "fc3"), ReLU("relu3"), Conv(1, 4, {5, 5}, "conv4") }); diff --git a/unit_tests/recipes/Test_ToGenericOp.cpp b/unit_tests/recipes/Test_ToGenericOp.cpp index cb75fdb1072dee476c88c1f6d502a792b2e6abd9..02d784385ee18ceb495fd1e8a2f25ed161b4fee0 100644 --- a/unit_tests/recipes/Test_ToGenericOp.cpp +++ b/unit_tests/recipes/Test_ToGenericOp.cpp @@ -32,9 +32,9 @@ TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") { ReLU(), Conv(4, 3, {1, 1}, "conv3"), ReLU(), - FC(2028, 256, false, "fc1"), + FC(2028, 256, 1.0, 1.0, false, false, false, "fc1"), ReLU(), - FC(256, 10, false, "fc2")}); + FC(256, 10, 1.0, 1.0, false, false, false, "fc2")}); // NCHW - MNIST DATA like g->forwardDims({{5, 1, 28, 28}}); diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp index 1b5e2783813da890b1e79744582f54bb5c932772..655f7c7f5992902f7d73dd310f4a323d0e1eadce 100644 --- a/unit_tests/recipes/Test_removeFlatten.cpp +++ b/unit_tests/recipes/Test_removeFlatten.cpp @@ -27,8 +27,8 @@ namespace Aidge { TEST_CASE("[cpu/recipes] RemoveFlatten", "[RemoveFlatten][recipes]") { std::shared_ptr<Node> flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten"); - std::shared_ptr<Node> fc0 = FC(10, 10, false, "FC_1"); - std::shared_ptr<Node> fc1 = FC(10, 10, false, "FC_2"); + std::shared_ptr<Node> fc0 = FC(10, 10, 1.0, 1.0, false, false, false, "FC_1"); + std::shared_ptr<Node> fc1 = FC(10, 10, 1.0, 1.0, false, false, false, "FC_2"); std::shared_ptr<Node> prod = Producer(std::array<DimSize_t, 10>(), "myProd"); SECTION("flatten last layer : nothing removed because pattern searched is "