Skip to content
Snippets Groups Projects
Commit 501323a0 authored by Vincent Templier's avatar Vincent Templier
Browse files

Merge branch 'dev' into pybind_memorymanager

parents 3cee38e2 f29bc69f
No related branches found
No related tags found
2 merge requests!1190.2.1,!112Add python binding for MemoryManager
Pipeline #44822 passed
......@@ -8,4 +8,5 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
from aidge_core.aidge_core import * # import so generated by PyBind
from aidge_core.export import ExportNode
from aidge_core.export import ExportNode, generate_file, generate_str
import aidge_core.utils
from .node_export import *
from .code_generation import *
import os
from jinja2 import Environment, FileSystemLoader
def generate_file(file_path: str, template_path: str, **kwargs) -> None:
"""Generate a file at `file_path` using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param file_path: path where to generate the file
:type file_path: str
:param template_path: Path to the template to use for code generation
:type template_path: str
"""
# Get directory name of the file
dirname = os.path.dirname(file_path)
# If directory doesn't exist, create it
if not os.path.exists(dirname):
os.makedirs(dirname)
# Get directory name and name of the template
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
# Select template
template = Environment(loader=FileSystemLoader(
template_dir)).get_template(template_name)
# Generate file
content = template.render(kwargs)
with open(file_path, mode="w", encoding="utf-8") as message:
message.write(content)
def generate_str(template_path:str, **kwargs) -> str:
"""Generate a string using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param template_path: Path to the template to use for code generation
:type template_path: str
:return: A string of the interpreted template
:rtype: str
"""
dirname = os.path.dirname(template_path)
filename = os.path.basename(template_path)
template = Environment(loader=FileSystemLoader(dirname)).get_template(filename)
return template.render(kwargs)
def template_docstring(template_keyword, text_to_replace):
"""Method to template docstring
:param template: Template keyword to replace, in the documentation you template word must be between `{` `}`
:type template: str
:param text_to_replace: Text to replace your template with.
:type text_to_replace: str
"""
def dec(func):
if "{"+template_keyword+"}" not in func.__doc__:
raise RuntimeError(
f"The function {function.__name__} docstring does not contain the template keyword: {template_keyword}.")
func.__doc__ = func.__doc__.replace(
"{"+template_keyword+"}", text_to_replace)
return func
return dec
......@@ -160,7 +160,7 @@ public:
/**
* @brief List outside input connections of the GraphView. The vector
* size is garanteed to match the number of outside inputs of the GraphView. If there is
* size is guaranteed to match the number of outside inputs of the GraphView. If there is
* no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
......@@ -376,6 +376,12 @@ public:
addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
}
inline void updateNodeName(const std::string& oldName, const std::string& newName){
AIDGE_ASSERT(mNodeRegistry.find(oldName) != mNodeRegistry.end(), "No node named {} in graph {}, the graph may be corrupted !", oldName, name());
mNodeRegistry[newName] = mNodeRegistry[oldName];
mNodeRegistry.erase(oldName);
}
/**
* @brief Include a GraphView content in the current GraphView and link
* the two sets by linking one Node from each GraphView.
......
......@@ -235,8 +235,8 @@ public:
///////////////////////////////////////////////////////
/**
* @brief Vector of pointers to each GraphView containing the object
* @return std::vector<GraphView>
* @brief Set of pointers to each GraphView containing this Node
* @return std::set<GraphView>
*/
inline std::set<std::shared_ptr<GraphView>> views() const noexcept {
std::set<std::shared_ptr<GraphView>> res;
......@@ -460,10 +460,10 @@ private:
// OPERATOR FUNCTIONNAL but commented out to avoid iostream inclusion
// /**
// * @brief operator<< overload to ease print & debug of nodes
// * @param[inout] ostream to print to
// * @param[inout] ostream to print to
// * @param[in] n node to print
// */
// friend std::ostream& operator << (std::ostream& os, Node& n);
// friend std::ostream& operator << (std::ostream& os, Node& n);
};
} // namespace Aidge
......
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/data/Database.hpp"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Database(py::module& m){
/**
* @brief Trampoline class for binding
*
*/
class pyDatabase : public Database {
public:
using Database::Database; // Inherit constructors
py::class_<Database, std::shared_ptr<Database>>(m,"Database");
std::vector<std::shared_ptr<Tensor>> getItem(
const std::size_t index) const override {
PYBIND11_OVERRIDE_PURE_NAME(std::vector<std::shared_ptr<Tensor>>, Database,
"get_item", getItem, index);
}
std::size_t getLen() const noexcept override {
PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "len", getLen);
}
std::size_t getNbModalities() const noexcept override {
PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "get_nb_modalities",
getNbModalities);
}
};
}
void init_Database(py::module& m) {
py::class_<Database, std::shared_ptr<Database>, pyDatabase>(
m, "Database", py::dynamic_attr())
.def(py::init<>())
.def("get_item", &Database::getItem)
.def("len", &Database::getLen)
.def("get_nb_modalities", &Database::getNbModalities);
}
} // namespace Aidge
This diff is collapsed.
......@@ -57,7 +57,10 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) {
// INNER
///////////////////////////////////////////////////////
void Aidge::Node::setName(const std::string& name) { mName = name; }
void Aidge::Node::setName(const std::string& name) {
for (auto graphView : views()) graphView->updateNodeName(mName, name);
mName = name;
}
///////////////////////////////////////////////////////
// OPERATORS
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment