From 0ba1a128dcc2f1bd5df2196d3c3d20bbe341f751 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 20 Feb 2024 11:47:50 +0100 Subject: [PATCH] Fixed binding issues --- python_binding/graph/pybind_GraphView.cpp | 5 ++++ python_binding/graph/pybind_Node.cpp | 1 + python_binding/operator/pybind_BatchNorm.cpp | 3 +++ .../operator/pybind_MetaOperatorDefs.cpp | 3 +++ python_binding/operator/pybind_Pop.cpp | 27 +++++++++++++++++++ python_binding/pybind_core.cpp | 2 ++ python_binding/scheduler/pybind_Scheduler.cpp | 2 +- src/scheduler/Scheduler.cpp | 2 +- unit_tests/scheduler/Test_Scheduler.cpp | 5 ++-- 9 files changed, 45 insertions(+), 5 deletions(-) create mode 100644 python_binding/operator/pybind_Pop.cpp diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index eb26538a5..9d7229994 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -45,6 +45,9 @@ void init_GraphView(py::module& m) { :rtype: list[Node] )mydelimiter") + .def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs")) + .def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs")) + .def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add, py::arg("other_node"), py::arg("include_learnable_parameters") = true, R"mydelimiter( @@ -118,5 +121,7 @@ void init_GraphView(py::module& m) { // } // }) ; + + m.def("get_connected_graph_view", &getConnectedGraphView); } } // namespace Aidge diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index 83f5688fa..71f5b368b 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -23,6 +23,7 @@ namespace py = pybind11; namespace Aidge { void init_Node(py::module& m) { py::class_<Node, std::shared_ptr<Node>>(m, "Node") + .def(py::init<std::shared_ptr<Operator>, const std::string&>(), py::arg("op"), py::arg("name") = "") .def("name", &Node::name, R"mydelimiter( Name of the Node. diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index 411a2e1b6..1da480856 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -22,6 +22,9 @@ namespace Aidge { template <DimSize_t DIM> void declare_BatchNormOp(py::module& m) { py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor, Attributes>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) + .def(py::init<float, float>(), + py::arg("epsilon"), + py::arg("momentum")) .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index 11c3db681..98db4652a 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -139,6 +139,9 @@ void init_MetaOperatorDefs(py::module &m) { declare_LSTMOp(m); py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance()) + .def(py::init<const char *, const std::shared_ptr<GraphView>&>(), + py::arg("type"), + py::arg("graph")) .def("get_micro_graph", &MetaOperator_Op::getMicroGraph); m.def("meta_operator", &MetaOperator, diff --git a/python_binding/operator/pybind_Pop.cpp b/python_binding/operator/pybind_Pop.cpp new file mode 100644 index 000000000..91726fc1d --- /dev/null +++ b/python_binding/operator/pybind_Pop.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Pop.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Pop(py::module& m) { + py::class_<Pop_Op, std::shared_ptr<Pop_Op>, OperatorTensor, Attributes>(m, "PopOp", py::multiple_inheritance()) + .def("get_inputs_name", &Pop_Op::getInputsName) + .def("get_outputs_name", &Pop_Op::getOutputsName); + + m.def("Pop", &Pop, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index be0d357b7..736e7a1d6 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -39,6 +39,7 @@ void init_MetaOperatorDefs(py::module&); void init_Mul(py::module&); void init_Producer(py::module&); void init_Pad(py::module&); +void init_Pop(py::module&); void init_Pow(py::module&); void init_ReduceMean(py::module&); void init_ReLU(py::module&); @@ -95,6 +96,7 @@ void init_Aidge(py::module& m){ init_Mul(m); init_Pad(m); + init_Pop(m); init_Pow(m); init_ReduceMean(m); init_ReLU(m); diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index d963b81d5..b801898dc 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -23,7 +23,7 @@ void init_Scheduler(py::module& m){ .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) .def("resetScheduling", &SequentialScheduler::resetScheduling) .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) - .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling) + .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0) ; } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d74d1980c..fcff5b8f4 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -322,7 +322,7 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa for (const auto& element : mScheduling) { std::fprintf(fp, "%s :%ld, %ld\n", - namePtrTable.find(element.node)->second.c_str(), + namePtrTable.at(element.node).c_str(), std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index cb75669f3..2932514ca 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -61,13 +61,12 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { scheduler.generateScheduling(true); const auto sch = scheduler.getStaticScheduling(); - std::map<std::shared_ptr<Node>, std::string> namePtrTable - = g1->getRankedNodesName("{0} ({1}#{3})"); + const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); std::vector<std::string> nodesName; std::transform(sch.begin(), sch.end(), std::back_inserter(nodesName), - [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }); + [&namePtrTable](auto val){ return namePtrTable.at(val).c_str(); }); fmt::print("schedule: {}\n", nodesName); REQUIRE(sch.size() == 10 + orderedInputs.size()); -- GitLab