diff --git a/aidge_core/export_utils/__init__.py b/aidge_core/export_utils/__init__.py index a97e978749d1f5480ef8ef1e7e9c5f00d9c3d7df..b17ff90d61131daff069fe201589ba578f221a2b 100644 --- a/aidge_core/export_utils/__init__.py +++ b/aidge_core/export_utils/__init__.py @@ -3,4 +3,4 @@ from .code_generation import generate_file, generate_str, copy_file from .export_registry import ExportLib from .scheduler_export import scheduler_export from .tensor_export import tensor_to_c, generate_input_file - +from .generate_main import generate_main_cpp diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index fd24008a6de6c58c1e78f088e086817e2a769373..70c3e5fa47cb35cbf6611a5359e6a37e0f17620d 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -28,10 +28,8 @@ class ExportLib(aidge_core.OperatorImpl): # key: Path where static file is # Value: Path where to copy the file relative to the export root static_files: Dict[str, str] = {} - # Custom main generation jinja file - main_jinja_path = None # Main memory section - memory_section = None + mem_section = None # PRIVATE # Registry of exportNode, class level dictionary, shared across all ExportLib _cls_export_node_registry = {} diff --git a/aidge_core/export_utils/generate_main.py b/aidge_core/export_utils/generate_main.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eee930669cc5f15f88714acf4884f5e30333b1 --- /dev/null +++ b/aidge_core/export_utils/generate_main.py @@ -0,0 +1,51 @@ +import aidge_core +from pathlib import Path +from aidge_core.export_utils import generate_file, data_conversion + +def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView) -> None: + """ + Generate a C++ file to manage the forward pass of a model using the given graph structure. + + This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types, + and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`) + that call the `model_forward` function, which handles data flow and processing for the exported model. + + :param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved. + :type export_folder: str + :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and + ordered input/output data within the computational graph. + :type graph_view: aidge_core.graph_view + :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes), + indicating an internal bug in the graph representation. + """ + outputs_name: list[str] = [] + outputs_dtype: list[str] = [] + outputs_size: list[int] = [] + inputs_name: list[str] = [] + gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs() + gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs() + + for in_node, in_idx in gv_inputs: + in_node_input, in_node_input_idx = in_node.input(in_idx) + inputs_name.append(f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}") + for out_node, out_id in gv_outputs: + outputs_name.append(f"{out_node.name()}_output_{out_id}") + out_tensor = out_node.get_operator().get_output(out_id) + outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype())) + outputs_size.append(out_tensor.size()) + print(out_tensor.size()) + + + if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): + raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.") + + ROOT = Path(__file__).resolve().parents[0] + generate_file( + str(Path(export_folder) / "main.cpp"), + str(ROOT / "templates" / "main.jinja"), + func_name="model_forward", + inputs_name=inputs_name, + outputs_name=outputs_name, + outputs_dtype=outputs_dtype, + outputs_size=outputs_size + ) diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index df0b4a385327e4bdccd6fe4de46043d151658dbd..f1de6f823356b0d217dc7fd98db006a8c97e450c 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -98,14 +98,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = outputs_size.append(op.attributes["out_size"][idx]) func_name = "model_forward" - - - args = ", ".join([f"const {dtype}* {name}" for name, - dtype in zip(inputs_name, inputs_dtype)]) - args += ", " +", ".join([f"{dtype}** {name}" for name, - dtype in zip(outputs_name, outputs_dtype)]) - forward_func = f"void {func_name}({args})" - ROOT = Path(__file__).resolve().parents[0] generate_file( str(dnn_folder / "src" / "forward.cpp"), @@ -114,7 +106,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = headers=set(list_configs), actions=list_actions, mem_ctype=inputs_dtype[0], # Legacy behavior ... - mem_section=export_lib.mem_section, + mem_section=export_lib.mem_section, peak_mem=peak_mem, inputs_name=inputs_name, inputs_dtype=inputs_dtype, @@ -137,22 +129,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.") - if export_lib is not None and export_lib.main_jinja_path is not None: - main_jinja_path = export_lib.main_jinja_path - else : - main_jinja_path = str(ROOT / "templates" / "main.jinja") - - generate_file( - str(export_folder / "main.cpp"), - main_jinja_path, - func_name=func_name, - inputs_name=inputs_name, - outputs_name=outputs_name, - outputs_dtype=outputs_dtype, - outputs_size=outputs_size, - labels=labels - ) - if export_lib is not None: # Copy all static files in the export for source, destination in export_lib.static_files.items(): diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py index 633298f10dbfdafe40022f88f741f82d2d35c681..4f6a2960348c44dc7b8a0b957f777ddac5a8562a 100644 --- a/aidge_core/show_graphview.py +++ b/aidge_core/show_graphview.py @@ -79,29 +79,32 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e if parents[0] is None: parents.append(parents.pop(0)) else: pass - + parents_inputs = [] - for parent in parents: + input_idx = 0 + for parent in node.get_parents(): if parent is not None: - for output_idx in range(parent.get_operator().nb_outputs()): - for input_idx in range(node.get_operator().nb_inputs()): - if parent.get_operator().get_output(output_idx).dims() == node.get_operator().get_input(input_idx).dims(): + for children in parent.outputs(): + for child in children: + if child[0] == node and child[1] == input_idx: parents_inputs.append((parent.name(), input_idx)) - + elif parent is None: - for input_idx in list(range(node.get_operator().nb_inputs())): - if input_idx not in [item[1] for item in parents_inputs]: - parents_inputs.append((None, input_idx)) - - parents_inputs.sort(key=lambda x: x[1]) + if input_idx not in [item[1] for item in parents_inputs]: + parents_inputs.append((None, input_idx)) + + input_idx += 1 node_dict['parents'] = parents_inputs children_outputs = [] - for child in node.get_children(): - for input_idx in range(child.get_operator().nb_inputs()): - for output_idx in range(node.get_operator().nb_outputs()): - if child.get_operator().get_input(input_idx).dims() == node.get_operator().get_output(output_idx).dims(): - children_outputs.append((child.name(), output_idx)) + output_idx = 0 + for children in node.get_ordered_children(): + for child in children: + if child is not None: + for parent in child.inputs(): + if parent[0] == node and parent[1] == output_idx: + children_outputs.append((child.name(), output_idx)) + output_idx += 1 node_dict['children'] = children_outputs # Check if my node is a metaop @@ -129,7 +132,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e if params_file_format=='npz': np.savez_compressed(Path(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)}) - node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.npz') + node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.npz')) elif params_file_format=='json': tensor = np.array(node.get_operator().get_output(0)) @@ -145,13 +148,13 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e with open(Path(path_trainable_params, node.name() + '.json'), 'w') as fp: json.dump(tensor_dict, fp, indent=4) - node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.json') + node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.json')) else: raise Exception("File format to write trainable parameters not recognized.") - elif write_trainable_params_embed: + if write_trainable_params_embed: node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist() else: @@ -195,17 +198,21 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : Path, write_trainabl :type params_file_format: str, optional """ - if json_path.is_dir(): - json_path = (json_path.parent).joinpath('model.json') + if not json_path.suffix: + if not json_path.is_dir(): + json_path.mkdir(parents=True, exist_ok=True) + json_path = json_path.joinpath('model.json') - elif not json_path.is_dir(): - if json_path.suffix == '.json': - pass - else: - raise Exception('If ``json_path`` contains a filename it must be of JSON format.') + else: + if json_path.suffix != '.json': + raise Exception('If ``json_path`` contains a filename, it must be of JSON format.') + if not json_path.parent.is_dir(): + json_path.parent.mkdir(parents=True, exist_ok=True) if write_trainable_params_ext: path_trainable_params = (json_path.parent).joinpath(json_path.stem + '_trainable_params/') + path_trainable_params.mkdir(parents=True, exist_ok=True) + else: path_trainable_params = Path() diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index fd2a0b3f42d5888a68edb18caf046cea71dec0f3..9390fe5860b5d3523886856d9b2a40752d338af5 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -50,10 +50,10 @@ public: void zeros() override final; void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final { + AIDGE_ASSERT(offset + length <= mNbElts, "TensorImpl_cpu<{}>::copy(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mNbElts); const T* srcT = static_cast<const T *>(src); T* dstT = static_cast<T *>(rawPtr(offset)); - AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "TensorImpl_cpu<{}>::copy(): copy length ({}) is above capacity ({})", typeid(T).name(), length, mNbElts); AIDGE_ASSERT(dstT < srcT || dstT >= srcT + length, "TensorImpl_cpu<{}>::copy(): overlapping copy is not supported", typeid(T).name()); std::copy(srcT, srcT + length, dstT); } @@ -72,7 +72,7 @@ public: void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const override final { const T* src = static_cast<const T*>(rawPtr(offset)); - AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "TensorImpl_cpu<{}>::copyToHost(): copy length ({}) is above capacity ({})", typeid(T).name(), length, mNbElts); + AIDGE_ASSERT(offset + length <= mData.size(), "TensorImpl_cpu<{}>::copy(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mData.size()); std::copy(src, src + length, static_cast<T *>(dst)); } diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index a9b9213e914811ccff7d1e6d8efe4fdd8a505b87..82ecc7d28b723d2b3e268f4fb6fbf20d595240ff 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -134,6 +134,23 @@ void explicitTranspose(std::shared_ptr<GraphView> graphView); */ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); +/** + * @brief Tile any :cpp:function:`Aidge::MatMul` operator to several fixed size matrix multiplications. + * For instance, for a MatMul of size 80x80 and a tiling of 16x16, this will tile + * the MatMul operator to 25 (5 by 5) MatMul operators of size 16x16, with Slice + * operators inserted at the inputs and Concat operators inserted at the outputs. + * + * This is especially useful when matrix multiplication must be mapped to fixed + * maximum size hardware TPU (Tensor Processing Unit) or MMA (Matrix Multiplication + * Accelerator). This recipe can be combined with the :cpp:function:`Aidge::convToMatMul` recipe in + * order to convert convolutions to matrix multiplication beforehand, and + * :cpp:function:`Aidge::constantFolding` recipe to fold sliced constant tensors. + * + * @param matMul MatMul operator to be tiled. + * @param maxDims Maximum output dimensions of the tiled MatMul operators. + */ +void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims); + /** * Fuse each sub-graph matching a query in a Meta Operator. * @param graph Graph to manipulate diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index 35f6327444a874d8f5c2e94da6520244e095263a..69a28960b57e6ba2ac8a699bf45ff09961fa4135 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -176,6 +176,11 @@ void init_Node(py::module& m) { Get children. )mydelimiter") + .def("get_ordered_children", &Node::getOrderedChildren, + R"mydelimiter( + Get ordered children. + )mydelimiter") + .def("__call__", [](Node &self, pybind11::args args) { std::vector<Connector> connectors; diff --git a/src/backend/cpu/data/TensorImpl.cpp b/src/backend/cpu/data/TensorImpl.cpp index ed3c96f80c1b8bafd70425451d6618428d1888f0..506287a0c520915e6426f1f0b64d9c562c754d33 100644 --- a/src/backend/cpu/data/TensorImpl.cpp +++ b/src/backend/cpu/data/TensorImpl.cpp @@ -47,8 +47,8 @@ void Aidge::TensorImpl_cpu<T>::copyCast(const void *src, const Aidge::DataType s return; } + AIDGE_ASSERT(offset + length <= mNbElts, "TensorImpl_cpu<{}>::copyCast(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mNbElts); T* dstT = static_cast<T *>(rawPtr(offset)); - AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "TensorImpl_cpu<{}>::copyCast(): copy length ({}) is above capacity ({})", typeid(T).name(), length, mNbElts); switch (srcDt) { case DataType::Float64: diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp index 55efdd51d56f7db4f64880b967def661e5354af5..27b9d1cf151c1d12aa4395a3b24673a2f2a4ad3c 100644 --- a/src/operator/Concat.cpp +++ b/src/operator/Concat.cpp @@ -49,7 +49,9 @@ std::shared_ptr<Aidge::Operator> Aidge::Concat_Op::clone() const { void Aidge::Concat_OpImpl::forward() { const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp); - const DimSize_t axis = op.axis(); + auto axis = op.axis(); + const auto nbDimsInput0 = op.getInput(0)->nbDims(); + axis = (axis < 0) ? axis + static_cast<std::int32_t>(nbDimsInput0) : axis; assert(op.getInput(0) && "missing input in Concat operator"); for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) { diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp index 668ffd04b7acb0e72b4a3313805fa89ca3466f32..8fd2aa068c91dfebd6d1a3a47900c3aa9b0f9585 100644 --- a/src/operator/MatMul.cpp +++ b/src/operator/MatMul.cpp @@ -71,7 +71,7 @@ bool Aidge::MatMul_Op::forwardDims(bool /*allowDataDependency*/) { std::vector<std::size_t> outDims = std::vector<std::size_t>(dims_size-2, 1); for (std::size_t i = 0; i < dims_size-2; ++i) { - AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad vector dimension."); + AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad dimension {}: {} != {} for input #0 {} and #1 {}.", i, dims0[i], dims1[i], dims0, dims1); outDims[i] = std::max(dims0[i], dims1[i]); } diff --git a/src/recipes/MatMulTiling.cpp b/src/recipes/MatMulTiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cfc0b191a2f8b47aec92d1dec5ca8a44c95db5db --- /dev/null +++ b/src/recipes/MatMulTiling.cpp @@ -0,0 +1,131 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <cassert> +#include <memory> +#include <set> +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Slice.hpp" +#include "aidge/operator/Identity.hpp" +#include "aidge/operator/Concat.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +// see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm +void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) { + if (matMul->getOperator()->type() != "MatMul") { + AIDGE_INTERNAL_ASSERT("Operator should be a MatMul."); + } + AIDGE_ASSERT(matMul->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); + const auto& op = std::static_pointer_cast<OperatorTensor>(matMul->getOperator()); + if (!op->dimsForwarded()) { + AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling"); + } + + const auto& in0Tensor = op->getInput(0); + const auto& in1Tensor = op->getInput(1); + const auto& outTensor = op->getOutput(0); + const auto& input0Dims = in0Tensor->dims(); + const auto& input1Dims = in1Tensor->dims(); + const auto& outputDims = outTensor->dims(); + const auto& outputMatDims = std::vector<std::size_t>(outputDims.end() - 2, outputDims.end());; + + if (outputMatDims[0] > maxDims[0] || outputMatDims[1] > maxDims[1]) { + const auto sliceDims = (outputMatDims[0] > maxDims[0]) ? input0Dims : input1Dims; + std::int32_t axis; + std::int64_t splitIndex0_end = static_cast<std::int64_t>(sliceDims.end()[-2]); + std::int64_t splitIndex0_start = 0; + std::int64_t splitIndex1_end = static_cast<std::int64_t>(sliceDims.end()[-1]); + std::int64_t splitIndex1_start = 0; + + if (outputMatDims[0] > maxDims[0]) { + splitIndex0_end = maxDims[0]; + splitIndex0_start = maxDims[0]; + axis = -2; + } + else { + splitIndex1_end = maxDims[1]; + splitIndex1_start = maxDims[1]; + axis = -1; + } + + auto identity0 = Identity(); + auto sliceX0 = Slice(); + auto sliceX0_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{0, 0}}), "", true); + sliceX0_starts->addChild(sliceX0, 0, 1); + auto sliceX0_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_end, splitIndex1_end}}), "", true); + sliceX0_ends->addChild(sliceX0, 0, 2); + auto sliceX0_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); + sliceX0_axes->addChild(sliceX0, 0, 3); + auto sliceX0_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); + sliceX0_steps->addChild(sliceX0, 0, 4); + auto matMulX0 = MatMul(); + auto identity1 = Identity(); + auto sliceX1 = Slice(); + auto sliceX1_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_start, splitIndex1_start}}), "", true); + sliceX1_starts->addChild(sliceX1, 0, 1); + auto sliceX1_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{static_cast<std::int64_t>(sliceDims.end()[-2]), static_cast<std::int64_t>(sliceDims.end()[-1])}}), "", true); + sliceX1_ends->addChild(sliceX1, 0, 2); + auto sliceX1_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); + sliceX1_axes->addChild(sliceX1, 0, 3); + auto sliceX1_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); + sliceX1_steps->addChild(sliceX1, 0, 4); + auto matMulX1 = MatMul(); + auto concat = Concat(2, axis); + + if (outputMatDims[0] > maxDims[0]) { + identity0->addChild(sliceX0, 0, 0); + identity0->addChild(sliceX1, 0, 0); + identity1->addChild(matMulX0, 0, 1); + identity1->addChild(matMulX1, 0, 1); + sliceX0->addChild(matMulX0, 0, 0); + sliceX1->addChild(matMulX1, 0, 0); + } + else { + identity0->addChild(matMulX0, 0, 0); + identity0->addChild(matMulX1, 0, 0); + identity1->addChild(sliceX0, 0, 0); + identity1->addChild(sliceX1, 0, 0); + sliceX0->addChild(matMulX0, 0, 1); + sliceX1->addChild(matMulX1, 0, 1); + } + + matMulX0->addChild(concat, 0, 0); + matMulX1->addChild(concat, 0, 1); + + auto gMatMul = std::make_shared<GraphView>(); + gMatMul->add({matMul}); + + auto g = std::make_shared<GraphView>(); + g->add({identity0}); + g->add({identity1}); + g->add({sliceX0, sliceX0_starts, sliceX0_ends, sliceX0_axes, sliceX0_steps, matMulX0, matMulX1, sliceX1, sliceX1_starts, sliceX1_ends, sliceX1_axes, sliceX1_steps, concat}); + + auto replaced = GraphView::replace(gMatMul, g); + + if (replaced) { + g->forwardDims({}, true); + + // Recursive tiling + matMulTiling(matMulX1, maxDims); + matMulTiling(matMulX0, maxDims); + } + else { + Log::warn("Unable to split MatMul {}", matMul->name()); + } + } +} diff --git a/src/recipes/RemoveNode.cpp b/src/recipes/RemoveNode.cpp index a09c67991409dfe491d46b4ad739f9ddf5b72aef..3a1bac588ee8a1bb38f74fee441c9eff07b4ef6e 100644 --- a/src/recipes/RemoveNode.cpp +++ b/src/recipes/RemoveNode.cpp @@ -13,24 +13,15 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/recipes/Recipes.hpp" - -//Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - size_t Aidge::removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers) { - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey(type, "getType($) =='" + type + "'"); - regex->addQuery(type + "#"); - - const auto matches = regex->match(graphView); - for (const auto& solution : matches) { - assert(solution->at(type).size() == 1 && "Wrong number of nodes to replace\n"); - - std::set<NodePtr> nodesToRemove = solution->at(type); + auto matches = SinglePassGraphMatching(graphView).match(type); + for (const auto& match : matches) { + std::set<NodePtr> nodesToRemove = {match.graph->rootNode()}; if (incProducers) { - for (const auto& nodePtr: (*solution->at(type).begin())->getParents()) { + for (const auto& nodePtr: match.graph->rootNode()->getParents()) { if (nodePtr != nullptr && nodePtr->type() == "Producer") { nodesToRemove.insert(nodePtr); } diff --git a/unit_tests/backend/Test_TensorImpl.cpp b/unit_tests/backend/Test_TensorImpl.cpp index 43e25092a0f502698bbff7b0142969154f2cb0b0..ceb6772d01d4ee84524896fead96abcb445f84ff 100644 --- a/unit_tests/backend/Test_TensorImpl.cpp +++ b/unit_tests/backend/Test_TensorImpl.cpp @@ -47,6 +47,7 @@ TEST_CASE("Tensor fill", "[TensorImpl][fill]") { concatenatedTensor->getImpl()->copy(myTensor1->getImpl()->rawPtr(), 5, 0); concatenatedTensor->getImpl()->copy(myTensor2->getImpl()->rawPtr(), 5, 5); concatenatedTensor->getImpl()->copy(myTensor3->getImpl()->rawPtr(), 5, 10); + REQUIRE_THROWS(concatenatedTensor->getImpl()->copy(myTensor3->getImpl()->rawPtr(), 5, 11)); // concatenatedTensor->print(); std::shared_ptr<Tensor> expectedTensor= std::make_shared<Tensor>(Array2D<int, 3, 5>{