diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index d71d095e7f2c0c9bda4781f3efda3fb7954a2ed6..9b16f76d52e1a8d19a225d5ead2d1d47e465fd30 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -213,10 +213,7 @@ public: * @param inID * @return std::pair<std::shared_ptr<Node>, IOIndex_t> */ - inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const { - AIDGE_ASSERT((inID != gk_IODefaultIndex) && (inID < nbInputs()), "Input index out of bound."); - return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]); - } + std::pair<std::shared_ptr<Node>, IOIndex_t> input(const IOIndex_t inID) const; /** @@ -328,7 +325,7 @@ public: * Default to 0. * @param otherInId ID of the other Node input to connect to the current Node. * Default to the first available data input. - * + * * @note otherNode shared_ptr is passed by refenrece in order to be able to detect * possible dangling connection situations in debug using ref counting. */ @@ -509,7 +506,7 @@ private: * @param otherNode * @param outId * @param otherInId - * + * * @note otherNode shared_ptr is passed by refenrece in order to be able to detect * possible dangling connection situations in debug using ref counting. */ diff --git a/include/aidge/utils/Directories.hpp b/include/aidge/utils/Directories.hpp index 783783946ff5bdae5214cc41f6a1f029fae6e09c..c42280a6d67cfc86c64013b236690bf84f985f66 100644 --- a/include/aidge/utils/Directories.hpp +++ b/include/aidge/utils/Directories.hpp @@ -11,11 +11,10 @@ #ifndef AIDGE_DIRECTORIES_H_ #define AIDGE_DIRECTORIES_H_ -#include <algorithm> +#include <algorithm> // std::replace_if #include <errno.h> #include <string> // #include <string_view> available in c++-17 -#include <vector> #include <fmt/core.h> #include <fmt/format.h> diff --git a/include/aidge/utils/FileManagement.hpp b/include/aidge/utils/FileManagement.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8158fbf19f9e6c62ee5cc967c4a2bb03ba09d0a2 --- /dev/null +++ b/include/aidge/utils/FileManagement.hpp @@ -0,0 +1,28 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cstdio> // std::fclose, std::fopen +#include <memory> +#include <string> + +namespace Aidge { +struct FileDeleter { + void operator()(FILE* fp) const { + if (fp) { + std::fclose(fp); + } + } +}; + +inline std::unique_ptr<FILE, FileDeleter> createFile(const std::string& fileName, const char* accessibility = "w") { + return std::unique_ptr<FILE, FileDeleter>(std::fopen(fileName.c_str(), accessibility)); +} +} diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 315844858103cbce91049ec2195ff0a3bd7a9d81..d0b539182fc308c87f0fac11ab8fcdf15793c1f2 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -33,6 +33,7 @@ #include "aidge/operator/Producer.hpp" #include "aidge/operator/Memorize.hpp" #include "aidge/utils/Directories.hpp" +#include "aidge/utils/FileManagement.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" @@ -85,7 +86,7 @@ bool Aidge::GraphView::inView(const std::string& nodeName) const { } void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProducers) const { - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((path + ".mmd").c_str(), "w"), &std::fclose); + auto fp = createFile(path + ".mmd", "w"); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error, @@ -261,7 +262,7 @@ void Aidge::GraphView::logOutputs(const std::string& dirName) const { for (IOIndex_t outIdx = 0; outIdx < nodePtr->nbOutputs(); ++outIdx) { const std::string& inputPath = nodePath +"output_" + std::to_string(outIdx) + ".log"; - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen(inputPath.c_str(), "w"), &std::fclose); + auto fp = createFile(inputPath, "w"); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Could not create graph view log file: {}", inputPath); @@ -1122,7 +1123,7 @@ void Aidge::GraphView::insertParent(NodePtr childNode, * | >1 node, 1 input | trivial | trivial | broadcast | broadcast | * | 1 node, >1 inputs | (take first) | (take first) | same order | X | * | >1 node, >1 inputs | X | X | X | X | - * + * * Outputs conditions: * | old \ new | 1 node, 1 output | >1 node, 1 output | 1 node, >1 outputs | >1 node, >1 outputs | * | ------------------- | ---------------- | ----------------- | ------------------ | ------------------- | @@ -1130,7 +1131,7 @@ void Aidge::GraphView::insertParent(NodePtr childNode, * | >1 node, 1 output | trivial | trivial | take first | X | * | 1 node, >1 outputs | (take first) | (take first) | same order | X | * | >1 node, >1 outputs | X | X | X | X | - * + * * Only the X cases cannot possibly be resolved deterministically with sets of node. * These cases are therefore forbidden for the set-based `replace()` interface. * The remaining cases are handled by the GraphView-based `replace()` interface. diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 1c8585d1d1f26341724486a16d0678d92f759146..0dec30c2f2f2ffcb0f83740c863d46d7169d2f06 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -173,6 +173,12 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No return res; } +std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t> Aidge::Node::input(const Aidge::IOIndex_t inID) const { + // nbInputs already < gk_IODefaultIndex + AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound."); + return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]); +} + // void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> // tensor) { // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); diff --git a/src/scheduler/MemoryManager.cpp b/src/scheduler/MemoryManager.cpp index 05f461b82f16b6af4ed412b7336aa2328bcafbe1..8e35913f4832f0e54e26f9be286943eb25f498ba 100644 --- a/src/scheduler/MemoryManager.cpp +++ b/src/scheduler/MemoryManager.cpp @@ -572,7 +572,7 @@ Aidge::MemoryManager::getPlanes(const std::shared_ptr<Node>& node) const const std::map<std::shared_ptr<Node>, std::vector<MemoryPlane> > ::const_iterator it = mMemPlanes.find(node); - if (it == mMemPlanes.end()) { + if (it == mMemPlanes.cend()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "getSize(): no memory allocated for node name {}", node->name()); } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index fabdc7ad2a897708297f6fac25f036b45bd3b2b2..fa57f76db63fb345259d8bf585fef708b1d44a31 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -34,6 +34,7 @@ #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Concat.hpp" +#include "aidge/utils/FileManagement.hpp" #include "aidge/utils/Log.hpp" #include "aidge/utils/Types.h" @@ -590,7 +591,8 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr node->name(), node->type()); const bool isWrappable = (requiredProtected.data < requiredData.data); - const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; + const auto& memPlanes = memManager.getPlanes(parent.first); + const MemoryManager::MemoryPlane& memPlane = memPlanes.at(memPlanes.size() - parent.first->nbOutputs() + parent.second); // use at() to avoid dangling reference pointer if (isWrappable || !memManager.isWrapAround( memPlane.memSpace, @@ -672,7 +674,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer std::shared_ptr<Node> concat = nullptr; // If the only child is a concatenation, check if we can allocate - // the concatenation result directly and skip allocation for this + // the concatenation result directly and skip allocation for this // node output: if (childs.size() == 1 && (*childs.begin())->type() == Concat_Op::Type) { concat = *childs.begin(); @@ -758,10 +760,11 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer node->name(), node->type()); const bool isWrappable = (requiredProtected.data < requiredData.data); + const auto& memPlanes = memManager.getPlanes(parent.first); const MemoryManager::MemoryPlane& memPlane = (concat && itConcat != concatMemPlane.end()) ? itConcat->second - : memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; + : memPlanes.at(memPlanes.size()-parent.first->nbOutputs()+parent.second); // use at() to avoid dangling reference pointer if (isWrappable || !memManager.isWrapAround( memPlane.memSpace, @@ -901,7 +904,7 @@ void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Te } void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const { - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); + auto fp = createFile(fileName + ".mmd", "w"); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error, @@ -931,7 +934,7 @@ void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const } void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); + auto fp = createFile(fileName + ".mmd", "w"); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error, @@ -966,7 +969,7 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) } void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat) const { - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); + auto fp = createFile(fileName + ".mmd", "w"); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error,