Skip to content
Snippets Groups Projects
Commit 4d82363f authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'tiling' into 'main'

byding the first tiling prototype

See merge request !58
parents 6da29a40 92e7646b
No related branches found
No related tags found
2 merge requests!58byding the first tiling prototype,!47Vit operators
Pipeline #35698 passed
...@@ -97,6 +97,7 @@ void init_GraphView(py::module& m) { ...@@ -97,6 +97,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes) .def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name")) .def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims) .def("forward_dims", &GraphView::forwardDims)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"))
.def("__call__", &GraphView::operator(), py::arg("connectors")) .def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend")) .def("set_backend", &GraphView::setBackend, py::arg("backend"))
......
...@@ -137,6 +137,8 @@ void init_Node(py::module& m) { ...@@ -137,6 +137,8 @@ void init_Node(py::module& m) {
:rtype: int :rtype: int
)mydelimiter") )mydelimiter")
.def("get_parent", &Node::getParent, py::arg("in_id"))
.def("get_parents", &Node::getParents, .def("get_parents", &Node::getParents,
R"mydelimiter( R"mydelimiter(
Get parents. Get parents.
......
...@@ -20,6 +20,7 @@ namespace Aidge { ...@@ -20,6 +20,7 @@ namespace Aidge {
void init_Operator(py::module& m){ void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx")) .def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx"))
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <cstddef>
#include <string> #include <string>
#include "aidge/recipies/Recipies.hpp" #include "aidge/recipies/Recipies.hpp"
#include "aidge/utils/Types.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -28,7 +30,7 @@ void init_Recipies(py::module &m) { ...@@ -28,7 +30,7 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
...@@ -63,7 +65,10 @@ void init_Recipies(py::module &m) { ...@@ -63,7 +65,10 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling),
py::arg("node"), py::arg("axis"), py::arg("nb_slices"));
// m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator. // Recipie to remove a flatten operator.
......
...@@ -21,6 +21,7 @@ void init_Scheduler(py::module& m){ ...@@ -21,6 +21,7 @@ void init_Scheduler(py::module& m){
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false)
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
.def("get_static_scheduling", &SequentialScheduler::getStaticScheduling) .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling)
; ;
......
...@@ -74,7 +74,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -74,7 +74,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
res.insert(clonedInputs[i]); res.insert(clonedInputs[i]);
} }
for (; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis]) { for (IOIndex_t i = 0; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis], ++i) {
const auto inputDims = op->computeReceptiveField(outTensor->getIdx(currentFirstDims), outputDims, 0); const auto inputDims = op->computeReceptiveField(outTensor->getIdx(currentFirstDims), outputDims, 0);
auto newNode = node -> clone(); // no input associated to clones auto newNode = node -> clone(); // no input associated to clones
newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis])); newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis]));
...@@ -83,7 +83,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -83,7 +83,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
// Slice for input and each parameter // Slice for input and each parameter
auto slice = Slice(inputDims[0].first, inputDims[0].second, "Slice_" + std::to_string(currentFirstDims[axis])); auto slice = Slice(inputDims[0].first, inputDims[0].second, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0); slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, currentFirstDims[axis]); newNode -> addChild(concat, 0, i);
res.insert(slice); res.insert(slice);
res.insert(newNode); res.insert(newNode);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment