Skip to content
Snippets Groups Projects
Commit fffdfa2b authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/dev' into allowNoInputProducer

parents 5f50b624 c57db282
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
Pipeline #57542 failed
...@@ -53,7 +53,7 @@ def serialize_to_cpp(export_folder: str, ...@@ -53,7 +53,7 @@ def serialize_to_cpp(export_folder: str,
### Generating an export for each nodes and dnn file ### ### Generating an export for each nodes and dnn file ###
list_configs = [] # List of headers to include in dnn.cpp to access attribute and parameters list_configs = [] # List of headers to include in dnn.cpp to access attribute and parameters
list_actions = [] # List of string to construct graph list_actions = [] # List of string to construct graph
set_operator = set() list_operators = [] # List of operator types used (to be made unique latter)
# Queue of Aidge nodes to explore, guarantee a topological exploration of the graph # Queue of Aidge nodes to explore, guarantee a topological exploration of the graph
open_nodes = list(graph_view.get_input_nodes()) open_nodes = list(graph_view.get_input_nodes())
# List of Aidge nodes already explored # List of Aidge nodes already explored
...@@ -102,12 +102,13 @@ def serialize_to_cpp(export_folder: str, ...@@ -102,12 +102,13 @@ def serialize_to_cpp(export_folder: str,
# Add forward kernel # Add forward kernel
list_actions += op.forward() list_actions += op.forward()
closed_nodes.append(node) closed_nodes.append(node)
list_operators = list(dict.fromkeys(list_operators)) # make unique
# Generate full dnn.cpp # Generate full dnn.cpp
aidge_core.export_utils.generate_file( aidge_core.export_utils.generate_file(
export_folder_path / "src/dnn.cpp", export_folder_path / "src/dnn.cpp",
ROOT_EXPORT / "templates/dnn.jinja", ROOT_EXPORT / "templates/dnn.jinja",
headers=list_configs, headers=list_configs,
operators=set_operator, operators=list_operators,
actions=list_actions, actions=list_actions,
) )
...@@ -8,8 +8,6 @@ http://www.eclipse.org/legal/epl-2.0. ...@@ -8,8 +8,6 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0 SPDX-License-Identifier: EPL-2.0
""" """
import aidge_core
from aidge_core.utils import run_command
import unittest import unittest
import os import os
import pathlib import pathlib
...@@ -18,6 +16,10 @@ import subprocess ...@@ -18,6 +16,10 @@ import subprocess
import sys import sys
import aidge_core
from aidge_core.utils import run_command
from aidge_core.testing.utils import tree_update_from_cache, tree_move, tree_remove
def initFiller(model): def initFiller(model):
# Initialize parameters (weights and biases) # Initialize parameters (weights and biases)
for node in model.get_nodes(): for node in model.get_nodes():
...@@ -45,22 +47,6 @@ def initFiller(model): ...@@ -45,22 +47,6 @@ def initFiller(model):
pass pass
def clean_dir(dir: pathlib.Path) -> None:
if not dir.is_dir():
print(f"Error : directory {dir} doesn't exist. Exiting clean_dir().")
return
for filename in os.listdir(dir):
file_path = os.path.join(dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
return
class test_export(unittest.TestCase): class test_export(unittest.TestCase):
"""Test aidge export""" """Test aidge export"""
...@@ -68,6 +54,10 @@ class test_export(unittest.TestCase): ...@@ -68,6 +54,10 @@ class test_export(unittest.TestCase):
self.EXPORT_PATH: pathlib.Path = pathlib.Path("dummy_export") self.EXPORT_PATH: pathlib.Path = pathlib.Path("dummy_export")
self.BUILD_DIR: pathlib.Path = self.EXPORT_PATH / "build" self.BUILD_DIR: pathlib.Path = self.EXPORT_PATH / "build"
self.INSTALL_DIR: pathlib.Path = (self.EXPORT_PATH / "install").absolute() self.INSTALL_DIR: pathlib.Path = (self.EXPORT_PATH / "install").absolute()
self.TMP_BUILD_DIR: pathlib.Path = (
self.EXPORT_PATH.parent /
f"__tmp_{self.EXPORT_PATH.name}_build"
)
def tearDown(self): def tearDown(self):
pass pass
...@@ -92,15 +82,27 @@ class test_export(unittest.TestCase): ...@@ -92,15 +82,27 @@ class test_export(unittest.TestCase):
initFiller(model) initFiller(model)
model.forward_dims([[1, 32*32*3]]) model.forward_dims([[1, 32*32*3]])
# Preserve previously generated build if present
tree_move(self.BUILD_DIR, self.TMP_BUILD_DIR, ignore_missing=True, exist_ok=True)
# Clean install dir
tree_remove(self.INSTALL_DIR, ignore_missing=True)
# Export model # Export model
aidge_core.serialize_to_cpp(self.EXPORT_PATH, model) aidge_core.serialize_to_cpp(self.EXPORT_PATH, model)
self.assertTrue(self.EXPORT_PATH.is_dir(), "Export folder has not been generated")
self.assertTrue( # Add other source files
self.EXPORT_PATH.is_dir(), "Export folder has not been generated" shutil.copyfile(pathlib.Path(__file__).parent / "static/main.cpp", self.EXPORT_PATH / "main.cpp")
# Use cache if any, put cache inside export dir
# such that cleaning export dir also cleans the cache
tree_update_from_cache(
self.EXPORT_PATH,
cache_path=self.EXPORT_PATH / "__cache_export"
) )
os.makedirs(self.BUILD_DIR, exist_ok=True)
clean_dir(self.BUILD_DIR) # if build dir existed already ensure its emptyness # Move back preserved build dir if any and ensure build dir exists
clean_dir(self.INSTALL_DIR) tree_move(self.TMP_BUILD_DIR, self.BUILD_DIR, ignore_missing=True)
self.BUILD_DIR.mkdir(exist_ok=True)
# Test compilation of export # Test compilation of export
search_path = ( search_path = (
...@@ -109,11 +111,6 @@ class test_export(unittest.TestCase): ...@@ -109,11 +111,6 @@ class test_export(unittest.TestCase):
else os.environ["AIDGE_INSTALL"] else os.environ["AIDGE_INSTALL"]
) )
shutil.copyfile(
pathlib.Path(__file__).parent / "static/main.cpp",
self.EXPORT_PATH / "main.cpp",
)
########################## ##########################
# CMAKE EXPORT # CMAKE EXPORT
try: try:
......
...@@ -77,6 +77,9 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera ...@@ -77,6 +77,9 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera
// we should always consume available data first. This is ensured // we should always consume available data first. This is ensured
// by setting the consumers list to the output nodes and then recursively // by setting the consumers list to the output nodes and then recursively
// find the dependencies. // find the dependencies.
// The initial list may contain producer nodes, in which case
// getPriorProducersConsumers() at step 2 will have moved it in
// the requiredProducers list.
std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes(); std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes();
std::set<std::shared_ptr<Node>> producers; std::set<std::shared_ptr<Node>> producers;
...@@ -300,19 +303,23 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera ...@@ -300,19 +303,23 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera
void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>& consumer, const std::string& nodeName) const { void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>& consumer, const std::string& nodeName) const {
Log::debug("\t- consumer: {}", fmt::styled(nodeName, fg(fmt::color::orange))); Log::debug("\t- consumer: {}", fmt::styled(nodeName, fg(fmt::color::orange)));
std::string crLog = "\t\tC/R:\t"; std::string crLog = "\t\tC/R:\t";
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { if (consumer->nbInputs() > 0) {
crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
consumer->getOperator()->getNbRequiredData(inId)); crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
consumer->getOperator()->getNbRequiredData(inId));
}
crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
} }
crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
Log::debug("{}", crLog); Log::debug("{}", crLog);
std::string pLog = "\t\tP:\t"; std::string pLog = "\t\tP:\t";
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { if (consumer->nbOutputs() > 0) {
pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
}
pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
} }
pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
Log::debug("{}", pLog); Log::debug("{}", pLog);
} }
...@@ -733,7 +740,9 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) ...@@ -733,7 +740,9 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node)
} }
prior.isPrior = true; prior.isPrior = true;
if (prior.priorConsumers.empty()) { if (node->type() == Producer_Op::Type) {
prior.requiredProducers.insert(node);
} else if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node); prior.priorConsumers.insert(node);
} }
mPriorCache.insert(std::make_pair(node, prior)); mPriorCache.insert(std::make_pair(node, prior));
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include "aidge/graph/Testing.hpp" #include "aidge/graph/Testing.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
namespace Aidge { namespace Aidge {
...@@ -134,4 +136,58 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { ...@@ -134,4 +136,58 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
} }
TEST_CASE("someScheduling", "[Scheduler][someUseCases]") {
SECTION("Identity Graph") {
auto data1 = Producer({1}, "data1");
auto identity = Identity("id");
auto g = std::make_shared<GraphView>("TestGraph");
data1->addChild(identity);
g->add({data1, identity});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(sch.size() == nodes.size());
REQUIRE(sch[0] == data1);
REQUIRE(sch[1] == identity);
}
SECTION("Producer Graph") {
auto data1 = Producer({1}, "data1");
auto g = std::make_shared<GraphView>("TestGraph");
g->add({data1});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(sch.size() == nodes.size());
REQUIRE(sch[0] == data1);
}
SECTION("Generic producer Graph") {
auto gen1 = GenericOperator("Prod", 0, 0, 1, "gen1");
auto g = std::make_shared<GraphView>("TestGraph");
g->add({gen1});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(sch.size() == nodes.size());
REQUIRE(sch[0] == gen1);
}
SECTION("No output Graph") {
auto dead1 = GenericOperator("Dead", 1, 0, 0, "dead");
auto g = std::make_shared<GraphView>("TestGraph");
g->add({dead1});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(nodes.size() == 1);
REQUIRE(sch.size() == 0);
}
}
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment