Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_core
  • hrouis/aidge_core
  • mszczep/aidge_core
  • oantoni/aidge_core
  • cguillon/aidge_core
  • jeromeh/aidge_core
  • axelfarr/aidge_core
  • cmoineau/aidge_core
  • noamzerah/aidge_core
  • lrakotoarivony/aidge_core
  • silvanosky/aidge_core
  • maab05/aidge_core
  • mick94/aidge_core
  • lucaslopez/aidge_core_ll
  • wboussella/aidge_core
  • farnez/aidge_core
  • mnewson/aidge_core
17 results
Show changes
Commits on Source (46)
Showing
with 306 additions and 47 deletions
......@@ -8,4 +8,5 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
from aidge_core.aidge_core import * # import so generated by PyBind
from aidge_core.export import ExportNode
from aidge_core.export import ExportNode, generate_file, generate_str
import aidge_core.utils
from .node_export import *
from .code_generation import *
import os
from jinja2 import Environment, FileSystemLoader
def generate_file(file_path: str, template_path: str, **kwargs) -> None:
"""Generate a file at `file_path` using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param file_path: path where to generate the file
:type file_path: str
:param template_path: Path to the template to use for code generation
:type template_path: str
"""
# Get directory name of the file
dirname = os.path.dirname(file_path)
# If directory doesn't exist, create it
if not os.path.exists(dirname):
os.makedirs(dirname)
# Get directory name and name of the template
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
# Select template
template = Environment(loader=FileSystemLoader(
template_dir)).get_template(template_name)
# Generate file
content = template.render(kwargs)
with open(file_path, mode="w", encoding="utf-8") as message:
message.write(content)
def generate_str(template_path:str, **kwargs) -> str:
"""Generate a string using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param template_path: Path to the template to use for code generation
:type template_path: str
:return: A string of the interpreted template
:rtype: str
"""
dirname = os.path.dirname(template_path)
filename = os.path.basename(template_path)
template = Environment(loader=FileSystemLoader(dirname)).get_template(filename)
return template.render(kwargs)
def template_docstring(template_keyword, text_to_replace):
"""Method to template docstring
:param template: Template keyword to replace, in the documentation you template word must be between `{` `}`
:type template: str
:param text_to_replace: Text to replace your template with.
:type text_to_replace: str
"""
def dec(func):
if "{"+template_keyword+"}" not in func.__doc__:
raise RuntimeError(
f"The function {function.__name__} docstring does not contain the template keyword: {template_keyword}.")
func.__doc__ = func.__doc__.replace(
"{"+template_keyword+"}", text_to_replace)
return func
return dec
......@@ -160,7 +160,7 @@ public:
/**
* @brief List outside input connections of the GraphView. The vector
* size is garanteed to match the number of outside inputs of the GraphView. If there is
* size is guaranteed to match the number of outside inputs of the GraphView. If there is
* no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
......@@ -210,7 +210,7 @@ public:
* @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes.
*/
bool forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}, bool allowDataDependency = false);
bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
......@@ -376,6 +376,12 @@ public:
addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
}
inline void updateNodeName(const std::string& oldName, const std::string& newName){
AIDGE_ASSERT(mNodeRegistry.find(oldName) != mNodeRegistry.end(), "No node named {} in graph {}, the graph may be corrupted !", oldName, name());
mNodeRegistry[newName] = mNodeRegistry[oldName];
mNodeRegistry.erase(oldName);
}
/**
* @brief Include a GraphView content in the current GraphView and link
* the two sets by linking one Node from each GraphView.
......@@ -480,6 +486,14 @@ public:
*/
IOIndex_t getNbFreeDataInputs() const;
/**
* @brief Force update of GraphView inputs/outputs.
* It may be necessary to force the update of GraphView inputs/outputs when
* connections are added or removed inside the GraphView **after** the nodes
* were added.
*/
void updateInputsOutputs();
private:
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
......
......@@ -235,8 +235,8 @@ public:
///////////////////////////////////////////////////////
/**
* @brief Vector of pointers to each GraphView containing the object
* @return std::vector<GraphView>
* @brief Set of pointers to each GraphView containing this Node
* @return std::set<GraphView>
*/
inline std::set<std::shared_ptr<GraphView>> views() const noexcept {
std::set<std::shared_ptr<GraphView>> res;
......@@ -460,10 +460,10 @@ private:
// OPERATOR FUNCTIONNAL but commented out to avoid iostream inclusion
// /**
// * @brief operator<< overload to ease print & debug of nodes
// * @param[inout] ostream to print to
// * @param[inout] ostream to print to
// * @param[in] n node to print
// */
// friend std::ostream& operator << (std::ostream& os, Node& n);
// friend std::ostream& operator << (std::ostream& os, Node& n);
};
} // namespace Aidge
......
......@@ -70,16 +70,9 @@ public:
return mScheduler;
}
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final {
AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type");
AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx);
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
bool forwardDims(bool allowDataDependency = false) override final {
// Check first that all required inputs are available, otherwise
......
......@@ -56,8 +56,8 @@ public:
///////////////////////////////////////////////////
// Tensor access
// input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
......
......@@ -27,9 +27,10 @@ enum class ScalingAttr {
scalingFactor, quantizedNbBits, isOutputUnsigned
};
class Scaling_Op : public OperatorTensor,
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
class Scaling_Op
: public OperatorTensor,
public Registrable<Scaling_Op, std::string, std::shared_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
public:
static const std::string Type;
......@@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
}
*/
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") {
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f,
std::size_t quantizedNbBits=8,
bool isOutputUnsigned=true,
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
}
} // namespace Aidge
......
......@@ -37,7 +37,7 @@ public:
/**
* @brief Run the provided Computational Graph with a batch of data
*/
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
virtual void forward(bool forwardDims = true, const std::vector<std::shared_ptr<Aidge::Tensor>>& data = {});
};
} // namespace Aidge
......
......@@ -114,7 +114,7 @@ public:
*
* @param data data input tensors
*/
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);
void connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data);
/**
* @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes.
......
......@@ -49,7 +49,7 @@ public:
/**
* @brief Run the provided Computational Graph with a batch of data
*/
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
virtual void forward(bool forwardDims = true, const std::vector<std::shared_ptr<Aidge::Tensor>>& data = {});
/**
* @brief Run the provided Computational Graph with a batch of data
......
......@@ -21,6 +21,7 @@
#include "aidge/utils/future_std/any.hpp"
#include "aidge/utils/Attributes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#ifdef PYBIND
#include <pybind11/pybind11.h>
......@@ -86,7 +87,7 @@ public:
template<class T> void addAttr(const std::string& name, const T& value)
{
const auto& res = mAttrs.emplace(std::make_pair(name, future_std::any(value)));
assert(res.second && "attribute already exists");
AIDGE_ASSERT(res.second, "attribute already exists");
#ifdef PYBIND
// We cannot handle Python object if the Python interpreter is not running
......@@ -129,10 +130,10 @@ public:
void addAttrPy(const std::string& name, py::object&& value)
{
auto it = mAttrs.find(name);
assert(it == mAttrs.end() && "attribute already exists");
AIDGE_ASSERT(it == mAttrs.end(), "attribute already exists");
const auto& res = mAttrsPy.emplace(std::make_pair(name, value));
assert(res.second && "attribute already exists");
AIDGE_ASSERT(res.second, "attribute already exists");
}
void setAttrPy(const std::string& name, py::object&& value) override final
......@@ -199,6 +200,8 @@ public:
};
#endif
virtual ~DynamicAttributes() {}
private:
#ifdef PYBIND
// Stores C++ attributes (copy) and Python-only attributes
......
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/data/Database.hpp"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Database(py::module& m){
/**
* @brief Trampoline class for binding
*
*/
class pyDatabase : public Database {
public:
using Database::Database; // Inherit constructors
py::class_<Database, std::shared_ptr<Database>>(m,"Database");
std::vector<std::shared_ptr<Tensor>> getItem(
const std::size_t index) const override {
PYBIND11_OVERRIDE_PURE_NAME(std::vector<std::shared_ptr<Tensor>>, Database,
"get_item", getItem, index);
}
std::size_t getLen() const noexcept override {
PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "len", getLen);
}
std::size_t getNbModalities() const noexcept override {
PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "get_nb_modalities",
getNbModalities);
}
};
}
void init_Database(py::module& m) {
py::class_<Database, std::shared_ptr<Database>, pyDatabase>(
m, "Database", py::dynamic_attr())
.def(py::init<>())
.def("get_item", &Database::getItem)
.def("len", &Database::getLen)
.def("get_nb_modalities", &Database::getNbModalities);
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scaling(py::module& m)
{
py::class_<Scaling_Op, std::shared_ptr<Scaling_Op>, Attributes, OperatorTensor>(m, "ScalingOp", py::multiple_inheritance())
.def("get_inputs_name", &Scaling_Op::getInputsName)
.def("get_outputs_name", &Scaling_Op::getOutputsName)
.def("attributes_name", &Scaling_Op::staticGetAttrsName);
declare_registrable<Scaling_Op>(m, "ScalingOp");
m.def("Scaling", &Scaling, py::arg("scaling_factor") = 1.0f, py::arg("nb_bits") = 8, py::arg("is_output_unsigned") = true, py::arg("name") = "");
}
} // namespace Aidge
......@@ -51,6 +51,7 @@ void init_Pow(py::module&);
void init_ReduceMean(py::module&);
void init_ReLU(py::module&);
void init_Reshape(py::module&);
void init_Scaling(py::module&);
void init_Sigmoid(py::module&);
void init_Slice(py::module&);
void init_Softmax(py::module&);
......@@ -72,6 +73,7 @@ void init_Recipes(py::module&);
void init_GraphViewHelper(py::module&);
void init_Scheduler(py::module&);
void init_MemoryManager(py::module&);
void init_TensorUtils(py::module&);
void init_Filler(py::module&);
......@@ -117,6 +119,7 @@ void init_Aidge(py::module& m) {
init_ReduceMean(m);
init_ReLU(m);
init_Reshape(m);
init_Scaling(m);
init_Sigmoid(m);
init_Slice(m);
init_Softmax(m);
......@@ -134,6 +137,7 @@ void init_Aidge(py::module& m) {
init_Recipes(m);
init_GraphViewHelper(m);
init_Scheduler(m);
init_MemoryManager(m);
init_TensorUtils(m);
init_Filler(m);
}
......
......@@ -21,66 +21,70 @@
namespace py = pybind11;
namespace Aidge {
void init_Recipes(py::module &m) {
void init_Recipes(py::module &m)
{
m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param graph_view: Graph view on which we want to apply the recipie
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a dropout operator.
Recipe to remove a dropout operator.
:param graph_view: Graph view on which we want to apply the recipie
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
Recipe to remove a flatten operator.
:param graph_view: Graph view on which we want to apply the recipie
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
// m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
// Recipe to remove a flatten operator.
// :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
Recipe to remove a flatten operator.
:param graph_view: Graph view on which we want to apply the recipie
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling),
m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling),
py::arg("node"), py::arg("axis"), py::arg("nb_slices"));
// m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
// recipe to remove a flatten operator.
// :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("expand_metaops", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(expandMetaOps), py::arg("graph_view"), py::arg("recursive") = false);
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/scheduler/MemoryManager.hpp"
namespace py = pybind11;
namespace Aidge {
void init_MemoryManager(py::module& m)
{
py::enum_<MemoryManager::OptimizeStrategy>(m, "OptimizeStrategy")
.value("None", MemoryManager::OptimizeStrategy::None)
.value("OptimizeMaxLifetimeMinSizeFirst", MemoryManager::OptimizeStrategy::OptimizeMaxLifetimeMinSizeFirst)
.value("OptimizeMaxLifetimeMaxSizeFirst", MemoryManager::OptimizeStrategy::OptimizeMaxLifetimeMaxSizeFirst)
.value("OptimizeMaxHoleMaxLifetimeFirst", MemoryManager::OptimizeStrategy::OptimizeMaxHoleMaxLifetimeFirst)
.export_values();
py::class_<MemoryManager::MemorySpace, std::shared_ptr<MemoryManager::MemorySpace>>(m, "MemorySpace")
.def(py::init<MemoryManager::Clock_T, unsigned int, unsigned int, std::set<std::shared_ptr<Node>> >(), py::arg("clock"), py::arg("offset"), py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>())
.def_readwrite("offset", &MemoryManager::MemorySpace::offset)
.def_readwrite("size", &MemoryManager::MemorySpace::size)
.def_readwrite("dependencies", &MemoryManager::MemorySpace::dependencies)
.def_readwrite("allocated", &MemoryManager::MemorySpace::allocated)
.def_readwrite("released", &MemoryManager::MemorySpace::released);
py::class_<MemoryManager::MemoryPlane, std::shared_ptr<MemoryManager::MemoryPlane>>(m, "MemoryPlane")
.def(py::init<std::shared_ptr<MemoryManager::MemorySpace>,
MemoryManager::Clock_T, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int>(),
py::arg("mem_space"), py::arg("clock"), py::arg("offset"),
py::arg("size"), py::arg("stride"), py::arg("length"), py::arg("count"))
.def_readwrite("mem_space", &MemoryManager::MemoryPlane::memSpace)
.def_readwrite("allocated", &MemoryManager::MemoryPlane::allocated)
.def_readwrite("offset", &MemoryManager::MemoryPlane::offset)
.def_readwrite("size", &MemoryManager::MemoryPlane::size)
.def_readwrite("stride", &MemoryManager::MemoryPlane::stride)
.def_readwrite("length", &MemoryManager::MemoryPlane::length)
.def_readwrite("count", &MemoryManager::MemoryPlane::count)
.def("get_size", &MemoryManager::MemoryPlane::getSize)
.def("get_useful_size", &MemoryManager::MemoryPlane::getUsefulSize)
.def("get_contiguous_offset", &MemoryManager::MemoryPlane::getContiguousOffset)
.def("get_contiguous_size", &MemoryManager::MemoryPlane::getContiguousSize)
.def("get_wrapped_offset", &MemoryManager::MemoryPlane::getWrappedOffset)
.def("get_wrapped_size", &MemoryManager::MemoryPlane::getWrappedSize)
.def("get_final_offset", &MemoryManager::MemoryPlane::getFinalOffset)
.def("get_upper_offset", &MemoryManager::MemoryPlane::getUpperOffset)
.def("get_limit", &MemoryManager::MemoryPlane::getLimit);
py::class_<MemoryManager::MaxLifetimeMinSizeFirst>(m, "MaxLifetimeMinSizeFirst")
.def(py::init<unsigned int>(), py::arg("max_lifetime"))
.def_readonly("max_lifetime", &MemoryManager::MaxLifetimeMinSizeFirst::maxLifetime)
.def("__call__", &MemoryManager::MaxLifetimeMinSizeFirst::operator(), py::arg("p0"), py::arg("p1"));
py::class_<MemoryManager::MaxLifetimeMaxSizeFirst>(m, "MaxLifetimeMaxSizeFirst")
.def(py::init<unsigned int>(), py::arg("max_lifetime"))
.def_readonly("max_lifetime", &MemoryManager::MaxLifetimeMaxSizeFirst::maxLifetime)
.def("__call__", &MemoryManager::MaxLifetimeMaxSizeFirst::operator(), py::arg("p0"), py::arg("p1"));
py::class_<MemoryManager::MaxHoleMaxLifetimeFirst>(m, "MaxHoleMaxLifetimeFirst")
.def(py::init<unsigned int, MemoryManager*>(), py::arg("max_lifetime"), py::arg("inst"))
.def_readonly("max_lifetime", &MemoryManager::MaxHoleMaxLifetimeFirst::maxLifetime)
.def_readwrite("inst", &MemoryManager::MaxHoleMaxLifetimeFirst::inst)
.def("__call__", &MemoryManager::MaxHoleMaxLifetimeFirst::operator(), py::arg("p0"), py::arg("p1"));
py::class_<MemoryManager, std::shared_ptr<MemoryManager>>(m, "MemoryManager")
.def(py::init<>())
.def("reserve", (std::shared_ptr<MemoryManager::MemorySpace> (MemoryManager::*)(unsigned int, const std::set<std::shared_ptr<Node>>&)) &MemoryManager::reserve, py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>())
.def("expand", &MemoryManager::expand, py::arg("mem_space"), py::arg("required_size"))
.def("allocate", (MemoryManager::MemoryPlane (MemoryManager::*)(unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::allocate, py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("allocate", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::allocate, py::arg("node"), py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("is_wrap_around", &MemoryManager::isWrapAround, py::arg("mem_space"), py::arg("offset"), py::arg("size"), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (MemoryManager::MemoryPlane (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("mem_space"), py::arg("offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (MemoryManager::MemoryPlane (MemoryManager::*)(const MemoryManager::MemoryPlane&, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("memPlane"), py::arg("extra_offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (unsigned int (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>, const std::shared_ptr<Node>&, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("mem_space"), py::arg("node"), py::arg("offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (unsigned int (MemoryManager::*)(const MemoryManager::MemoryPlane&, const std::shared_ptr<Node>&, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("mem_plane"), py::arg("node"), py::arg("extra_offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("release", (unsigned int (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>)) &MemoryManager::release, py::arg("mem_space"))
.def("release", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&)) &MemoryManager::release, py::arg("node"))
.def("release_dependencies", &MemoryManager::releaseDependencies, py::arg("node"))
.def("optimize", &MemoryManager::optimize, py::arg("strategy"))
.def("get_offset", &MemoryManager::getOffset, py::arg("node"), py::arg("plane") = 0)
.def("get_size", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&, unsigned int) const) &MemoryManager::getSize, py::arg("node"), py::arg("plane"))
.def("get_size", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&) const) &MemoryManager::getSize, py::arg("node"))
.def("get_peak_usage", &MemoryManager::getPeakUsage)
.def("get_max_lifetime", &MemoryManager::getMaxLifetime)
.def("get_planes", (const std::vector<MemoryManager::MemoryPlane>& (MemoryManager::*)(const std::shared_ptr<Node>&) const) &MemoryManager::getPlanes, py::arg("node"))
.def("get_planes", (const MemoryManager::MemMap_T& (MemoryManager::*)() const) &MemoryManager::getPlanes)
.def("get_planes", (MemoryManager::MemMap_T (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>) const) &MemoryManager::getPlanes, py::arg("mem_space"))
.def("get_nb_planes", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&) const) &MemoryManager::getNbPlanes, py::arg("node"))
.def("get_nb_planes", (unsigned int (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>) const) &MemoryManager::getNbPlanes, py::arg("mem_space"))
.def("get_current_tick", &MemoryManager::getCurrentTick)
.def("tick", &MemoryManager::tick)
.def("log", &MemoryManager::log, py::arg("file_name"))
;
}
} // Aidge
......@@ -11,6 +11,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/scheduler/MemoryManager.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/ParallelScheduler.hpp"
......@@ -22,10 +23,12 @@ namespace Aidge {
void init_Scheduler(py::module& m){
py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("graph_view", &Scheduler::graphView)
.def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &Scheduler::resetScheduling)
.def("generate_scheduling", &Scheduler::generateScheduling)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0)
.def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
;
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
......