Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_core
  • hrouis/aidge_core
  • mszczep/aidge_core
  • oantoni/aidge_core
  • cguillon/aidge_core
  • jeromeh/aidge_core
  • axelfarr/aidge_core
  • cmoineau/aidge_core
  • noamzerah/aidge_core
  • lrakotoarivony/aidge_core
  • silvanosky/aidge_core
  • maab05/aidge_core
  • mick94/aidge_core
  • lucaslopez/aidge_core_ll
  • wboussella/aidge_core
  • farnez/aidge_core
  • mnewson/aidge_core
17 results
Show changes
Commits on Source (12)
# Version 0.4.0 (February 2025)
# Version 0.4.0 (December 2024)
# Version 0.2.1 (May 14, 2024)
......
......@@ -28,7 +28,7 @@ def create_gview():
value = prod_op.get_output(0)
value.set_backend("cpu")
tuple_out = node.output(0)[0]
if (tuple_out[0].type() == "Conv" or tuple_out[0].type() == "PaddedConv") and tuple_out[1]==1:
# Conv weight
aidge_core.xavier_uniform_filler(value)
......@@ -45,13 +45,13 @@ def create_gview():
pass
# Compile model
gview.forward_dims([[1, 1, 28, 28]])
gview.forward_dims([[1, 1, 28, 28]])
gview.set_datatype(aidge_core.dtype.float32)
return gview
class test_show_gview(unittest.TestCase):
"""Test aidge functionality to show GraphView.
"""Test aidge functionality to show GraphView.
"""
def setUp(self):
......@@ -61,12 +61,14 @@ class test_show_gview(unittest.TestCase):
pass
def test_gview_to_json(self):
gview = create_gview()
# Create temporary file to store JSON model description
model_description_file = tempfile.NamedTemporaryFile(mode="w+", suffix='.json')
# Create temporary file to store JSON model description
model_description_file = tempfile.NamedTemporaryFile(suffix='.json', delete=False)
model_description_file.close() # Ensure the file is closed
# Pass the file path to gview_to_json
gview_to_json(gview, Path(model_description_file.name))
# Load JSON
......@@ -76,15 +78,15 @@ class test_show_gview(unittest.TestCase):
# Get list of nodes of Aidge graphview
gview_ordered_nodes = gview.get_ordered_nodes()
# Iterate over the list of ordered nodes and the corresponding JSON
# Iterate over the list of ordered nodes and the corresponding JSON
self.assertEqual(len(gview_ordered_nodes), len(model_json['graph']))
for node_gview, node_json in zip(gview_ordered_nodes, model_json['graph']):
for node_gview, node_json in zip(gview_ordered_nodes, model_json['graph']):
self.assertEqual(node_gview.get_operator().type(), node_json['optype'])
self.assertEqual(node_gview.get_operator().nb_inputs(), node_json['nb_inputs'])
self.assertEqual(node_gview.get_operator().nb_outputs(), node_json['nb_outputs'])
self.assertEqual(node_gview.get_operator().nb_inputs(), len(node_json['inputs']))
for input_idx in range(node_gview.get_operator().nb_inputs()):
self.assertEqual(node_gview.get_operator().get_input(input_idx).dims(), node_json['inputs'][input_idx]['dims'])
......@@ -97,7 +99,7 @@ class test_show_gview(unittest.TestCase):
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dtype()), node_json['outputs'][output_idx]['data_type'])
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dformat()), node_json['outputs'][output_idx]['data_format'])
self.assertEqual(len(node_gview.get_parents()), len(node_json['parents']))
self.assertEqual(len(node_gview.get_parents()), len(node_json['parents']))
self.assertEqual(len(node_gview.get_children()), len(node_json['children']))
if not hasattr(node_gview.get_operator(), 'get_micro_graph'):
......@@ -109,21 +111,21 @@ class test_show_gview(unittest.TestCase):
self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes'])
elif hasattr(node_gview.get_operator(), 'get_micro_graph'):
self.assertEqual(len(node_gview.get_operator().get_micro_graph().get_nodes()), len(node_json['attributes']['micro_graph']))
for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes():
for micro_node_json in node_json['attributes']['micro_graph']:
if micro_node_gview.get_operator().type() == micro_node_json['optype']:
for key, value in micro_node_gview.get_operator().attr.dict().items():
if not type(value).__name__ in dir(builtins):
# Replace original value by its name (str) because value is of a type that could not be written to the JSON
# Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name})
# Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name})
temp_mnode_dict = micro_node_gview.get_operator().attr.dict()
temp_mnode_dict.update({key : value.name})
self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes'])
self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes'])
if __name__ == '__main__':
unittest.main()
......@@ -34,6 +34,10 @@ public:
return std::make_unique<ProdConso>(op, true);
}
const Operator& getOperator() const noexcept {
return mOp;
}
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
......
......@@ -187,6 +187,7 @@ public:
* @param fileName Name of the file to save the diagram (without extension).
*/
void saveStaticSchedulingDiagram(const std::string& fileName) const;
void saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat = 2) const;
/**
* @brief Save in a Mermaid file the order of layers execution.
......@@ -233,6 +234,18 @@ protected:
*/
void generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const;
/**
* @brief Get the factorized scheduling, by identifying repetitive sequences
* in the scheduling.
*
* @param schedule Vector of shared pointers to StaticSchedulingElements to be processed
* @param size_t Minimum number repetitions to factorize the sequence
* @return Vector containing the repetitive sequences, in order. The second
* element of the pair is the number of repetitions.
*/
std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>>
getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule, size_t minRepeat = 2) const;
private:
/**
* @brief Summarize the consumer state of a node for debugging purposes.
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm> // std::transform
#include <cctype> // std::tolower
#include <string> // std::string
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>
#include "aidge/data/Elts.hpp"
namespace py = pybind11;
namespace Aidge {
template <class T>
void bindEnum(py::module& m, const std::string& name) {
// Define enumeration names for python as lowercase type name
// This defined enum names compatible with basic numpy type
// name such as: float32, flot64, [u]int32, [u]int64, ...
auto python_enum_name = [](const T& type) {
auto str_lower = [](std::string& str) {
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c){
return std::tolower(c);
});
};
auto type_name = std::string(Aidge::format_as(type));
str_lower(type_name);
return type_name;
};
// Auto generate enumeration names from lowercase type strings
std::vector<std::string> enum_names;
for (auto type_str : EnumStrings<T>::data) {
auto type = static_cast<T>(enum_names.size());
auto enum_name = python_enum_name(type);
enum_names.push_back(enum_name);
}
// Define python side enumeration aidge_core.type
auto e_type = py::enum_<T>(m, name.c_str());
// Add enum value for each enum name
for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
e_type.value(enum_names[idx].c_str(), static_cast<T>(idx));
}
// Define str() to return the bare enum name value, it allows
// to compare directly for instance str(tensor.type())
// with str(nparray.type)
e_type.def("__str__", [enum_names](const T& type) {
return enum_names[static_cast<int>(type)];
}, py::prepend());
}
void init_Elts(py::module& m) {
bindEnum<Elts_t::EltType>(m, "EltType");
m.def("format_as", (const char* (*)(Elts_t::EltType)) &format_as, py::arg("elt"));
py::class_<Elts_t, std::shared_ptr<Elts_t>>(
m, "Elts_t", py::dynamic_attr())
.def_static("none_elts", &Elts_t::NoneElts)
.def_static("data_elts", &Elts_t::DataElts, py::arg("data"), py::arg("token") = 1)
.def_static("token_elts", &Elts_t::TokenElts, py::arg("token"))
.def_readwrite("data", &Elts_t::data)
.def_readwrite("token", &Elts_t::token)
.def_readwrite("type", &Elts_t::type)
.def(py::self + py::self)
.def(py::self += py::self)
.def(py::self < py::self)
.def(py::self > py::self);
}
} // namespace Aidge
......@@ -21,6 +21,7 @@ void init_Random(py::module&);
void init_Data(py::module&);
void init_DataFormat(py::module&);
void init_DataType(py::module&);
void init_Elts(py::module&);
void init_Database(py::module&);
void init_DataProvider(py::module&);
void init_Interpolation(py::module&);
......@@ -112,6 +113,7 @@ void init_Aidge(py::module& m) {
init_Data(m);
init_DataFormat(m);
init_DataType(m);
init_Elts(m);
init_Database(m);
init_DataProvider(m);
init_Interpolation(m);
......
......@@ -104,6 +104,7 @@ void init_ProdConso(py::module& m){
.def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
.def_static("default_model", &ProdConso::defaultModel)
.def_static("in_place_model", &ProdConso::inPlaceModel)
.def("get_operator", &ProdConso::getOperator)
.def("get_nb_required_data", &ProdConso::getNbRequiredData)
.def("get_nb_required_protected", &ProdConso::getNbRequiredProtected)
.def("get_required_memory", &ProdConso::getRequiredMemory)
......
......@@ -32,6 +32,7 @@ void init_Scheduler(py::module& m){
.def("graph_view", &Scheduler::graphView)
.def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name"))
.def("save_factorized_static_scheduling_diagram", &Scheduler::saveFactorizedStaticSchedulingDiagram, py::arg("file_name"), py::arg("min_repeat") = 2)
.def("resetScheduling", &Scheduler::resetScheduling)
.def("generate_scheduling", &Scheduler::generateScheduling)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::EarlyLateSort::Default)
......
......@@ -36,6 +36,7 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
for (const auto& node : candidates) {
bool foldable = true;
auto replaceGraph = std::make_shared<GraphView>();
size_t i = 0;
for (const auto& input : node->inputs()) {
if (input.first) {
if (input.first->type() != Producer_Op::Type) {
......@@ -53,6 +54,13 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
replaceGraph->add(input.first, false);
}
else if (node->inputCategory(i) != InputCategory::OptionalData
&& node->inputCategory(i) != InputCategory::OptionalParam)
{
foldable = false;
break;
}
++i;
}
if (foldable) {
......
......@@ -427,6 +427,83 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
}
}
std::vector<std::pair<std::vector<Aidge::Scheduler::StaticSchedulingElement*>, size_t>>
Aidge::Scheduler::getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule, size_t minRepeat) const
{
std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> sequences;
size_t offset = 0;
for (size_t i = 0; i < schedule.size(); ) {
// Find all the possible repetitive sequences starting from this element
std::vector<StaticSchedulingElement*> seq;
std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> longuestSeq;
std::vector<size_t> longuestSeqOffset;
for (size_t k = i; k < schedule.size(); ++k) {
// For each sequence length, starting from 1...
seq.push_back(new StaticSchedulingElement(
schedule[k]->node,
schedule[k]->early - offset,
schedule[k]->late - offset));
size_t start = k + 1;
size_t nbRepeats = 1;
bool repeat = true;
const auto seqOffset = (start < schedule.size()) ? schedule[start]->early - offset - seq[0]->early : 0;
do {
// Count the number of consecutive sequences (repetitions)
for (size_t r = 0; r < seq.size(); ++r) {
if (start + r >= schedule.size()
|| schedule[start + r]->node != seq[r]->node
|| schedule[start + r]->early - offset != seq[r]->early + seqOffset * nbRepeats
|| schedule[start + r]->late - offset != seq[r]->late + seqOffset * nbRepeats)
{
repeat = false;
break;
}
}
if (repeat) {
start += seq.size();
++nbRepeats;
}
}
while (repeat);
if (nbRepeats >= minRepeat) {
// If repetitions exist for this sequence length, add it to the list
longuestSeq.push_back(std::make_pair(seq, nbRepeats));
longuestSeqOffset.push_back(seqOffset);
}
else if (k == i) {
// Ensure that at least the current element is in the list if no
// repetition is found
longuestSeq.push_back(std::make_pair(seq, 1));
longuestSeqOffset.push_back(0);
}
}
// Select the one with the best factorization
// i.e. which maximize the product sequence length * number of sequences
size_t maxS = 0;
size_t maxFactorization = 0;
for (size_t s = 0; s < longuestSeq.size(); ++s) {
const auto factor = longuestSeq[s].first.size() * longuestSeq[s].second;
if (factor > maxFactorization) {
maxFactorization = factor;
maxS = s;
}
}
sequences.push_back(longuestSeq[maxS]);
i += longuestSeq[maxS].first.size() * longuestSeq[maxS].second;
offset += longuestSeqOffset[maxS] * (longuestSeq[maxS].second - 1);
}
return sequences;
}
void Aidge::Scheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
......@@ -869,8 +946,65 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName)
// Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_');
fmt::print(fp.get(), "{} :{}, {}\n",
name, element->early, element->late);
if (element->early == element->late) {
fmt::print(fp.get(), "{} :milestone, {}, {}\n",
name, element->early, element->late);
}
else {
fmt::print(fp.get(), "{} :{}, {}\n",
name, element->early, element->late);
}
}
}
}
fmt::print(fp.get(), "\n");
}
void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Could not create scheduling diagram log file: {}", fileName + ".mmd");
}
fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n");
if (!mStaticSchedule.empty()) {
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
for (const auto& schedule : mStaticSchedule) {
const auto factorizedSchedule = getFactorizedScheduling(schedule, minRepeat);
size_t seq = 0;
for (const auto& sequence : factorizedSchedule) {
if (sequence.second > 1) {
fmt::print(fp.get(), "section seq#{} (x{})\n", seq, sequence.second);
}
else {
fmt::print(fp.get(), "section seq#{}\n", seq);
}
for (const auto& element : sequence.first) {
auto name = namePtrTable.at(element->node);
// Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_');
std::string tag = ":";
if (element->early == element->late) {
tag += "milestone, ";
}
if (sequence.second > 1) {
tag += "active, ";
}
fmt::print(fp.get(), "{} {}{}, {}\n",
name, tag, element->early, element->late);
}
++seq;
}
}
}
......
find_package(Catch2 QUIET)
# Catch2 configuration
set(CATCH2_MIN_VERSION 3.3.0)
# Try to find system installed Catch2
find_package(Catch2 ${CATCH2_MIN_VERSION} QUIET)
if(NOT Catch2_FOUND)
message(STATUS "Catch2 not found in system, retrieving from git")
......@@ -9,62 +13,93 @@ if(NOT Catch2_FOUND)
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
GIT_TAG devel # or a later release
)
FetchContent_MakeAvailable(Catch2)
message(STATUS "Fetched Catch2 version ${Catch2_VERSION}")
else()
message(STATUS "Found system Catch2 version ${Catch2_VERSION}")
message(STATUS "Using system Catch2 version ${Catch2_VERSION}")
endif()
# Get all source files
file(GLOB_RECURSE src_files "*.cpp")
# Create test executable
add_executable(tests${module_name} ${src_files})
# Set C++14 standard
target_compile_features(tests${module_name} PRIVATE cxx_std_14)
set(FORCE_CI TRUE)
if (NOT(FORCE_CI))
if (DOSANITIZE STREQUAL "ON")
set(SANITIZE_FLAGS -fsanitize=address,leak,undefined,float-divide-by-zero -fno-omit-frame-pointer)
#TODO sanitizer seems buggy in some situations with msvc, leading to linker errors, temporarily inactivating it
#set(SANITIZE_MSVC_FLAGS)
set(SANITIZE_MSVC_FLAGS /fsanitize=address)
target_compile_definitions(tests${module_name} PUBLIC _DISABLE_VECTOR_ANNOTATION)
else()
set(SANITIZE_FLAGS)
set(SANITIZE_MSVC_FLAGS)
endif()
# Compiler flags and options
if(NOT(FORCE_CI))
# Sanitization configuration
if(DOSANITIZE STREQUAL "ON")
set(SANITIZE_FLAGS
-fsanitize=address,leak,undefined,float-divide-by-zero
-fno-omit-frame-pointer
)
set(SANITIZE_MSVC_FLAGS /fsanitize=address)
set(STRICT_ALIASING_FLAGS -fstrict-aliasing -Wstrict-aliasing=2)
# -fvisibility=hidden required by pybind11
target_compile_options(tests${module_name} PUBLIC
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-fvisibility=hidden>)
target_compile_options(tests${module_name} PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall -Wextra -Wold-style-cast -pedantic -Werror=narrowing -Wshadow $<$<BOOL:${WERROR}>:-Werror> ${SANITIZE_FLAGS}>)
target_compile_options(tests${module_name} PRIVATE
$<$<CXX_COMPILER_ID:GNU>:${STRICT_ALIASING_FLAGS}>)
target_compile_options(${module_name} PRIVATE
$<$<CXX_COMPILER_ID:MSVC>:
/W4 /DWIN32 /D_WINDOWS /GR /EHsc /MP /Zc:__cplusplus /Zc:preprocessor /permissive- ${SANITIZE_MSVC_FLAGS}>)
if (DOSANITIZE STREQUAL "ON")
target_compile_options(${module_name} PRIVATE $<$<CXX_COMPILER_ID:MSVC>:/MDd>)
endif()
# TODO FIXME: I'm not sure it's a good idea to propagate this option but, at this point, it was the only way that worked to silence C4477
target_compile_options(${module_name} PUBLIC $<$<CXX_COMPILER_ID:MSVC>: /wd4477>)
# Temporary workaround for MSVC sanitizer issues
target_compile_definitions(tests${module_name} PUBLIC _DISABLE_VECTOR_ANNOTATION)
else()
set(SANITIZE_FLAGS "")
set(SANITIZE_MSVC_FLAGS "")
endif()
target_link_options(tests${module_name} PUBLIC $<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:${SANITIZE_FLAGS}>)
#target_link_options(tests${module_name} PUBLIC $<$<CXX_COMPILER_ID:MSVC>:${SANITIZE_MSVC_FLAGS}>)
# Strict aliasing for better type safety
set(STRICT_ALIASING_FLAGS -fstrict-aliasing -Wstrict-aliasing=2)
endif()
# Compiler options for different toolchains
target_compile_options(tests${module_name} PUBLIC
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-fvisibility=hidden> # Hide symbols by default
) # -fvisibility=hidden required by Pybind11
target_link_libraries(tests${module_name} PRIVATE ${module_name})
# Common warning and error flags
target_compile_options(tests${module_name} PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall -Wextra -Wold-style-cast -pedantic -Werror=narrowing -Wshadow
$<$<BOOL:${WERROR}>:-Werror>
${SANITIZE_FLAGS}
>
)
# Strict aliasing for GCC
target_compile_options(tests${module_name} PRIVATE
$<$<CXX_COMPILER_ID:GNU>:${STRICT_ALIASING_FLAGS}>
)
target_link_libraries(tests${module_name} PRIVATE Catch2::Catch2WithMain)
# MSVC-specific settings
target_compile_options(tests${module_name} PRIVATE
$<$<CXX_COMPILER_ID:MSVC>:
/W4 /DWIN32 /D_WINDOWS /GR /EHsc /MP
/Zc:__cplusplus /Zc:preprocessor /permissive-
${SANITIZE_MSVC_FLAGS}
>
)
# Workaround for C4477 warning in MSVC
if(DOSANITIZE STREQUAL "ON")
target_compile_options(tests${module_name} PRIVATE $<$<CXX_COMPILER_ID:MSVC>:/MDd>)
endif()
# TODO: Fix this once proper configuration is available
# only way to silence C4477
target_compile_options(tests${module_name} PUBLIC $<$<CXX_COMPILER_ID:MSVC>: /wd4477>)
endif()
# Link libraries
target_link_libraries(tests${module_name} PRIVATE
${module_name}
Catch2::Catch2WithMain
)
# Setup testing
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
include(CTest)
include(Catch)
# Discover and add tests
catch_discover_tests(tests${module_name})
# Set test configuration for CTest
set(CTEST_CONFIGURATION_TYPE ${CMAKE_BUILD_TYPE})
\ No newline at end of file
0.4.0
0.5.0