diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index eb26538a5db1eb40fdcb8a2e409067483d4a7d68..9d7229994c8a9e1e5eb1d2bab694641ce3981c4b 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -45,6 +45,9 @@ void init_GraphView(py::module& m) { :rtype: list[Node] )mydelimiter") + .def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs")) + .def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs")) + .def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add, py::arg("other_node"), py::arg("include_learnable_parameters") = true, R"mydelimiter( @@ -118,5 +121,7 @@ void init_GraphView(py::module& m) { // } // }) ; + + m.def("get_connected_graph_view", &getConnectedGraphView); } } // namespace Aidge diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index 83f5688fa3d9e459a364ee3e74975a23d09c236c..71f5b368bcd358231439dead96ade266b313cf6c 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -23,6 +23,7 @@ namespace py = pybind11; namespace Aidge { void init_Node(py::module& m) { py::class_<Node, std::shared_ptr<Node>>(m, "Node") + .def(py::init<std::shared_ptr<Operator>, const std::string&>(), py::arg("op"), py::arg("name") = "") .def("name", &Node::name, R"mydelimiter( Name of the Node. diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index 411a2e1b6ae78065a79b92f25c23dac13e341997..1da4808568df6d5a8eab559c67cd1a95555233b5 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -22,6 +22,9 @@ namespace Aidge { template <DimSize_t DIM> void declare_BatchNormOp(py::module& m) { py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor, Attributes>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) + .def(py::init<float, float>(), + py::arg("epsilon"), + py::arg("momentum")) .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index 11c3db681ea15f5413b286e2dc1eeef68ecafd31..98db4652a50329f02e4f7cace6072ffb46c1147d 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -139,6 +139,9 @@ void init_MetaOperatorDefs(py::module &m) { declare_LSTMOp(m); py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance()) + .def(py::init<const char *, const std::shared_ptr<GraphView>&>(), + py::arg("type"), + py::arg("graph")) .def("get_micro_graph", &MetaOperator_Op::getMicroGraph); m.def("meta_operator", &MetaOperator, diff --git a/python_binding/operator/pybind_Pop.cpp b/python_binding/operator/pybind_Pop.cpp new file mode 100644 index 0000000000000000000000000000000000000000..91726fc1d4721df1be712a26721d09b1a98fd9a2 --- /dev/null +++ b/python_binding/operator/pybind_Pop.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Pop.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Pop(py::module& m) { + py::class_<Pop_Op, std::shared_ptr<Pop_Op>, OperatorTensor, Attributes>(m, "PopOp", py::multiple_inheritance()) + .def("get_inputs_name", &Pop_Op::getInputsName) + .def("get_outputs_name", &Pop_Op::getOutputsName); + + m.def("Pop", &Pop, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index be0d357b7f73e26aad44994f407696f70617ad71..736e7a1d62164bacb13ed12edaab760ff24e30f6 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -39,6 +39,7 @@ void init_MetaOperatorDefs(py::module&); void init_Mul(py::module&); void init_Producer(py::module&); void init_Pad(py::module&); +void init_Pop(py::module&); void init_Pow(py::module&); void init_ReduceMean(py::module&); void init_ReLU(py::module&); @@ -95,6 +96,7 @@ void init_Aidge(py::module& m){ init_Mul(m); init_Pad(m); + init_Pop(m); init_Pow(m); init_ReduceMean(m); init_ReLU(m); diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index d963b81d501f5cd2faf4f69810c897bb4b4da86d..b801898dcd251bdf1976eba1941407965c7153b6 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -23,7 +23,7 @@ void init_Scheduler(py::module& m){ .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) .def("resetScheduling", &SequentialScheduler::resetScheduling) .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) - .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling) + .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0) ; } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d74d1980cfe3872afb2d245f6720ad5ea4a39438..fcff5b8f43440229636bc65be8100d706a74d177 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -322,7 +322,7 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa for (const auto& element : mScheduling) { std::fprintf(fp, "%s :%ld, %ld\n", - namePtrTable.find(element.node)->second.c_str(), + namePtrTable.at(element.node).c_str(), std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index cb75669f382b4352492ccbf22f9c918bcbe18033..2932514ca3a51c1c74eb583133b9a0d3557e8b3a 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -61,13 +61,12 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { scheduler.generateScheduling(true); const auto sch = scheduler.getStaticScheduling(); - std::map<std::shared_ptr<Node>, std::string> namePtrTable - = g1->getRankedNodesName("{0} ({1}#{3})"); + const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); std::vector<std::string> nodesName; std::transform(sch.begin(), sch.end(), std::back_inserter(nodesName), - [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }); + [&namePtrTable](auto val){ return namePtrTable.at(val).c_str(); }); fmt::print("schedule: {}\n", nodesName); REQUIRE(sch.size() == 10 + orderedInputs.size());