Skip to content
Snippets Groups Projects
Commit 4d82363f authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'tiling' into 'main'

byding the first tiling prototype

See merge request eclipse/aidge/aidge_core!58
parents 6da29a40 92e7646b
No related branches found
No related tags found
No related merge requests found
......@@ -97,6 +97,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"))
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"))
......
......@@ -137,6 +137,8 @@ void init_Node(py::module& m) {
:rtype: int
)mydelimiter")
.def("get_parent", &Node::getParent, py::arg("in_id"))
.def("get_parents", &Node::getParents,
R"mydelimiter(
Get parents.
......
......@@ -20,6 +20,7 @@ namespace Aidge {
void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx"))
......
......@@ -12,9 +12,11 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cstddef>
#include <string>
#include "aidge/recipies/Recipies.hpp"
#include "aidge/utils/Types.h"
namespace py = pybind11;
......@@ -28,7 +30,7 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
......@@ -63,7 +65,10 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling),
py::arg("node"), py::arg("axis"), py::arg("nb_slices"));
// m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
......
......@@ -21,6 +21,7 @@ void init_Scheduler(py::module& m){
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false)
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
.def("get_static_scheduling", &SequentialScheduler::getStaticScheduling)
;
......
......@@ -74,7 +74,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
res.insert(clonedInputs[i]);
}
for (; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis]) {
for (IOIndex_t i = 0; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis], ++i) {
const auto inputDims = op->computeReceptiveField(outTensor->getIdx(currentFirstDims), outputDims, 0);
auto newNode = node -> clone(); // no input associated to clones
newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis]));
......@@ -83,7 +83,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
// Slice for input and each parameter
auto slice = Slice(inputDims[0].first, inputDims[0].second, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, currentFirstDims[axis]);
newNode -> addChild(concat, 0, i);
res.insert(slice);
res.insert(newNode);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment