diff --git a/aidge_core/aidge_export_aidge/export.py b/aidge_core/aidge_export_aidge/export.py index 51468ed846dc7a731152a1ddb3f4374847631402..f5d8c8c7ca6f3fa7c7ef19b1aef3987c20acd1f7 100644 --- a/aidge_core/aidge_export_aidge/export.py +++ b/aidge_core/aidge_export_aidge/export.py @@ -53,7 +53,7 @@ def serialize_to_cpp(export_folder: str, ### Generating an export for each nodes and dnn file ### list_configs = [] # List of headers to include in dnn.cpp to access attribute and parameters list_actions = [] # List of string to construct graph - set_operator = set() + list_operators = [] # List of operator types used (to be made unique latter) # Queue of Aidge nodes to explore, guarantee a topological exploration of the graph open_nodes = list(graph_view.get_input_nodes()) # List of Aidge nodes already explored @@ -102,12 +102,13 @@ def serialize_to_cpp(export_folder: str, # Add forward kernel list_actions += op.forward() closed_nodes.append(node) + list_operators = list(dict.fromkeys(list_operators)) # make unique # Generate full dnn.cpp aidge_core.export_utils.generate_file( export_folder_path / "src/dnn.cpp", ROOT_EXPORT / "templates/dnn.jinja", headers=list_configs, - operators=set_operator, + operators=list_operators, actions=list_actions, ) diff --git a/aidge_core/unit_tests/test_export.py b/aidge_core/unit_tests/test_export.py index b8e1f0ba9d5f72c80f25f68884b797f138dd69d0..d98a6fdbc20dd7e99169422205a4e680350aed27 100644 --- a/aidge_core/unit_tests/test_export.py +++ b/aidge_core/unit_tests/test_export.py @@ -8,8 +8,6 @@ http://www.eclipse.org/legal/epl-2.0. SPDX-License-Identifier: EPL-2.0 """ -import aidge_core -from aidge_core.utils import run_command import unittest import os import pathlib @@ -18,6 +16,10 @@ import subprocess import sys +import aidge_core +from aidge_core.utils import run_command +from aidge_core.testing.utils import tree_update_from_cache, tree_move, tree_remove + def initFiller(model): # Initialize parameters (weights and biases) for node in model.get_nodes(): @@ -45,22 +47,6 @@ def initFiller(model): pass -def clean_dir(dir: pathlib.Path) -> None: - if not dir.is_dir(): - print(f"Error : directory {dir} doesn't exist. Exiting clean_dir().") - return - for filename in os.listdir(dir): - file_path = os.path.join(dir, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") - return - - class test_export(unittest.TestCase): """Test aidge export""" @@ -68,6 +54,10 @@ class test_export(unittest.TestCase): self.EXPORT_PATH: pathlib.Path = pathlib.Path("dummy_export") self.BUILD_DIR: pathlib.Path = self.EXPORT_PATH / "build" self.INSTALL_DIR: pathlib.Path = (self.EXPORT_PATH / "install").absolute() + self.TMP_BUILD_DIR: pathlib.Path = ( + self.EXPORT_PATH.parent / + f"__tmp_{self.EXPORT_PATH.name}_build" + ) def tearDown(self): pass @@ -92,15 +82,27 @@ class test_export(unittest.TestCase): initFiller(model) model.forward_dims([[1, 32*32*3]]) + # Preserve previously generated build if present + tree_move(self.BUILD_DIR, self.TMP_BUILD_DIR, ignore_missing=True, exist_ok=True) + # Clean install dir + tree_remove(self.INSTALL_DIR, ignore_missing=True) + # Export model aidge_core.serialize_to_cpp(self.EXPORT_PATH, model) - - self.assertTrue( - self.EXPORT_PATH.is_dir(), "Export folder has not been generated" + self.assertTrue(self.EXPORT_PATH.is_dir(), "Export folder has not been generated") + # Add other source files + shutil.copyfile(pathlib.Path(__file__).parent / "static/main.cpp", self.EXPORT_PATH / "main.cpp") + + # Use cache if any, put cache inside export dir + # such that cleaning export dir also cleans the cache + tree_update_from_cache( + self.EXPORT_PATH, + cache_path=self.EXPORT_PATH / "__cache_export" ) - os.makedirs(self.BUILD_DIR, exist_ok=True) - clean_dir(self.BUILD_DIR) # if build dir existed already ensure its emptyness - clean_dir(self.INSTALL_DIR) + + # Move back preserved build dir if any and ensure build dir exists + tree_move(self.TMP_BUILD_DIR, self.BUILD_DIR, ignore_missing=True) + self.BUILD_DIR.mkdir(exist_ok=True) # Test compilation of export search_path = ( @@ -109,11 +111,6 @@ class test_export(unittest.TestCase): else os.environ["AIDGE_INSTALL"] ) - shutil.copyfile( - pathlib.Path(__file__).parent / "static/main.cpp", - self.EXPORT_PATH / "main.cpp", - ) - ########################## # CMAKE EXPORT try: diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 34aea5ffd909f57e4834deeed6d0bdc2664d4644..7af3c62c5d0af33b01e596ecf4c91c35ab3e17b7 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -77,6 +77,9 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera // we should always consume available data first. This is ensured // by setting the consumers list to the output nodes and then recursively // find the dependencies. + // The initial list may contain producer nodes, in which case + // getPriorProducersConsumers() at step 2 will have moved it in + // the requiredProducers list. std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes(); std::set<std::shared_ptr<Node>> producers; @@ -300,19 +303,23 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>& consumer, const std::string& nodeName) const { Log::debug("\t- consumer: {}", fmt::styled(nodeName, fg(fmt::color::orange))); std::string crLog = "\t\tC/R:\t"; - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); + if (consumer->nbInputs() > 0) { + for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { + crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); + } + crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); } - crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); Log::debug("{}", crLog); std::string pLog = "\t\tP:\t"; - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + if (consumer->nbOutputs() > 0) { + for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { + pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + } + pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); } - pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); Log::debug("{}", pLog); } @@ -733,7 +740,9 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) } prior.isPrior = true; - if (prior.priorConsumers.empty()) { + if (node->type() == Producer_Op::Type) { + prior.requiredProducers.insert(node); + } else if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } mPriorCache.insert(std::make_pair(node, prior)); diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 3c3026ff09222f9623d886f9c4574bf23667cd9a..ec850d28109a2682bb762c89e814622de6eec3d8 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -26,6 +26,8 @@ #include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Identity.hpp" +#include "aidge/operator/GenericOperator.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" namespace Aidge { @@ -134,4 +136,58 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); } +TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { + + SECTION("Identity Graph") { + auto data1 = Producer({1}, "data1"); + auto identity = Identity("id"); + auto g = std::make_shared<GraphView>("TestGraph"); + data1->addChild(identity); + g->add({data1, identity}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == data1); + REQUIRE(sch[1] == identity); + } + + SECTION("Producer Graph") { + auto data1 = Producer({1}, "data1"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({data1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == data1); + } + + SECTION("Generic producer Graph") { + auto gen1 = GenericOperator("Prod", 0, 0, 1, "gen1"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({gen1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == gen1); + } + + SECTION("No output Graph") { + auto dead1 = GenericOperator("Dead", 1, 0, 0, "dead"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({dead1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(nodes.size() == 1); + REQUIRE(sch.size() == 0); + } +} + } // namespace Aidge