Skip to content
Snippets Groups Projects
Commit 5805bb1c authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Use fmt library instead of custom functions, added GraphView::getRankedNodes()...

Use fmt library instead of custom functions, added GraphView::getRankedNodes() and GraphView::getRankedNodesName() methods
parent ec2c2e43
No related branches found
No related tags found
No related merge requests found
Showing with 158 additions and 174 deletions
......@@ -30,8 +30,15 @@ endif()
##############################################
# Find system dependencies
Include(FetchContent)
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1 # or a later release
)
FetchContent_MakeAvailable(fmt)
##############################################
# Create target and set properties
......@@ -64,6 +71,7 @@ if (PYBIND)
)
endif()
target_link_libraries(${module_name} PUBLIC fmt::fmt)
target_compile_features(${module_name} PRIVATE cxx_std_14)
if (DOSANITIZE STREQUAL "ON")
......
@PACKAGE_INIT@
Include(FetchContent)
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1 # or a later release
)
FetchContent_MakeAvailable(fmt)
include(${CMAKE_CURRENT_LIST_DIR}/aidge_core-config-version.cmake)
......
......@@ -141,7 +141,7 @@ public:
*/
virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend {}", mBackend);
};
/**
......
......@@ -262,6 +262,25 @@ public:
*/
NodePtr getNode(const std::string& nodeName) const;
/**
* Get the ranked list of nodes in the GraphView.
* If the ranking cannot be garanteed to be unique, the second item indicates
* the rank from which unicity cannot be garanteed.
* @return std::pair<std::vector<NodePtr>, size_t> Pair with the list of ranked
* nodes and the size of the ranked sub-list where unicity is garanteed.
*/
std::pair<std::vector<NodePtr>, size_t> getRankedNodes() const;
/**
* Get the nodes name according to the GraphView nodes ranking.
* @param format The formatting string to be used with fmt::format().
* The usable positional arguments are the following:
* {0} node name, {1} node type, {2} rank, {3} type rank
* @param markNonUnicity If true, non unique ranking is prefixed with "?"
* @return std::map<NodePtr, std::string> A map with the corresponding names
*/
std::map<NodePtr, std::string> getRankedNodesName(const std::string& format, bool markNonUnicity = true) const;
/**
* @brief Remove a Node from the current GraphView scope without affecting its connections.
* @param nodePtr Node to remove
......
......@@ -111,7 +111,7 @@ public:
for (DimIdx_t i = 0; i < (DIM+2); ++i) {
if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension {} ({} + {})", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
}
}
......
......@@ -133,7 +133,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> co
for (DimIdx_t i = 0; i < (DIM+2); ++i) {
if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension {} ({} + {})", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
}
}
......
......@@ -128,7 +128,7 @@ public:
for (DimIdx_t i = 0; i < (DIM+2); ++i) {
if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension {} ({} + {})", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
}
}
......
......@@ -79,27 +79,27 @@ public:
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as outputs", type().c_str());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator only accepts Tensors as outputs", type().c_str());
}
if (outputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} outputs", type().c_str(), nbInputs());
}
*mInputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} outputs", type().c_str(), nbInputs());
}
*mInputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const override final {
if (outputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} outputs", type().c_str(), nbInputs());
}
if (mInputs[outputIdx] == nullptr){
return mOutputs[outputIdx]; // Input is not initialized with empty tensor
......
......@@ -89,12 +89,6 @@ private:
std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
/**
* Return a std::map with corresponding node's name.
* TODO: Mutualise with similar code in GraphView::save()?
*/
std::map<std::shared_ptr<Node>, std::string> getNodesName(bool verbose) const;
/** @brief Shared ptr to the scheduled graph view */
std::shared_ptr<GraphView> mGraphView;
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
......
......@@ -13,20 +13,17 @@
#ifndef AIDGE_ERRORHANDLING_H_
#define AIDGE_ERRORHANDLING_H_
#include <cstdio>
#include <memory>
#define AIDGE_STRINGIZE_DETAIL(x) #x
#define AIDGE_STRINGIZE(x) AIDGE_STRINGIZE_DETAIL(x)
#include <fmt/format.h>
#ifdef NO_EXCEPTION
#define AIDGE_THROW_OR_ABORT(ex, ...) \
do { std::printf(__VA_ARGS__); std::abort(); } while (false)
do { fmt::print(__VA_ARGS__); std::abort(); } while (false)
#else
#include <stdexcept>
#include "aidge/utils/Formatting.hpp"
#define AIDGE_THROW_OR_ABORT(ex, ...) \
throw ex(stringFormat(__VA_ARGS__))
throw ex(fmt::format(__VA_ARGS__))
#endif
/**
......@@ -35,7 +32,7 @@ throw ex(stringFormat(__VA_ARGS__))
* If it asserts, it means an user error.
*/
#define AIDGE_ASSERT(stm, ...) \
if (!(stm)) { printf("Assertion failed: " AIDGE_STRINGIZE(stm) " in " __FILE__ ":%d", __LINE__); \
if (!(stm)) { fmt::print("Assertion failed: " #stm " in {}:{}", __FILE__, __LINE__); \
AIDGE_THROW_OR_ABORT(std::runtime_error, __VA_ARGS__); }
/**
......@@ -44,6 +41,6 @@ if (!(stm)) { printf("Assertion failed: " AIDGE_STRINGIZE(stm) " in " __FILE__ "
* If it asserts, it means a bug.
*/
#define AIDGE_INTERNAL_ASSERT(stm) \
assert((stm) && "Internal assertion failed: " #stm " in " __FILE__ ":" AIDGE_STRINGIZE(__LINE__))
assert((stm) && "Internal assertion failed")
#endif //AIDGE_ERRORHANDLING_H_
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_FORMATTING_H_
#define AIDGE_FORMATTING_H_
#include <memory>
#include <string>
#include <vector>
namespace Aidge {
// The code snippet below is licensed under CC0 1.0.
template<typename ... Args>
std::string stringFormat(const std::string& format, Args... args) {
#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
// Disable security warning on GCC
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wformat-security"
#elif defined(_MSC_VER)
// Disable security warning on MSVC
#pragma warning(push)
#pragma warning(disable : 4774)
#endif
int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
if (size_s <= 0) {
std::printf("Error during formatting.");
std::abort();
}
auto size = static_cast<size_t>(size_s);
std::unique_ptr<char[]> buf(new char[size]);
std::snprintf(buf.get(), size, format.c_str(), args...);
return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
#pragma GCC diagnostic pop
#elif defined(_MSC_VER)
#pragma warning(pop)
#endif
}
/**
* Print any iterable object in a std::string.
*/
template <class T, typename F>
std::string print(const T& vec, const std::string& format, const F& func) {
std::string str = "{";
bool first = true;
for (const auto& val : vec) {
if (!first) {
str += ", ";
}
else {
first = false;
}
str += stringFormat(format, func(val));
}
str += "}";
return str;
}
template <class T>
std::string print(const T& vec, const std::string& format) {
return print(vec, format, [](auto val){ return val; });
}
}
#endif //AIDGE_FORMATTING_H_
......@@ -95,7 +95,7 @@ public:
}
}
AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"%s\" not found", name);
AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"{}\" not found", name);
}
template <typename R>
......@@ -106,7 +106,7 @@ public:
}
}
AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"%s\" not found", name);
AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"{}\" not found", name);
}
template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value>
......@@ -116,7 +116,7 @@ public:
return reinterpret_cast<R&>(std::get<SIZE-1>(mAttrs));
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "wrong type for attribute with index %lu", i);
AIDGE_THROW_OR_ABORT(std::runtime_error, "wrong type for attribute with index {}", i);
}
}
else {
......@@ -136,7 +136,7 @@ public:
return reinterpret_cast<const R&>(std::get<SIZE-1>(mAttrs));
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "wrong type for attribute with index %lu", i);
AIDGE_THROW_OR_ABORT(std::runtime_error, "wrong type for attribute with index {}", i);
}
}
else {
......@@ -190,7 +190,7 @@ public:
}
}
AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"%s\" not found", name.c_str());
AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"{}\" not found", name.c_str());
}
std::set<std::string> getAttrsName() const override final {
......@@ -211,7 +211,7 @@ public:
}
}
AIDGE_THROW_OR_ABORT(py::value_error, "attribute \"%s\" not found", name.c_str());
AIDGE_THROW_OR_ABORT(py::value_error, "attribute \"{}\" not found", name.c_str());
};
#endif
......
......@@ -15,8 +15,10 @@
#include <utility>
#include <numeric>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include "aidge/utils/Types.h"
#include "aidge/utils/Formatting.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
......@@ -56,37 +58,28 @@ std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
"'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");
std::map<const std::string, std::size_t> typeCounter;
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
// Start by creating every node
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
++typeCounter[currentType];
auto namePtrTable = getRankedNodesName("{3}");
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
std::string givenName =
(node_ptr->name().empty())
? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>"
: "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + "#" + std::to_string(typeCounter[currentType]) + " )</em></sub>\"";
namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType]));
? "<em>" + node_ptr->type() + "#" + namePtrTable[node_ptr] + "</em>"
: "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable[node_ptr] + ")</em></sub>\"";
if (node_ptr == mRootNode) {
std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(),
std::fprintf(fp, "%s_%s(%s):::rootCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
else {
if ((currentType != "Producer") || showProducers) {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
if ((node_ptr->type() != "Producer") || showProducers) {
std::fprintf(fp, "%s_%s(%s)\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
}
......@@ -108,15 +101,15 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
std::string dims = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator());
if (op && !op->getOutput(outputIdx)->dims().empty()) {
dims += " " + print(op->getOutput(outputIdx)->dims(), "%u");
dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims());
}
if (mNodes.find(child) != mNodes.end()) {
std::fprintf(fp, "%s-->|\"%u%s&rarr;%u\"|%s\n", namePtrTable[node_ptr].c_str(),
outputIdx, dims.c_str(), inputIdx, namePtrTable[child].c_str());
std::fprintf(fp, "%s_%s-->|\"%u%s&rarr;%u\"|%s_%s\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(),
outputIdx, dims.c_str(), inputIdx, child->type().c_str(), namePtrTable[child].c_str());
}
else if (verbose) {
std::fprintf(fp, "%s-->|\"%u%s&rarr;%u\"|%p:::externalCls\n", namePtrTable[node_ptr].c_str(),
std::fprintf(fp, "%s_%s-->|\"%u%s&rarr;%u\"|%p:::externalCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(),
outputIdx, dims.c_str(), inputIdx, static_cast<void*>(child.get()));
}
break;
......@@ -131,8 +124,8 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
size_t inputIdx = 0;
for (auto input : mInputNodes) {
std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|&rarr;%u|%s\n", inputIdx, inputIdx,
input.second, namePtrTable[input.first].c_str());
std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|&rarr;%u|%s_%s\n", inputIdx, inputIdx,
input.second, input.first->type().c_str(), namePtrTable[input.first].c_str());
++inputIdx;
}
......@@ -142,11 +135,11 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
std::string dims = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator());
if (op && !op->getOutput(output.second)->dims().empty()) {
dims += " " + print(op->getOutput(output.second)->dims(), "%u");
dims += " " + fmt::format("{}", op->getOutput(output.second)->dims());
}
std::fprintf(fp, "%s--->|\"%u%s&rarr;\"|output%lu((out#%lu)):::outputCls\n",
namePtrTable[output.first].c_str(), output.second,
std::fprintf(fp, "%s_%s--->|\"%u%s&rarr;\"|output%lu((out#%lu)):::outputCls\n",
output.first->type().c_str(), namePtrTable[output.first].c_str(), output.second,
dims.c_str(), outputIdx, outputIdx);
++outputIdx;
}
......@@ -155,13 +148,6 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
std::fprintf(fp, "classDef outputCls fill:#ffa\n");
std::fprintf(fp, "classDef externalCls fill:#ccc\n");
std::fprintf(fp, "classDef rootCls stroke:#f00\n");
if (verbose) {
for (const auto &c : typeCounter) {
std::printf("%s - %zu\n", c.first.c_str(), c.second);
}
}
std::fprintf(fp, "\n");
std::fclose(fp);
}
......@@ -436,6 +422,59 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
}
}
std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes() const {
std::set<NodePtr> nodesToRank(mNodes);
nodesToRank.erase(mRootNode);
std::vector<NodePtr> rankedNodes;
rankedNodes.push_back(mRootNode);
for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
NodePtr curNode = rankedNodes[curNodeIdx];
for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child);
nodesToRank.erase(child);
}
}
}
for (auto parent : curNode->getParents()) {
if (nodesToRank.find(parent) != nodesToRank.end()) {
rankedNodes.push_back(parent);
nodesToRank.erase(parent);
}
}
}
const size_t orderUnicityLimit = rankedNodes.size();
if (!nodesToRank.empty()) {
rankedNodes.insert(rankedNodes.end(), nodesToRank.begin(), nodesToRank.end());
}
return std::make_pair(rankedNodes, orderUnicityLimit);
}
std::map<Aidge::NodePtr, std::string> Aidge::GraphView::getRankedNodesName(const std::string& format, bool markNonUnicity) const {
const auto rankedNodes = getRankedNodes();
std::map<NodePtr, std::string> rankedNodesName;
size_t rank = 0;
std::map<std::string, size_t> typeRank;
for (const auto& rankedNode : rankedNodes.first) {
std::map<std::string, size_t>::iterator it;
std::tie(it, std::ignore) = typeRank.insert(std::make_pair(rankedNode->type(), 0));
const auto name = (markNonUnicity && rank < rankedNodes.second)
? fmt::format(format, rankedNode->name(), rankedNode->type(), rank, it->second)
: fmt::format(format, rankedNode->name(), rankedNode->type(), fmt::format("?{}", rank), fmt::format("?{}", it->second));
rankedNodesName.insert(std::make_pair(rankedNode, name));
++it->second;
++rank;
}
return rankedNodesName;
}
bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
if (otherNodes.empty()) {
return true;
......
......@@ -21,7 +21,7 @@
void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (inputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} inputs", type().c_str(), nbInputs());
}
if (strcmp((data)->type(), Tensor::Type) != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input data must be of Tensor type");
......@@ -31,7 +31,7 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
......@@ -44,7 +44,7 @@ Aidge::OperatorTensor::~OperatorTensor() = default;
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
......@@ -55,34 +55,34 @@ void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, std::share
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
if (inputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} inputs", type().c_str(), nbInputs());
}
return mInputs[inputIdx];
}
void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} outputs", type().c_str(), nbOutputs());
}
*mOutputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} outputs", type().c_str(), nbOutputs());
}
*mOutputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{} Operator has {} outputs", type().c_str(), nbOutputs());
}
return mOutputs[outputIdx];
}
......@@ -105,7 +105,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_
}
for (DimIdx_t i = 0; i < outputDims.size(); ++i) {
if (((outputDims[i] + firstEltDims[i]) > getOutput(0)->dims()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension {} ({} + {})", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
}
}
// return the same Tensor description as given in function parameter for each data input
......
......@@ -16,6 +16,8 @@
#include <set>
#include <string>
#include <fmt/ranges.h>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
......@@ -63,7 +65,8 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// frozen state.
std::vector<std::set<std::shared_ptr<Node>>> frozenConsumers;
std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose);
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
do {
// From the current consumers list, check if any prior nodes are needed.
......@@ -92,8 +95,17 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
if (prior.isPrior) {
if (verbose) {
printf("\t\trequired producers: %s\n", print(prior.requiredProducers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str());
printf("\t\tprior consumers: %s\n", print(prior.priorConsumers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str());
std::vector<std::string> requiredProducersName;
std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(),
std::back_inserter(requiredProducersName),
[&namePtrTable](auto val){ return namePtrTable[val].c_str(); });
fmt::print("\t\trequired producers: {}\n", requiredProducersName);
std::vector<std::string> priorConsumersName;
std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(),
std::back_inserter(priorConsumersName),
[&namePtrTable](auto val){ return namePtrTable[val].c_str(); });
fmt::print("\t\tprior consumers: {}\n", priorConsumersName);
}
requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend());
......@@ -261,7 +273,8 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
// Clear previous scheduling results
mScheduling.clear();
std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose);
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
int cpt = 0;
for (const auto& runnable : mStaticSchedule) {
......@@ -286,12 +299,13 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa
std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q µs\n\n");
if (!mScheduling.empty()) {
std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(true);
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
const auto globalStart = mScheduling[0].start;
for (const auto& element : mScheduling) {
std::fprintf(fp, "%s :%ld, %ld\n",
namePtrTable[element.node].c_str(),
namePtrTable.find(element.node)->second.c_str(),
std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
}
......@@ -353,25 +367,3 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::
}
return prior;
}
std::map<std::shared_ptr<Aidge::Node>, std::string> Aidge::SequentialScheduler::getNodesName(bool verbose) const {
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) {
std::map<const std::string, std::size_t> typeCounter;
for (const std::shared_ptr<Node> &node_ptr : mGraphView->getNodes()) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
++typeCounter[currentType];
namePtrTable[node_ptr] =
(node_ptr->name().empty())
? currentType + "#" + std::to_string(typeCounter[currentType])
: node_ptr->name() + " (" + currentType + "#" + std::to_string(typeCounter[currentType]) + ")";
}
}
return namePtrTable;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment