Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_core
  • hrouis/aidge_core
  • mszczep/aidge_core
  • oantoni/aidge_core
  • cguillon/aidge_core
  • jeromeh/aidge_core
  • axelfarr/aidge_core
  • cmoineau/aidge_core
  • noamzerah/aidge_core
  • lrakotoarivony/aidge_core
  • silvanosky/aidge_core
  • maab05/aidge_core
  • mick94/aidge_core
  • lucaslopez/aidge_core_ll
  • wboussella/aidge_core
  • farnez/aidge_core
  • mnewson/aidge_core
17 results
Show changes
Commits on Source (7)
......@@ -4,8 +4,9 @@ import builtins
import aidge_core
import numpy as np
from pathlib import Path
from typing import Any, Dict, List, Optional
def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]:
def _retrieve_operator_attrs(node : aidge_core.Node) -> Dict[str, Optional[Any]]:
"""
Returns the dictionary containing the attributes of a given Node.
......@@ -13,7 +14,7 @@ def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bo
:type graph: aidge_core.Node
:return: A dictionary with the Node's attributes.
:rtype: dict[str, int, float, bool, None]
:rtype: Dict[str, Optional[Any]]
"""
if node.get_operator().attr is not None:
......@@ -27,7 +28,7 @@ def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bo
return node_attr_dict
def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_embed : bool, write_trainable_params_ext : bool, path_trainable_params : Path, params_file_format : str) -> dict[str, int, float, bool, None]:
def _create_dict(ordered_nodes : List[aidge_core.Node], write_trainable_params_embed : bool, write_trainable_params_ext : bool, path_trainable_params : Path, params_file_format : str) -> Dict[str, Optional[Any]]:
"""
Creates a dictionary to store the information of a given ordered GraphView.
......@@ -43,7 +44,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
:type params_file_format: str
:return: A dictionary with the GraphView description.
:rtype: dict[str, int, float, bool, None]
:rtype: Dict[str, Optional[Any]]
"""
graphview_dict = {'graph': []}
......@@ -79,7 +80,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
if parents[0] is None: parents.append(parents.pop(0))
else:
pass
parents_inputs = []
input_idx = 0
for parent in node.get_parents():
......@@ -88,11 +89,11 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
for child in children:
if child[0] == node and child[1] == input_idx:
parents_inputs.append((parent.name(), input_idx))
elif parent is None:
if input_idx not in [item[1] for item in parents_inputs]:
parents_inputs.append((None, input_idx))
input_idx += 1
node_dict['parents'] = parents_inputs
......@@ -167,7 +168,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
return graphview_dict
def _write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None:
def _write_dict_json(graphview_dict : Dict[str, Optional[Any]], json_path : str) -> None:
"""
Writes dictionary containing GraphView description to a JSON file.
......
......@@ -43,6 +43,7 @@ For more evolved scenarii, specialize the provided FileTreeCache class.
from pathlib import Path
import shutil
import sys
import filecmp
from typing import Optional, Union, List
......@@ -54,6 +55,21 @@ __all__ = [
"tree_update_from_cache",
]
def is_relative_to(path: Path, other: Path) -> bool:
"""
Dynamically choose implementation based on Python version
"""
# Python 3.9+
if sys.version_info >= (3, 9):
return path.is_relative_to(other)
# Python 3.8 and earlier
try:
path.relative_to(other)
return True
except ValueError:
return False
class FileTreeCache():
"""
......@@ -66,8 +82,8 @@ class FileTreeCache():
default_tmp_prefix = "__tmp_"
def __init__(self,
src_path: Union[str|Path],
cache_path: Optional[Union[str|Path]] = None
src_path: Union[str, Path],
cache_path: Optional[Union[str, Path]] = None
) -> None:
self.src_path = Path(src_path).absolute()
self.cache_path = (
......@@ -78,7 +94,7 @@ class FileTreeCache():
)
ctx_msg = f"tree_cache: {src_path = }, {cache_path = }"
assert self.src_path != self.cache_path, f"src_path and cache_path must differ on {ctx_msg}"
assert not self.src_path.is_relative_to(self.cache_path), f"src_path must not be relative to cache_path on {ctx_msg}"
assert not is_relative_to(self.src_path, self.cache_path), f"src_path must not be relative to cache_path on {ctx_msg}"
self._tmp_path = (
self.src_path.parent /
f"{self.default_tmp_prefix}{self.src_path.name}")
......@@ -92,7 +108,7 @@ class FileTreeCache():
assert not dst_cache_dir.exists()
assert src_dir.is_dir()
assert not cache_dir.exists() or cache_dir.is_dir()
assert not cache_dir.is_relative_to(src_dir)
assert not is_relative_to(cache_dir, src_dir)
def copy_or_cache(src, dst):
base_src = Path(src).relative_to(src_dir)
......@@ -132,8 +148,8 @@ class FileTreeCache():
def tree_update_from_cache(
src_path: Union[str|Path],
cache_path: Optional[Union[str|Path]] = None) -> None:
src_path: Union[str, Path],
cache_path: Optional[Union[str, Path]] = None) -> None:
"""
Update from cache the current generation of a tree from the
older generations, preserving file stamps when files contents are identical.
......
......@@ -5,8 +5,9 @@ Provide utility function for file trees manipulations.
"""
import shutil
import sys
from pathlib import Path
from typing import Union, Optional
from typing import Union
__all__ = [
......@@ -15,8 +16,24 @@ __all__ = [
]
def is_relative_to(path: Path, other: Path) -> bool:
"""
Dynamically choose implementation based on Python version
"""
# Python 3.9+
if sys.version_info >= (3, 9):
return path.is_relative_to(other)
# Python 3.8 and earlier
try:
path.relative_to(other)
return True
except ValueError:
return False
def tree_remove(
path: Union[str|Path],
path: Union[str, Path],
ignore_missing: bool = False,
) -> None:
"""
......@@ -35,8 +52,8 @@ def tree_remove(
def tree_move(
src_path: Union[str|Path],
dst_path: Union[str|Path],
src_path: Union[str, Path],
dst_path: Union[str, Path],
ignore_missing: bool = False,
exist_ok: bool = False,
) -> None:
......@@ -56,8 +73,8 @@ def tree_move(
assert ignore_missing or src_path.exists(), f"src_path must exists when ignore_missing is False on {ctx_msg}"
assert exist_ok or not dst_path.exists(), f"dst_path must not exists when exist_ok is False on {ctx_msg}"
assert src_path != dst_path, f"paths must not be identical on {ctx_msg}"
assert not dst_path.is_relative_to(src_path), f"dst_path must not be relative to src_path on {ctx_msg}"
assert not src_path.is_relative_to(dst_path), f"src_path must not be relative to dst_path on {ctx_msg}"
assert not is_relative_to(dst_path, src_path), f"dst_path must not be relative to src_path on {ctx_msg}"
assert not is_relative_to(src_path, dst_path), f"src_path must not be relative to dst_path on {ctx_msg}"
if ignore_missing and not src_path.exists():
return
if exist_ok and dst_path.exists():
......
......@@ -36,7 +36,10 @@ public:
std::shared_ptr<SequentialScheduler> mScheduler;
std::weak_ptr<Node> mUpperNode;
public:
private:
const std::shared_ptr<DynamicAttributes> mAttributes = std::make_shared<DynamicAttributes>();
public:
MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory = {});
/**
......@@ -92,7 +95,7 @@ public:
mGraph->setDataType(datatype);
}
std::shared_ptr<Attributes> attributes() const override;
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
......
......@@ -95,7 +95,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs)
Log::debug("getBestMatch() for requirements: {}", requiredSpecs);
const auto availableSpecsSet = getAvailableImplSpecs();
AIDGE_ASSERT(availableSpecsSet.size() > 0 ,
AIDGE_ASSERT(availableSpecsSet.size() > 0 ,
"OperatorImpl::getBestMatch(): No available specs found by"
"getAvailableSpecs(). "
"Cannot find best implementation for required specs, aborting.");
......@@ -139,7 +139,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs)
if (mandatory) {
// Required attribute:
if (!spec.attrs.hasAttr(name)) {
Log::debug("Could not find mandatory attribute {} value {}.", name);
Log::debug("Could not find mandatory attribute '{}'.", name);
// Missing attribute
match = false;
break;
......
......@@ -157,15 +157,15 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator());
if (op && !op->getOutput(outputIdx)->undefined()) {
dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims());
dtype += "\n" + fmt::format("{}", op->getOutput(outputIdx)->dataType());
dtype += " " + fmt::format("{}", op->getOutput(outputIdx)->dataType());
}
if (mNodes.find(child) != mNodes.end()) {
fmt::print(fp.get(), "{}_{}-->|\"{}{}{}&rarr;{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr),
fmt::print(fp.get(), "{}_{}-->|\"{}{}{}<br/>&darr;<br/>{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr),
outputIdx, dims, dtype, inputIdx, child->type(), namePtrTable.at(child));
}
else if (verbose) {
fmt::print(fp.get(), "{}_{}-->|\"{}{}{}&rarr;{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr),
fmt::print(fp.get(), "{}_{}-->|\"{}{}{}<br/>&darr;<br/>{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr),
outputIdx, dims, dtype, inputIdx, static_cast<void*>(child.get()));
}
// Do no break here because the same child can be connected to several inputs
......@@ -182,11 +182,13 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
for (const auto& input : mInputNodes) {
if (input.first != nullptr) {
const auto& op_ = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator());
if (op_->getInput(input.second) && (!op_->getInput(input.second)->empty())) {
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}{}\"|{}_{}\n", inputIdx, inputIdx,
input.second, op_->getInput(input.second)->dims(), input.first->type(), namePtrTable.at(input.first));
if (op_->getInput(input.second) && (!op_->getInput(input.second)->undefined())) {
std::string dims = " " + fmt::format("{}", op_->getInput(input.second)->dims());
std::string dtype = " " + fmt::format("{}", op_->getInput(input.second)->dataType());
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"{}{}<br/>&darr;<br/>{}\"|{}_{}\n", inputIdx, inputIdx,
dims, dtype, input.second, input.first->type(), namePtrTable.at(input.first));
} else {
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}\"|{}_{}\n", inputIdx, inputIdx,
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&darr;<br/>{}\"|{}_{}\n", inputIdx, inputIdx,
input.second, input.first->type(), namePtrTable.at(input.first));
}
}
......@@ -201,14 +203,16 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
if (output.first != nullptr) {
// Add-on to display the operator's output dimensions
std::string dims = "";
std::string dtype = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator());
if (op && op->getOutput(output.second) && !op->getOutput(output.second)->undefined()) {
dims += " " + fmt::format("{}", op->getOutput(output.second)->dims());
dtype += " " + fmt::format("{}", op->getOutput(output.second)->dataType());
}
fmt::print(fp.get(), "{}_{}--->|\"{}{}&rarr;\"|output{}((out#{})):::outputCls\n",
fmt::print(fp.get(), "{}_{}--->|\"{}{}{}<br/>&darr;\"|output{}((out#{})):::outputCls\n",
output.first->type(), namePtrTable.at(output.first), output.second,
dims, outputIdx, outputIdx);
dims, dtype, outputIdx, outputIdx);
}
else {
fmt::print(fp.get(), "output{}((out#{})):::outputCls\n", outputIdx, outputIdx);
......@@ -1236,7 +1240,6 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
if (removeFromGraphs) {
for (const auto& g : commonGraphViews) {
g -> remove(nodePtr, false);
g -> updateInputsOutputsDelete(nodePtr);
}
nodePtr -> resetConnections(true);
}
......
......@@ -48,6 +48,10 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar
mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second));
}
}
for (const auto& node : mGraph->getRankedNodesName("{1}_{3}")) {
mAttributes->addAttr(node.second, node.first->getOperator()->attributes());
}
}
std::shared_ptr<Aidge::Operator> Aidge::MetaOperator_Op::clone() const {
......@@ -119,22 +123,6 @@ std::set<std::string> Aidge::MetaOperator_Op::getAvailableBackends() const {
return backendsList;
}
std::shared_ptr<Aidge::Attributes> Aidge::MetaOperator_Op::attributes() const {
auto attrs = std::make_shared<DynamicAttributes>();
for (const auto& node : mGraph->getRankedNodesName("{3}")) {
const auto attributes = node.first->getOperator()->attributes();
if (attributes) {
const auto nodeAttrs = DynamicAttributes(attributes->getAttrs());
attrs->addAttr(node.first->type() + "#" + node.second, nodeAttrs);
if (node.second == "0") {
attrs->addAttr(node.first->type(), nodeAttrs);
}
}
}
return attrs;
}
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
if (mImpl) {
......
......@@ -88,7 +88,7 @@ std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IO
}
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(outputIdx < nbOutputs(), "{} Operator has {} outputs", type(), nbOutputs());
AIDGE_ASSERT(outputIdx < nbOutputs(), "{} Operator has {} outputs, asked for output#{}", type(), nbOutputs(), outputIdx);
return mOutputs[outputIdx];
}
......