Skip to content
Snippets Groups Projects
Commit 0ba1a128 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed binding issues

parent 54348a20
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
......@@ -45,6 +45,9 @@ void init_GraphView(py::module& m) {
:rtype: list[Node]
)mydelimiter")
.def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs"))
.def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs"))
.def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add,
py::arg("other_node"), py::arg("include_learnable_parameters") = true,
R"mydelimiter(
......@@ -118,5 +121,7 @@ void init_GraphView(py::module& m) {
// }
// })
;
m.def("get_connected_graph_view", &getConnectedGraphView);
}
} // namespace Aidge
......@@ -23,6 +23,7 @@ namespace py = pybind11;
namespace Aidge {
void init_Node(py::module& m) {
py::class_<Node, std::shared_ptr<Node>>(m, "Node")
.def(py::init<std::shared_ptr<Operator>, const std::string&>(), py::arg("op"), py::arg("name") = "")
.def("name", &Node::name,
R"mydelimiter(
Name of the Node.
......
......@@ -22,6 +22,9 @@ namespace Aidge {
template <DimSize_t DIM>
void declare_BatchNormOp(py::module& m) {
py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor, Attributes>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance())
.def(py::init<float, float>(),
py::arg("epsilon"),
py::arg("momentum"))
.def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName)
.def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName);
......
......@@ -139,6 +139,9 @@ void init_MetaOperatorDefs(py::module &m) {
declare_LSTMOp(m);
py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance())
.def(py::init<const char *, const std::shared_ptr<GraphView>&>(),
py::arg("type"),
py::arg("graph"))
.def("get_micro_graph", &MetaOperator_Op::getMicroGraph);
m.def("meta_operator", &MetaOperator,
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Pop.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Pop(py::module& m) {
py::class_<Pop_Op, std::shared_ptr<Pop_Op>, OperatorTensor, Attributes>(m, "PopOp", py::multiple_inheritance())
.def("get_inputs_name", &Pop_Op::getInputsName)
.def("get_outputs_name", &Pop_Op::getOutputsName);
m.def("Pop", &Pop, py::arg("name") = "");
}
} // namespace Aidge
......@@ -39,6 +39,7 @@ void init_MetaOperatorDefs(py::module&);
void init_Mul(py::module&);
void init_Producer(py::module&);
void init_Pad(py::module&);
void init_Pop(py::module&);
void init_Pow(py::module&);
void init_ReduceMean(py::module&);
void init_ReLU(py::module&);
......@@ -95,6 +96,7 @@ void init_Aidge(py::module& m){
init_Mul(m);
init_Pad(m);
init_Pop(m);
init_Pow(m);
init_ReduceMean(m);
init_ReLU(m);
......
......@@ -23,7 +23,7 @@ void init_Scheduler(py::module& m){
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
.def("get_static_scheduling", &SequentialScheduler::getStaticScheduling)
.def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0)
;
}
}
......
......@@ -322,7 +322,7 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa
for (const auto& element : mScheduling) {
std::fprintf(fp, "%s :%ld, %ld\n",
namePtrTable.find(element.node)->second.c_str(),
namePtrTable.at(element.node).c_str(),
std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
}
......
......@@ -61,13 +61,12 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
scheduler.generateScheduling(true);
const auto sch = scheduler.getStaticScheduling();
std::map<std::shared_ptr<Node>, std::string> namePtrTable
= g1->getRankedNodesName("{0} ({1}#{3})");
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform(sch.begin(), sch.end(),
std::back_inserter(nodesName),
[&namePtrTable](auto val){ return namePtrTable[val].c_str(); });
[&namePtrTable](auto val){ return namePtrTable.at(val).c_str(); });
fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == 10 + orderedInputs.size());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment