diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 89d7a3a7b0c4d164473869a9d6372c3bf48cd308..c7b712be460a748df12447b15883eff58abbf690 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -312,6 +312,18 @@ class Tensor : public Data, */ Tensor sqrt() const; + /** + * @brief Element-wise abs operation for Tensor. + * @return Tensor + */ + Tensor abs() const; + + /** + * @brief Mean operation for Tensor. + * @return Tensor + */ + Tensor mean() const; + ~Tensor() noexcept; public: diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index fc8bfb3353352186b23459e1ca82505827c28345..951aa6b29d73d9055cf9f13c8ddc6313cb506879 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -43,6 +43,7 @@ public: bool singleOutput = true; IOIndex_t edgeLeftIdx = 0; IOIndex_t edgeRightIdx = 0; + NodePtr startNode; // For check & debug purpose: size_t depth = 0; @@ -134,10 +135,20 @@ public: * * @param query The query to search. * @param disjoint If true, only keep the longuest disjoint (non-overlapping) matches. - * @return Set of matches, each stored in a MatchingResult struct. + * @return std::set<MatchingResult> Set of matches, each stored in a MatchingResult struct. */ std::set<MatchingResult> match(const std::string& query, bool disjoint = false); + /** + * @brief Same as match() but with a mandatory start node. + * + * @param startNode Mandatory start node for the query. + * @param query The query to search. + * @return MatchingResult MatchingResult struct, with empty graph if query + * is not found, or the graph corresponding to the query. + */ + MatchingResult matchFrom(NodePtr startNode, const std::string& query); + /** * Filter to keep only the longuest disjoint (non-overlapping) matches. */ @@ -158,7 +169,7 @@ private: bool matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches); /** - * BLOCK = '(' SEQ | PAR | BLOCK | ALT | NODE ')' + * BLOCK = '(' SEQ | PAR | ALT | BLOCK | NODE ')' */ bool matchBlock(Context& ctx, std::set<MatchingResult>& matches); @@ -190,7 +201,7 @@ private: * TYPE = [A-Za-z0-9_]+ * ANCHOR = [A-Za-z0-9_]+ * LAMBDA = [A-Za-z0-9_]+ - * NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')? + * NODE = ((TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?) | '$' */ bool matchNode(Context& ctx, std::set<MatchingResult>& matches); diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index f694a1234b6037a0ae75a89380af9747765e290c..3be17d6d21d18d63e75e384f2c6e037452db3a82 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -17,6 +17,7 @@ #include <set> #include <string> #include <vector> +#include <deque> #include <utility> #ifdef PYBIND @@ -63,6 +64,9 @@ private: std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */ std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */ + std::deque<std::function<bool()>> mForward; + std::deque<std::function<bool()>> mBackward; + public: Node() = delete; @@ -79,6 +83,22 @@ public: return lhs.shared_from_this() == rhs.shared_from_this(); } + void addBeforeForward(std::function<bool()> func) { + mForward.push_front(func); + } + + void addAfterForward(std::function<bool()> func) { + mForward.push_back(func); + } + + void addBeforeBackward(std::function<bool()> func) { + mBackward.push_front(func); + } + + void addAfterBackward(std::function<bool()> func) { + mBackward.push_back(func); + } + public: /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION diff --git a/include/aidge/operator/Abs.hpp b/include/aidge/operator/Abs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3c2f1bb388cf064be379f476f1d2df4491b57637 --- /dev/null +++ b/include/aidge/operator/Abs.hpp @@ -0,0 +1,71 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_ABS_H_ +#define AIDGE_CORE_OPERATOR_ABS_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Abs_Op : public OperatorTensor, + public Registrable<Abs_Op, std::string, std::shared_ptr<OperatorImpl>(const Abs_Op&)> { +public: + static const std::string Type; + + Abs_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Abs_Op(const Abs_Op& op) + : OperatorTensor(op) + { + if (op.mImpl) { + SET_IMPL_MACRO(Abs_Op, *this, op.backend()); + } else { + mImpl = nullptr; + } + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Abs_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Abs_Op>(*this); + } + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Abs(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Abs_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_ABS_H_ */ diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index f1996fbae025838e2e6f6c21c70018c7cc9746f5..31378532e28c05971e4e3eb5778d4821ce2b6fde 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -61,6 +61,13 @@ public: void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + DimSize_t inChannels() const { + if (!getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Fully Connected (FC) operator has no weight Tensor associated so no specific number of input channel imposed."); + } + return getInput(1)->template dims<2>()[1]; + } + DimSize_t outChannels() const { if (!getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Fully Connected (FC) operator has no weight Tensor associated so no specific number of output channel imposed."); diff --git a/include/aidge/scheduler/MemoryManager.hpp b/include/aidge/scheduler/MemoryManager.hpp index 94add56e8afdebb8e42f7ae49a32da2aeed9e9cb..2e397d1dbaa1cc8d8f586d15363cbd2245963152 100644 --- a/include/aidge/scheduler/MemoryManager.hpp +++ b/include/aidge/scheduler/MemoryManager.hpp @@ -19,6 +19,25 @@ #include "aidge/graph/Node.hpp" namespace Aidge { +/** + * @brief The MemoryManager can be used to generate an optimized static memory + * layout for a computing graph in a global memory space. + * The are some assumptions: + * - A MemoryManager represents a single global memory space, filled with + * contiguous, non-overlapping MemorySpace chunks. + * - A MemorySpace contains one or multiple MemoryPlane, each MemoryPlane + * corresponding to the allocation of a specific Tensor. When a Tensor can re- + * use the memory of the preceding one (for in-place or partially in-place + * operators), multiple overlapping MemoryPlane can be created in the same + * MemorySpace (remember, MemorySpace **cannot** be overlapping!). + * - A MemoryPlane is tailored for handling (N)HWC data with two properties: + * - Possibility of wrapping: on the H axis (each W*C block is contiguous). + * - Possibility of concatenation: on the C axis (C1+C2+...+Cn). + * - All the sizes and offets specified in a MemoryManager are expressed in + * number of data elements, or **words**, meaning currently a uniform data + * precision is expected in a MemoryManager (for instance, if the precision is + * 16-bits, each data element will be 2 bytes, which will be the size of a word). + */ class MemoryManager { public: typedef int Clock_T; @@ -45,18 +64,45 @@ public: allocated(clock_), released(-1) {} + /// Offset of the MemorySpace in the MemoryManager global memory space (in words) unsigned int offset; + /// Size of the MemorySpace (in words) unsigned int size; std::set<std::shared_ptr<Node> > dependencies; Clock_T allocated; Clock_T released; }; - // MemoryPlane belongs to a MemorySpace. Any number of potentially - // overlapping planes can be associated to a MemorySpace. - // MemoryPlane can be non-contiguous (in case of stride, or wrapping, when - // offset + size > memSpace.size). - // MemoryPlane cannot be re-arranged inside a MemorySpace. + /** + * @brief MemoryPlane belongs to a MemorySpace. Any number of potentiall + * overlapping planes can be associated to a MemorySpace. + * MemoryPlane can be non-contiguous (in case of stride, or wrapping, when + * offset + size > memSpace.size). + * MemoryPlane cannot be re-arranged inside a MemorySpace. + * + * A MemoryPlane is tailored for handling (N)HWC data with two properties: + * - Possibility of wrapping: on the H axis (each W*C block is contiguous). + * - Possibility of concatenation: on the C axis (C1+C2+...+Cn). + * + * Detail of (N)HWC data handling: + * - \p length is the size of contiguous and non-breakable memory line (W in HWC); + * - \p count is the number of memory lines of size \p length constituting a memory block (H in HWC); + * - \p stride is the number of channels, or memory blocks, *in total*, + * of \p count lines of size \p length (C in NHWC); + * - \p size is the number of channels, or memory blocks, *in this MemoryPlane*, + * of \p count lines of size \p length. + * In the case of concatenation, there can be multiple overlapping MemoryPlane + * with different size, like NHWC = NHW(C1+C2): + * - MemoryPlane#1: \p size = C1 and \p stride = C=C1+C2 + * - MemoryPlane#2: \p size = C2 and \p stride = C=C1+C2 + * (with an additionnal relative offset of +C1) + * In this mode, wrapping can only occur on the H (\p count) axis. W*C chunks + * are garanteed to be contiguous (\p length * \p stride). + * + * By default, \p stride = \p size, \p count = 1 and \p length = 1, meaning + * there is no NHWC layout and the MemoryPlane can be wrapped **anywhere**. + * In this case, \p size is the total size of the MemoryPlane (H*W*C, in words). + */ struct MemoryPlane { MemoryPlane(std::shared_ptr<MemorySpace> memSpace_, Clock_T clock_, @@ -92,36 +138,91 @@ public: <= memSpace->offset + memSpace->size); } + /** + * @brief Get the total size of the MemoryPlane, including the stride. + * + * @return unsigned int Total size in words + */ inline unsigned int getSize() const { return stride * length * count; } + /** + * @brief Get the useful size of the MemoryPlane, as if its memory blocks + * were contiguous, without stride. + * + * @return unsigned int Useful size in words + */ inline unsigned int getUsefulSize() const { return size * length * count; } + /** + * @brief Get the absolute offset of the beginning of the memory plane. + * + * @return unsigned int Contiguous offset in words + */ inline unsigned int getContiguousOffset() const { return memSpace->offset + offset; } + /** + * @brief Get the size of the contiguous part of the memory plane, from + * its beginning to the limit of the MemorySpace size. + * If the MemoryPlane fill the MemorySpace without wrapping, the contiguous + * size will be the same as the total size of the MemoryPlane. + * + * @return unsigned int Contiguous size in words + */ inline unsigned int getContiguousSize() const { return std::min(getSize(), getLimit()); } + /** + * @brief Get the absolute offset of the wrapped part of the memory plane. + * Since the wrapped part of the memory plane begins at the beginning of + * the MemorySpace, the returned offset is always the same as the MemorySpace + * offset. + * + * @return unsigned int Wrapped offset in words + */ inline unsigned int getWrappedOffset() const { return memSpace->offset; } + /** + * @brief Get the size of the wrapped part of the memory plane, from + * the beginning of the MemorySpace to the total size of the MemoryPlane, + * including the stride. + * If the MemoryPlane fill the MemorySpace without wrapping, the wrapped + * size will 0. + * + * @return unsigned int Wrapped size in words + */ inline unsigned int getWrappedSize() const { return getSize() - getContiguousSize(); } + /** + * @brief Get the absolute offset after the end of the memory plane (if it + * is wrapped, the offset will correspond to the end of the wrapped part). + * The word at the final offset is not included in the MemoryPlane. + * + * @return unsigned int Final offset in words + */ inline unsigned int getFinalOffset() const { return (getWrappedSize() > 0) ? getWrappedOffset() + getWrappedSize() : getContiguousOffset() + getContiguousSize(); } + /** + * @brief Get the absolute offset after the end of the contiguous part + * of the memory plane. + * The word at the upper offset is not included in the MemoryPlane. + * + * @return unsigned int Upper offset in words + */ inline unsigned int getUpperOffset() const { return (getContiguousOffset() + getContiguousSize()); } @@ -146,10 +247,29 @@ public: std::shared_ptr<MemorySpace> memSpace; Clock_T allocated; + /// Relative offset of the MemoryPlane in the MemorySpace (in words) unsigned int offset; + /// Number of channels, or memory blocks, *in this MemoryPlane*, + /// of \p count lines of size \p length. + /// In the case of concatenation, there can be multiple overlapping MemoryPlane + /// with different size, like NHWC = NHW(C1+C2): + /// - MemoryPlane#1: \p size = C1 and \p stride = C=C1+C2 + /// - MemoryPlane#2: \p size = C2 and \p stride = C=C1+C2 + /// (with an additionnal relative offset of +C1) + /// By default, \p stride = \p size, \p count = 1 and \p length = 1, meaning + /// there is no NHWC layout and the MemoryPlane can be wrapped **anywhere**. + /// In this case, \p size is the total size of the MemoryPlane (H*W*C, in words). unsigned int size; + /// Number of channels, or memory blocks *in total*, + /// of \p count lines of size \p length (the C in NHWC). + /// There should be C blocks of H*W size. unsigned int stride; + /// Size of an elementary, contiguous and non-breakable, memory line + /// (the W in NHWC), in words. A MemoryPlane wrapping cannot occur in + /// the middle of a memory line. unsigned int length; + /// Number of memory lines of size \p length constituting a memory block + /// (the H in NHWC). The size of a memory block is H*W. unsigned int count; }; diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp index 3bb41b5bb0d9c2727d95a2656a1a2d5b96ff950b..18e75b7cef5a2e9e9568a900f826a31c87012318 100644 --- a/include/aidge/utils/StaticAttributes.hpp +++ b/include/aidge/utils/StaticAttributes.hpp @@ -158,7 +158,11 @@ public: std::enable_if_t<(SIZE > 0), bool> = true> constexpr const std::type_info& getAttrType(std::size_t i) const { if (i == SIZE-1) { - return typeid(typename std::tuple_element<SIZE-1,std::tuple<T...>>::type); + // Workaround for NVCC from 12.2.1 to 12.4.1 + // error: no suitable constructor exists to convert from "const char *" to "std::type_info" + typename std::tuple_element<SIZE-1,std::tuple<T...>>::type dummy{}; + return typeid(dummy); + //return typeid(typename std::tuple_element<SIZE-1,std::tuple<T...>>::type); } else { return getAttrType<SIZE-1>(i); diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 20bf3fb78d0f14ca1496ef92425ec4cd155f86d5..e382fe2aca4d6e27a00e4e96233e08b50a92418d 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -16,9 +16,11 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Abs.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/Div.hpp" #include "aidge/operator/Mul.hpp" +#include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/Sub.hpp" #include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Transpose.hpp" @@ -106,6 +108,32 @@ Aidge::Tensor Aidge::Tensor::sqrt() const { return sqrt_.getOutput(0)->clone(); } +Aidge::Tensor Aidge::Tensor::abs() const { + AIDGE_ASSERT(hasImpl(), "Tensor has no implementation."); + auto abs_ = Abs_Op(); + abs_.associateInput(0, std::make_shared<Tensor>(*this)); + abs_.setDataType(dataType()); + abs_.setDataFormat(dataFormat()); + abs_.setBackend(mImpl->backend()); + abs_.forward(); + return abs_.getOutput(0)->clone(); +} + +Aidge::Tensor Aidge::Tensor::mean() const { + AIDGE_ASSERT(hasImpl(), "Tensor has no implementation."); + // TODO: should be the default behavior of ReduceMean_Op + // No need to specify the list of all axes! + std::vector<std::int32_t> axes(nbDims()); + std::iota(std::begin(axes), std::end(axes), 0); + auto mean_ = ReduceMean_Op(axes, 0); + mean_.associateInput(0, std::make_shared<Tensor>(*this)); + mean_.setDataType(dataType()); + mean_.setDataFormat(dataFormat()); + mean_.setBackend(mImpl->backend()); + mean_.forward(); + return mean_.getOutput(0)->clone(); +} + Aidge::Tensor& Aidge::Tensor::operator=(const Aidge::Tensor& other) { if (this == &other) { return *this; diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index b93ac16a9384d9b6ec8b62124136cb5085268d58..22be1347aa7ef108f593d3aabe3ff6d75c9312b1 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -56,6 +56,31 @@ std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphM return matches; } +Aidge::SinglePassGraphMatching::MatchingResult Aidge::SinglePassGraphMatching::matchFrom(NodePtr startNode, const std::string& query) { + Context ctx; + ctx.query = query; + ctx.startNode = startNode; + std::set<MatchingResult> matches; + + while (matchSequence(ctx, matches) || matchNodeOrBlock(ctx, matches)) { + removeWhiteSpace(ctx.query); + if (!ctx.query.empty() && ctx.query[0] == ';') { + ctx.query.erase(0, 1); + } + else { + break; + } + } + + removeWhiteSpace(ctx.query); + if (!ctx.query.empty()) { + Log::warn("Syntax error, unable to parse remaining query: {}", ctx.query); + } + + AIDGE_INTERNAL_ASSERT(matches.size() <= 1); + return (!matches.empty()) ? *matches.begin() : MatchingResult(); +} + std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::filterLonguestDisjoint(const std::set<MatchingResult>& matches) { // Sort matches by highest number of nodes first, thanks to the CompareMatchingResultSize function std::set<MatchingResult, CompareMatchingResultSize> sortedMatches(matches.begin(), matches.end()); @@ -218,8 +243,8 @@ bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::set<MatchingR // SEQ | PAR | BLOCK | ALT | NODE if (!matchSequence(newCtx, newMatches) && !matchParallel(newCtx, newMatches) - && !matchBlock(newCtx, newMatches) && !matchAlternative(newCtx, newMatches) + && !matchBlock(newCtx, newMatches) && !matchNode(newCtx, newMatches)) { Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); @@ -368,6 +393,9 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::set<Mat return false; } newCtx.query = altCtx.query; + newCtx.anchors.insert(altCtx.anchors.begin(), altCtx.anchors.end()); + bool firstSequence = altCtx.firstSequence; + bool firstNode = altCtx.firstNode; newMatches.insert(altMatches.begin(), altMatches.end()); bool found = false; @@ -391,6 +419,11 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::set<Mat return false; } newCtx.query = altCtx.query; + newCtx.anchors.insert(altCtx.anchors.begin(), altCtx.anchors.end()); + AIDGE_ASSERT(firstSequence == altCtx.firstSequence, + "Ill-formed query; inconsistency between alternatives regarding first sequence in query at: {}", ctx.query); + AIDGE_ASSERT(firstNode == altCtx.firstNode, + "Ill-formed query; inconsistency between alternatives regarding first node in query at: {}", ctx.query); newMatches.insert(altMatches.begin(), altMatches.end()); } @@ -399,6 +432,9 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::set<Mat return false; } + newCtx.firstSequence = firstSequence; + newCtx.firstNode = firstNode; + --newCtx.depth; ctx = newCtx; matches = newMatches; @@ -513,7 +549,7 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe Log::debug("{}node", std::string(2*newCtx.depth, ' ')); auto newMatches = matches; - // (TYPE | '.') + // (TYPE | '.' | '$') removeWhiteSpace(newCtx.query); if (newCtx.query.empty()) { Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); @@ -521,10 +557,16 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe } std::string type; + bool unconnected = false; if (newCtx.query[0] == '.') { // '.' newCtx.query.erase(0, 1); // drop '.' } + else if (newCtx.query[0] == '$') { + // '$' + newCtx.query.erase(0, 1); // drop '$' + unconnected = true; + } else { // TYPE const auto endIdentifier = std::find_if(newCtx.query.begin(), newCtx.query.end(), @@ -542,6 +584,9 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe // ('#' ANCHOR)? std::string anchor = ""; if (!newCtx.query.empty() && newCtx.query[0] == '#') { + AIDGE_ASSERT(!unconnected, + "Ill-formed query; an anchor cannot be specified for end of graph ($) in query at: {}", ctx.query); + // '#' newCtx.query.erase(0, 1); // drop '#' @@ -555,6 +600,9 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe // ('[' LAMBDA ']')? std::string lambda = ""; if (!newCtx.query.empty() && newCtx.query[0] == '[') { + AIDGE_ASSERT(!unconnected, + "Ill-formed query; a lambda cannot be specified for end of graph ($) in query at: {}", ctx.query); + // '[' newCtx.query.erase(0, 1); @@ -581,9 +629,64 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe } // Parsing is done, try to match the node - if (newCtx.firstSequence && newCtx.firstNode) { + if (unconnected) { + for (auto it = newMatches.begin(); it != newMatches.end(); ) { + bool found = false; + + if (newCtx.lookForChild) { + const auto outputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) + ? ((newCtx.edgeLeftIdx < it->startNode->nbOutputs()) + ? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(it->startNode->output(newCtx.edgeLeftIdx))) + : std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>()) + : it->startNode->outputs(); + + for (const auto& output : outputs) { + for (const auto& node : output) { + if (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx) { + if (mGraph->inView(node.first) && !it->graph->inView(node.first)) { + found = true; + break; + } + } + } + + if (found) { + break; + } + } + } + else { + const auto inputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) + ? ((newCtx.edgeLeftIdx < it->startNode->nbInputs()) + ? std::vector<std::pair<NodePtr, IOIndex_t>>(1, it->startNode->input(newCtx.edgeLeftIdx)) + : std::vector<std::pair<NodePtr, IOIndex_t>>()) + : it->startNode->inputs(); + + for (const auto& input : inputs) { + if (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx) { + if (mGraph->inView(input.first) && !it->graph->inView(input.first)) { + found = true; + break; + } + } + } + } + + if (found) { + it = newMatches.erase(it); + } + else { + ++it; + } + } + + Log::debug("{}node $, found: {}", std::string(2*newCtx.depth + 2, ' '), newMatches.size()); + } + else if (newCtx.firstSequence && newCtx.firstNode) { // First node of first sequence = root node - for (auto node : mGraph->getNodes()) { + const auto nodes = (newCtx.startNode) ? std::set<NodePtr>{newCtx.startNode} : mGraph->getNodes(); + + for (auto node : nodes) { if ((type.empty() || node->type() == type) && (lambda.empty() || mLambda.at(lambda)(node))) { @@ -627,7 +730,9 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe if (newCtx.lookForChild) { const auto outputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) - ? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(it->startNode->output(newCtx.edgeLeftIdx))) + ? ((newCtx.edgeLeftIdx < it->startNode->nbOutputs()) + ? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(it->startNode->output(newCtx.edgeLeftIdx))) + : std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>()) : it->startNode->outputs(); for (const auto& output : outputs) { @@ -664,7 +769,9 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe } else { const auto inputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) - ? std::vector<std::pair<NodePtr, IOIndex_t>>(1, it->startNode->input(newCtx.edgeLeftIdx)) + ? ((newCtx.edgeLeftIdx < it->startNode->nbInputs()) + ? std::vector<std::pair<NodePtr, IOIndex_t>>(1, it->startNode->input(newCtx.edgeLeftIdx)) + : std::vector<std::pair<NodePtr, IOIndex_t>>()) : it->startNode->inputs(); for (const auto& input : inputs) { diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 7fe155b5a2b9b42a1504dbb592b2326d13b99c1e..1035deb366a9c5df6ff08cd87ebd65a11c2b6e78 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -29,8 +29,13 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), mIdOutParents( - std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { + std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) +{ // ctor + if (op) { + mForward.push_back([this](){ this->mOperator->forward(); return true; }); + mBackward.push_back([this](){ this->mOperator->backward(); return true; }); + } } /////////////////////////////////////////////////////// @@ -82,13 +87,27 @@ std::string Aidge::Node::createUniqueName(std::string name){ /////////////////////////////////////////////////////// void Aidge::Node::forward() { - assert((mOperator != nullptr) && "No Operator interface provided, can't run forward().\n"); - mOperator->forward(); + for (auto it = mForward.begin(); it != mForward.end(); ) { + const auto keep = (*it)(); + if (!keep) { + it = mForward.erase(it); + } + else { + ++it; + } + } } void Aidge::Node::backward() { - assert((mOperator != nullptr) && "No Operator interface provided, can't run backward().\n"); - mOperator->backward(); + for (auto it = mBackward.begin(); it != mBackward.end(); ) { + const auto keep = (*it)(); + if (!keep) { + it = mBackward.erase(it); + } + else { + ++it; + } + } } /////////////////////////////////////////////////////// @@ -196,7 +215,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent - // find first occurence of child in the output's children + // find first occurrence of child in the output's children originalParent.first->removeChild(shared_from_this(), originalParent.second); } mIdOutParents[inId] = newNodeoutId; diff --git a/src/operator/Abs.cpp b/src/operator/Abs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8ee706f6c993362e2569b6be86f5e17545ae679 --- /dev/null +++ b/src/operator/Abs.cpp @@ -0,0 +1,25 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/Abs.hpp" + +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +const std::string Aidge::Abs_Op::Type = "Abs"; + +void Aidge::Abs_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + SET_IMPL_MACRO(Abs_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} diff --git a/src/recipes/FuseBatchNorm.cpp b/src/recipes/FuseBatchNorm.cpp index e1553fda551795a0b6f0334ccf1dbd3d2b760085..34722c19f8c0fddaffa7357136f1512a027e1617 100644 --- a/src/recipes/FuseBatchNorm.cpp +++ b/src/recipes/FuseBatchNorm.cpp @@ -90,13 +90,13 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, meanVariance += b_var.get<float>(outChId); ++count; } else { - fmt::print("Zero-variance: {} [{}]\n", convNode->name(), outChId); + Log::notice("Zero-variance: {} [{}]\n", convNode->name(), outChId); } } if (count > 0) meanVariance /= count; else { - fmt::print("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n"); + Log::notice("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } // Add bias if it is non existant, as there will be a bias after the fuse diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 6abb4d37114d0952feb13c6cfbee66bd65dc5748..2fdcd611d378ceb6c3dbdc853920eecf92c31141 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -18,6 +18,8 @@ #include "aidge/graph/Testing.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/operator/Add.hpp" +#include "aidge/operator/BatchNorm.hpp" +#include "aidge/operator/FC.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/Producer.hpp" @@ -27,7 +29,7 @@ using namespace Aidge; void checkMatches(const std::set<SinglePassGraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) { - REQUIRE(results.size() == expected.size()); + CHECK(results.size() == expected.size()); for (const auto& result : results) { const auto found = nodePtrTo(result.graph->getNodes(), nodePtrToName); @@ -347,6 +349,94 @@ TEST_CASE("[core/graph] Matching") { }); } + auto g2 = Sequential({ + Producer({16, 3, 512, 512}, "dataProvider"), + Conv(3, 4, {5, 5}, "conv1"), + BatchNorm<2>(4, 1.0e-5, 0.1, "bn1"), + Conv(4, 4, {5, 5}, "conv2"), + ReLU("relu2"), + Conv(4, 4, {5, 5}, "conv3"), + BatchNorm<2>(4, 1.0e-5, 0.1, "bn3"), + FC(4, 4, false, "fc1"), + FC(4, 4, false, "fc2"), + FC(4, 4, false, "fc3"), + ReLU("relu3"), + Conv(1, 4, {5, 5}, "conv4") + }); + + SECTION("((Conv#->(.[exBN]|$))|(FC#->(.[exFC])*->$))") { + auto gm = SinglePassGraphMatching(g2); + gm.addNodeLambda("exBN", [](const NodePtr& node) { + return (node->type() != "BatchNorm"); + }); + gm.addNodeLambda("exFC", [](const NodePtr& node) { + return (node->type() != "FC"); + }); + + const auto results = gm.match("((Conv#->(.[exBN]|$))|(FC#->(.[exFC])*->$))"); + + checkMatches(results, { + {"conv2", {"conv2", "relu2"}}, + {"conv4", {"conv4"}}, + {"fc3", {"fc3", "relu3", "conv4"}} + }); + } + + // Find last node of a type + SECTION("FC#->(.[exFC])*->$") { + auto gm = SinglePassGraphMatching(g2); + gm.addNodeLambda("exFC", [](const NodePtr& node) { + return (node->type() != "FC"); + }); + + const auto results = gm.match("FC#->(.[exFC])*->$"); + + checkMatches(results, { + {"fc3", {"fc3", "relu3", "conv4"}} + }); + } + + SECTION("Conv#->(.[exConv])*->$") { + auto gm = SinglePassGraphMatching(g2); + gm.addNodeLambda("exConv", [](const NodePtr& node) { + return (node->type() != "Conv"); + }); + + const auto results = gm.match("Conv#->(.[exConv])*->$"); + + checkMatches(results, { + {"conv4", {"conv4"}} + }); + } + + // Find first node of a type + SECTION("FC#<-(.[exFC])*<-$") { + auto gm = SinglePassGraphMatching(g2); + gm.addNodeLambda("exFC", [](const NodePtr& node) { + return (node->type() != "FC"); + }); + + const auto results = gm.match("FC#<-(.[exFC])*<-$"); + + checkMatches(results, { + {"fc1", {"fc1", "bn3", "conv3", "relu2", "conv2", "bn1", "conv1", "dataProvider"}} + }); + } + + SECTION("(((FC#|Conv#)<-(.[exParam])*<-$)|((FC#|Conv#)->(.[exParam])*->$));(FC#|Conv#)<1-Producer#") { + auto gm = SinglePassGraphMatching(g2); + gm.addNodeLambda("exParam", [](const NodePtr& node) { + return (node->type() != "FC" && node->type() != "Conv"); + }); + + const auto results = gm.match("(((FC#|Conv#)<-(.[exParam])*<-$)|((FC#|Conv#)->(.[exParam])*->$));(FC#|Conv#)<1-Producer#"); + + checkMatches(results, { + {"conv1", {"conv1", "conv1_w", "dataProvider"}}, + {"conv4", {"conv4", "conv4_w"}} + }); + } + SECTION("Conv->ReLU [perf]") { const size_t nbTests = 3; std::mt19937::result_type seed(1);