Skip to content
Snippets Groups Projects
Commit 6c5b59f2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'dev' into fixGraphRegexUnique

parents ab13734e e9be8063
No related tags found
1 merge request!106Draft: Fix graph regex unique
Pipeline #45645 canceled
Showing
with 306 additions and 47 deletions
...@@ -8,4 +8,5 @@ http://www.eclipse.org/legal/epl-2.0. ...@@ -8,4 +8,5 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0 SPDX-License-Identifier: EPL-2.0
""" """
from aidge_core.aidge_core import * # import so generated by PyBind from aidge_core.aidge_core import * # import so generated by PyBind
from aidge_core.export import ExportNode from aidge_core.export import ExportNode, generate_file, generate_str
import aidge_core.utils
from .node_export import * from .node_export import *
from .code_generation import *
import os
from jinja2 import Environment, FileSystemLoader
def generate_file(file_path: str, template_path: str, **kwargs) -> None:
"""Generate a file at `file_path` using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param file_path: path where to generate the file
:type file_path: str
:param template_path: Path to the template to use for code generation
:type template_path: str
"""
# Get directory name of the file
dirname = os.path.dirname(file_path)
# If directory doesn't exist, create it
if not os.path.exists(dirname):
os.makedirs(dirname)
# Get directory name and name of the template
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
# Select template
template = Environment(loader=FileSystemLoader(
template_dir)).get_template(template_name)
# Generate file
content = template.render(kwargs)
with open(file_path, mode="w", encoding="utf-8") as message:
message.write(content)
def generate_str(template_path:str, **kwargs) -> str:
"""Generate a string using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param template_path: Path to the template to use for code generation
:type template_path: str
:return: A string of the interpreted template
:rtype: str
"""
dirname = os.path.dirname(template_path)
filename = os.path.basename(template_path)
template = Environment(loader=FileSystemLoader(dirname)).get_template(filename)
return template.render(kwargs)
def template_docstring(template_keyword, text_to_replace):
"""Method to template docstring
:param template: Template keyword to replace, in the documentation you template word must be between `{` `}`
:type template: str
:param text_to_replace: Text to replace your template with.
:type text_to_replace: str
"""
def dec(func):
if "{"+template_keyword+"}" not in func.__doc__:
raise RuntimeError(
f"The function {function.__name__} docstring does not contain the template keyword: {template_keyword}.")
func.__doc__ = func.__doc__.replace(
"{"+template_keyword+"}", text_to_replace)
return func
return dec
...@@ -160,7 +160,7 @@ public: ...@@ -160,7 +160,7 @@ public:
/** /**
* @brief List outside input connections of the GraphView. The vector * @brief List outside input connections of the GraphView. The vector
* size is garanteed to match the number of outside inputs of the GraphView. If there is * size is guaranteed to match the number of outside inputs of the GraphView. If there is
* no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
* @return std::vector<std::pair<NodePtr, IOIndex_t>> * @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/ */
...@@ -210,7 +210,7 @@ public: ...@@ -210,7 +210,7 @@ public:
* @brief Compute dimensions of input/output Tensors for each Operator of the * @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes. * GraphView object's Nodes.
*/ */
bool forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}, bool allowDataDependency = false); bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
...@@ -376,6 +376,12 @@ public: ...@@ -376,6 +376,12 @@ public:
addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor); addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
} }
inline void updateNodeName(const std::string& oldName, const std::string& newName){
AIDGE_ASSERT(mNodeRegistry.find(oldName) != mNodeRegistry.end(), "No node named {} in graph {}, the graph may be corrupted !", oldName, name());
mNodeRegistry[newName] = mNodeRegistry[oldName];
mNodeRegistry.erase(oldName);
}
/** /**
* @brief Include a GraphView content in the current GraphView and link * @brief Include a GraphView content in the current GraphView and link
* the two sets by linking one Node from each GraphView. * the two sets by linking one Node from each GraphView.
...@@ -480,6 +486,14 @@ public: ...@@ -480,6 +486,14 @@ public:
*/ */
IOIndex_t getNbFreeDataInputs() const; IOIndex_t getNbFreeDataInputs() const;
/**
* @brief Force update of GraphView inputs/outputs.
* It may be necessary to force the update of GraphView inputs/outputs when
* connections are added or removed inside the GraphView **after** the nodes
* were added.
*/
void updateInputsOutputs();
private: private:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TENSOR MANAGEMENT // TENSOR MANAGEMENT
......
...@@ -235,8 +235,8 @@ public: ...@@ -235,8 +235,8 @@ public:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
/** /**
* @brief Vector of pointers to each GraphView containing the object * @brief Set of pointers to each GraphView containing this Node
* @return std::vector<GraphView> * @return std::set<GraphView>
*/ */
inline std::set<std::shared_ptr<GraphView>> views() const noexcept { inline std::set<std::shared_ptr<GraphView>> views() const noexcept {
std::set<std::shared_ptr<GraphView>> res; std::set<std::shared_ptr<GraphView>> res;
...@@ -460,10 +460,10 @@ private: ...@@ -460,10 +460,10 @@ private:
// OPERATOR FUNCTIONNAL but commented out to avoid iostream inclusion // OPERATOR FUNCTIONNAL but commented out to avoid iostream inclusion
// /** // /**
// * @brief operator<< overload to ease print & debug of nodes // * @brief operator<< overload to ease print & debug of nodes
// * @param[inout] ostream to print to // * @param[inout] ostream to print to
// * @param[in] n node to print // * @param[in] n node to print
// */ // */
// friend std::ostream& operator << (std::ostream& os, Node& n); // friend std::ostream& operator << (std::ostream& os, Node& n);
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -70,16 +70,9 @@ public: ...@@ -70,16 +70,9 @@ public:
return mScheduler; return mScheduler;
} }
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type"); void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx); void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
bool forwardDims(bool allowDataDependency = false) override final { bool forwardDims(bool allowDataDependency = false) override final {
// Check first that all required inputs are available, otherwise // Check first that all required inputs are available, otherwise
......
...@@ -56,8 +56,8 @@ public: ...@@ -56,8 +56,8 @@ public:
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
// Tensor access // Tensor access
// input management // input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final; std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
......
...@@ -27,9 +27,10 @@ enum class ScalingAttr { ...@@ -27,9 +27,10 @@ enum class ScalingAttr {
scalingFactor, quantizedNbBits, isOutputUnsigned scalingFactor, quantizedNbBits, isOutputUnsigned
}; };
class Scaling_Op : public OperatorTensor, class Scaling_Op
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, : public OperatorTensor,
public StaticAttributes<ScalingAttr, float, size_t, bool> { public Registrable<Scaling_Op, std::string, std::shared_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
public: public:
static const std::string Type; static const std::string Type;
...@@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri ...@@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
} }
*/ */
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") { inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f,
std::size_t quantizedNbBits=8,
bool isOutputUnsigned=true,
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name); return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
} }
} // namespace Aidge } // namespace Aidge
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
/** /**
* @brief Run the provided Computational Graph with a batch of data * @brief Run the provided Computational Graph with a batch of data
*/ */
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); virtual void forward(bool forwardDims = true, const std::vector<std::shared_ptr<Aidge::Tensor>>& data = {});
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -114,7 +114,7 @@ public: ...@@ -114,7 +114,7 @@ public:
* *
* @param data data input tensors * @param data data input tensors
*/ */
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); void connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data);
/** /**
* @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes. * @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes.
......
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
/** /**
* @brief Run the provided Computational Graph with a batch of data * @brief Run the provided Computational Graph with a batch of data
*/ */
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); virtual void forward(bool forwardDims = true, const std::vector<std::shared_ptr<Aidge::Tensor>>& data = {});
/** /**
* @brief Run the provided Computational Graph with a batch of data * @brief Run the provided Computational Graph with a batch of data
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "aidge/utils/future_std/any.hpp" #include "aidge/utils/future_std/any.hpp"
#include "aidge/utils/Attributes.hpp" #include "aidge/utils/Attributes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#ifdef PYBIND #ifdef PYBIND
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
...@@ -86,7 +87,7 @@ public: ...@@ -86,7 +87,7 @@ public:
template<class T> void addAttr(const std::string& name, const T& value) template<class T> void addAttr(const std::string& name, const T& value)
{ {
const auto& res = mAttrs.emplace(std::make_pair(name, future_std::any(value))); const auto& res = mAttrs.emplace(std::make_pair(name, future_std::any(value)));
assert(res.second && "attribute already exists"); AIDGE_ASSERT(res.second, "attribute already exists");
#ifdef PYBIND #ifdef PYBIND
// We cannot handle Python object if the Python interpreter is not running // We cannot handle Python object if the Python interpreter is not running
...@@ -129,10 +130,10 @@ public: ...@@ -129,10 +130,10 @@ public:
void addAttrPy(const std::string& name, py::object&& value) void addAttrPy(const std::string& name, py::object&& value)
{ {
auto it = mAttrs.find(name); auto it = mAttrs.find(name);
assert(it == mAttrs.end() && "attribute already exists"); AIDGE_ASSERT(it == mAttrs.end(), "attribute already exists");
const auto& res = mAttrsPy.emplace(std::make_pair(name, value)); const auto& res = mAttrsPy.emplace(std::make_pair(name, value));
assert(res.second && "attribute already exists"); AIDGE_ASSERT(res.second, "attribute already exists");
} }
void setAttrPy(const std::string& name, py::object&& value) override final void setAttrPy(const std::string& name, py::object&& value) override final
...@@ -199,6 +200,8 @@ public: ...@@ -199,6 +200,8 @@ public:
}; };
#endif #endif
virtual ~DynamicAttributes() {}
private: private:
#ifdef PYBIND #ifdef PYBIND
// Stores C++ attributes (copy) and Python-only attributes // Stores C++ attributes (copy) and Python-only attributes
......
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/data/Database.hpp" #include "aidge/data/Database.hpp"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Database(py::module& m){ /**
* @brief Trampoline class for binding
*
*/
class pyDatabase : public Database {
public:
using Database::Database; // Inherit constructors
py::class_<Database, std::shared_ptr<Database>>(m,"Database"); std::vector<std::shared_ptr<Tensor>> getItem(
const std::size_t index) const override {
PYBIND11_OVERRIDE_PURE_NAME(std::vector<std::shared_ptr<Tensor>>, Database,
"get_item", getItem, index);
}
std::size_t getLen() const noexcept override {
PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "len", getLen);
}
std::size_t getNbModalities() const noexcept override {
PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "get_nb_modalities",
getNbModalities);
}
};
void init_Database(py::module& m) {
} py::class_<Database, std::shared_ptr<Database>, pyDatabase>(
m, "Database", py::dynamic_attr())
.def(py::init<>())
.def("get_item", &Database::getItem)
.def("len", &Database::getLen)
.def("get_nb_modalities", &Database::getNbModalities);
} }
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scaling(py::module& m)
{
py::class_<Scaling_Op, std::shared_ptr<Scaling_Op>, Attributes, OperatorTensor>(m, "ScalingOp", py::multiple_inheritance())
.def("get_inputs_name", &Scaling_Op::getInputsName)
.def("get_outputs_name", &Scaling_Op::getOutputsName)
.def("attributes_name", &Scaling_Op::staticGetAttrsName);
declare_registrable<Scaling_Op>(m, "ScalingOp");
m.def("Scaling", &Scaling, py::arg("scaling_factor") = 1.0f, py::arg("nb_bits") = 8, py::arg("is_output_unsigned") = true, py::arg("name") = "");
}
} // namespace Aidge
...@@ -51,6 +51,7 @@ void init_Pow(py::module&); ...@@ -51,6 +51,7 @@ void init_Pow(py::module&);
void init_ReduceMean(py::module&); void init_ReduceMean(py::module&);
void init_ReLU(py::module&); void init_ReLU(py::module&);
void init_Reshape(py::module&); void init_Reshape(py::module&);
void init_Scaling(py::module&);
void init_Sigmoid(py::module&); void init_Sigmoid(py::module&);
void init_Slice(py::module&); void init_Slice(py::module&);
void init_Softmax(py::module&); void init_Softmax(py::module&);
...@@ -72,6 +73,7 @@ void init_Recipes(py::module&); ...@@ -72,6 +73,7 @@ void init_Recipes(py::module&);
void init_GraphViewHelper(py::module&); void init_GraphViewHelper(py::module&);
void init_Scheduler(py::module&); void init_Scheduler(py::module&);
void init_MemoryManager(py::module&);
void init_TensorUtils(py::module&); void init_TensorUtils(py::module&);
void init_Filler(py::module&); void init_Filler(py::module&);
...@@ -117,6 +119,7 @@ void init_Aidge(py::module& m) { ...@@ -117,6 +119,7 @@ void init_Aidge(py::module& m) {
init_ReduceMean(m); init_ReduceMean(m);
init_ReLU(m); init_ReLU(m);
init_Reshape(m); init_Reshape(m);
init_Scaling(m);
init_Sigmoid(m); init_Sigmoid(m);
init_Slice(m); init_Slice(m);
init_Softmax(m); init_Softmax(m);
...@@ -134,6 +137,7 @@ void init_Aidge(py::module& m) { ...@@ -134,6 +137,7 @@ void init_Aidge(py::module& m) {
init_Recipes(m); init_Recipes(m);
init_GraphViewHelper(m); init_GraphViewHelper(m);
init_Scheduler(m); init_Scheduler(m);
init_MemoryManager(m);
init_TensorUtils(m); init_TensorUtils(m);
init_Filler(m); init_Filler(m);
} }
......
...@@ -21,66 +21,70 @@ ...@@ -21,66 +21,70 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Recipes(py::module &m) { void init_Recipes(py::module &m)
{
m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse. // :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter"); // )mydelimiter");
m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter( m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a dropout operator. Recipe to remove a dropout operator.
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator. Recipe to remove a flatten operator.
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
// m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator. // Recipe to remove a flatten operator.
// :param nodes: The flatten operator to remove. // :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter"); // )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse. // :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter"); // )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator. Recipe to remove a flatten operator.
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling), m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling),
py::arg("node"), py::arg("axis"), py::arg("nb_slices")); py::arg("node"), py::arg("axis"), py::arg("nb_slices"));
// m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator. // recipe to remove a flatten operator.
// :param nodes: The flatten operator to remove. // :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter"); // )mydelimiter");
m.def("expand_metaops", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(expandMetaOps), py::arg("graph_view"), py::arg("recursive") = false);
} }
} // namespace Aidge } // namespace Aidge
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/scheduler/MemoryManager.hpp"
namespace py = pybind11;
namespace Aidge {
void init_MemoryManager(py::module& m)
{
py::enum_<MemoryManager::OptimizeStrategy>(m, "OptimizeStrategy")
.value("None", MemoryManager::OptimizeStrategy::None)
.value("OptimizeMaxLifetimeMinSizeFirst", MemoryManager::OptimizeStrategy::OptimizeMaxLifetimeMinSizeFirst)
.value("OptimizeMaxLifetimeMaxSizeFirst", MemoryManager::OptimizeStrategy::OptimizeMaxLifetimeMaxSizeFirst)
.value("OptimizeMaxHoleMaxLifetimeFirst", MemoryManager::OptimizeStrategy::OptimizeMaxHoleMaxLifetimeFirst)
.export_values();
py::class_<MemoryManager::MemorySpace, std::shared_ptr<MemoryManager::MemorySpace>>(m, "MemorySpace")
.def(py::init<MemoryManager::Clock_T, unsigned int, unsigned int, std::set<std::shared_ptr<Node>> >(), py::arg("clock"), py::arg("offset"), py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>())
.def_readwrite("offset", &MemoryManager::MemorySpace::offset)
.def_readwrite("size", &MemoryManager::MemorySpace::size)
.def_readwrite("dependencies", &MemoryManager::MemorySpace::dependencies)
.def_readwrite("allocated", &MemoryManager::MemorySpace::allocated)
.def_readwrite("released", &MemoryManager::MemorySpace::released);
py::class_<MemoryManager::MemoryPlane, std::shared_ptr<MemoryManager::MemoryPlane>>(m, "MemoryPlane")
.def(py::init<std::shared_ptr<MemoryManager::MemorySpace>,
MemoryManager::Clock_T, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int>(),
py::arg("mem_space"), py::arg("clock"), py::arg("offset"),
py::arg("size"), py::arg("stride"), py::arg("length"), py::arg("count"))
.def_readwrite("mem_space", &MemoryManager::MemoryPlane::memSpace)
.def_readwrite("allocated", &MemoryManager::MemoryPlane::allocated)
.def_readwrite("offset", &MemoryManager::MemoryPlane::offset)
.def_readwrite("size", &MemoryManager::MemoryPlane::size)
.def_readwrite("stride", &MemoryManager::MemoryPlane::stride)
.def_readwrite("length", &MemoryManager::MemoryPlane::length)
.def_readwrite("count", &MemoryManager::MemoryPlane::count)
.def("get_size", &MemoryManager::MemoryPlane::getSize)
.def("get_useful_size", &MemoryManager::MemoryPlane::getUsefulSize)
.def("get_contiguous_offset", &MemoryManager::MemoryPlane::getContiguousOffset)
.def("get_contiguous_size", &MemoryManager::MemoryPlane::getContiguousSize)
.def("get_wrapped_offset", &MemoryManager::MemoryPlane::getWrappedOffset)
.def("get_wrapped_size", &MemoryManager::MemoryPlane::getWrappedSize)
.def("get_final_offset", &MemoryManager::MemoryPlane::getFinalOffset)
.def("get_upper_offset", &MemoryManager::MemoryPlane::getUpperOffset)
.def("get_limit", &MemoryManager::MemoryPlane::getLimit);
py::class_<MemoryManager::MaxLifetimeMinSizeFirst>(m, "MaxLifetimeMinSizeFirst")
.def(py::init<unsigned int>(), py::arg("max_lifetime"))
.def_readonly("max_lifetime", &MemoryManager::MaxLifetimeMinSizeFirst::maxLifetime)
.def("__call__", &MemoryManager::MaxLifetimeMinSizeFirst::operator(), py::arg("p0"), py::arg("p1"));
py::class_<MemoryManager::MaxLifetimeMaxSizeFirst>(m, "MaxLifetimeMaxSizeFirst")
.def(py::init<unsigned int>(), py::arg("max_lifetime"))
.def_readonly("max_lifetime", &MemoryManager::MaxLifetimeMaxSizeFirst::maxLifetime)
.def("__call__", &MemoryManager::MaxLifetimeMaxSizeFirst::operator(), py::arg("p0"), py::arg("p1"));
py::class_<MemoryManager::MaxHoleMaxLifetimeFirst>(m, "MaxHoleMaxLifetimeFirst")
.def(py::init<unsigned int, MemoryManager*>(), py::arg("max_lifetime"), py::arg("inst"))
.def_readonly("max_lifetime", &MemoryManager::MaxHoleMaxLifetimeFirst::maxLifetime)
.def_readwrite("inst", &MemoryManager::MaxHoleMaxLifetimeFirst::inst)
.def("__call__", &MemoryManager::MaxHoleMaxLifetimeFirst::operator(), py::arg("p0"), py::arg("p1"));
py::class_<MemoryManager, std::shared_ptr<MemoryManager>>(m, "MemoryManager")
.def(py::init<>())
.def("reserve", (std::shared_ptr<MemoryManager::MemorySpace> (MemoryManager::*)(unsigned int, const std::set<std::shared_ptr<Node>>&)) &MemoryManager::reserve, py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>())
.def("expand", &MemoryManager::expand, py::arg("mem_space"), py::arg("required_size"))
.def("allocate", (MemoryManager::MemoryPlane (MemoryManager::*)(unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::allocate, py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("allocate", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::allocate, py::arg("node"), py::arg("size"), py::arg("dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("is_wrap_around", &MemoryManager::isWrapAround, py::arg("mem_space"), py::arg("offset"), py::arg("size"), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (MemoryManager::MemoryPlane (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("mem_space"), py::arg("offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (MemoryManager::MemoryPlane (MemoryManager::*)(const MemoryManager::MemoryPlane&, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("memPlane"), py::arg("extra_offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (unsigned int (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>, const std::shared_ptr<Node>&, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("mem_space"), py::arg("node"), py::arg("offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("reallocate", (unsigned int (MemoryManager::*)(const MemoryManager::MemoryPlane&, const std::shared_ptr<Node>&, unsigned int, unsigned int, bool, unsigned int, const std::set<std::shared_ptr<Node>>&, unsigned int, unsigned int, unsigned int)) &MemoryManager::reallocate, py::arg("mem_plane"), py::arg("node"), py::arg("extra_offset"), py::arg("size"), py::arg("wrap_around"), py::arg("extra_size") = 0, py::arg("additional_dependencies") = std::set<std::shared_ptr<Node>>(), py::arg("stride") = 0, py::arg("length") = 1, py::arg("count") = 1)
.def("release", (unsigned int (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>)) &MemoryManager::release, py::arg("mem_space"))
.def("release", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&)) &MemoryManager::release, py::arg("node"))
.def("release_dependencies", &MemoryManager::releaseDependencies, py::arg("node"))
.def("optimize", &MemoryManager::optimize, py::arg("strategy"))
.def("get_offset", &MemoryManager::getOffset, py::arg("node"), py::arg("plane") = 0)
.def("get_size", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&, unsigned int) const) &MemoryManager::getSize, py::arg("node"), py::arg("plane"))
.def("get_size", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&) const) &MemoryManager::getSize, py::arg("node"))
.def("get_peak_usage", &MemoryManager::getPeakUsage)
.def("get_max_lifetime", &MemoryManager::getMaxLifetime)
.def("get_planes", (const std::vector<MemoryManager::MemoryPlane>& (MemoryManager::*)(const std::shared_ptr<Node>&) const) &MemoryManager::getPlanes, py::arg("node"))
.def("get_planes", (const MemoryManager::MemMap_T& (MemoryManager::*)() const) &MemoryManager::getPlanes)
.def("get_planes", (MemoryManager::MemMap_T (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>) const) &MemoryManager::getPlanes, py::arg("mem_space"))
.def("get_nb_planes", (unsigned int (MemoryManager::*)(const std::shared_ptr<Node>&) const) &MemoryManager::getNbPlanes, py::arg("node"))
.def("get_nb_planes", (unsigned int (MemoryManager::*)(std::shared_ptr<MemoryManager::MemorySpace>) const) &MemoryManager::getNbPlanes, py::arg("mem_space"))
.def("get_current_tick", &MemoryManager::getCurrentTick)
.def("tick", &MemoryManager::tick)
.def("log", &MemoryManager::log, py::arg("file_name"))
;
}
} // Aidge
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "aidge/scheduler/MemoryManager.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/ParallelScheduler.hpp" #include "aidge/scheduler/ParallelScheduler.hpp"
...@@ -22,10 +23,12 @@ namespace Aidge { ...@@ -22,10 +23,12 @@ namespace Aidge {
void init_Scheduler(py::module& m){ void init_Scheduler(py::module& m){
py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("graph_view", &Scheduler::graphView)
.def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &Scheduler::resetScheduling) .def("resetScheduling", &Scheduler::resetScheduling)
.def("generate_scheduling", &Scheduler::generateScheduling) .def("generate_scheduling", &Scheduler::generateScheduling)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0) .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0)
.def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
; ;
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment