diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 0fe66e4b64e4113901db2bcd525e1895e642c6de..813301a144682ba3e99de31ae324ffaedcc5209f 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -96,7 +96,7 @@ public: * specified location. * @param path */ - void save(std::string path, bool verbose = false) const; + void save(std::string path, bool verbose = false, bool showProducers = true) const; inline bool inView(NodePtr nodePtr) const { return mNodes.find(nodePtr) != mNodes.end(); diff --git a/include/aidge/operator/Erf.hpp b/include/aidge/operator/Erf.hpp index 6395756f3b08c5838d390ab45d38fa9c03cb91cb..6995cea5e4af9a17cf3d24516d9840850e701669 100644 --- a/include/aidge/operator/Erf.hpp +++ b/include/aidge/operator/Erf.hpp @@ -51,12 +51,9 @@ public: return std::make_shared<Erf_Op>(*this); } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Erf_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index f8276222811f6cc02c062d85e7ae99d72edead7a..20082eed28825ade9d62fb5d4e081840d3bd4442 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -40,7 +40,7 @@ public: Gather_Op() = delete; - + using Attributes_ = StaticAttributes<GatherAttr, int>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; Gather_Op(int axis) @@ -70,13 +70,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Gather_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 1fe050b295e102bcdd4e5bd3651d126754b79618..5955d860a2e9a0db9bb296552927c40eb411f30d 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -46,11 +46,11 @@ public: return std::make_shared<MetaOperator_Op>(*this); } - const std::shared_ptr<GraphView>& getMicroGraph() const { + inline const std::shared_ptr<GraphView>& getMicroGraph() const noexcept { return mGraph; } - const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const { + inline const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const noexcept { return mScheduler; } diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index dd4ad16441f536fd786036672d57817b892cf155..cebc2d54041bb38c6e7f3434f12b559cec3d80af 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -73,12 +73,18 @@ public: public: virtual std::shared_ptr<Operator> clone() const = 0; + /** + * @brief Set the specified input with a shallow copy. + * @param inputIdx Index of the input to set. + * @param data Data to copy. + */ virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0; /** * @brief Set the specified input by performing a deep copy of the given data. * The pointer itself is not changed, thus keeping the current connections. * @param inputIdx Index of the input to set. + * @param data Data to copy. */ virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0; virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0; diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp index 0acd21b28fac54e7e6d30e8219ead0e04ef777f6..52d0118743373c23a4afe4a51d3f22adbe9e6848 100644 --- a/include/aidge/operator/ReduceMean.hpp +++ b/include/aidge/operator/ReduceMean.hpp @@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor, } else outDims.push_back(getInput(0)->dims()[d]); - } + } if(outDims.size()>0) mOutputs[0]->resize(outDims); else @@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor, } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 1ffa045960037f35167ae2d6e8904c49e2c55560..32d71d5adc3cfd92c9840dcb5bc61bfb6399c6db 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -66,12 +66,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Reshape_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp index f111be76cd712265e92e2e4c3e0220f79e13b1f7..2262bec14bd2f00cda643ade0709f7f9d509fa22 100644 --- a/include/aidge/operator/Transpose.hpp +++ b/include/aidge/operator/Transpose.hpp @@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor, } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 32151a66a46f7d7da73473c90effa760ebc93891..8e0da01c89767844040fcbc7b48e727800436daa 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -23,7 +23,7 @@ namespace Aidge { void init_GraphView(py::module& m) { py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView") .def(py::init<>()) - .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, + .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, py::arg("show_producers") = true, R"mydelimiter( Save the GraphView as a Mermaid graph in a .md file at the specified location. @@ -97,7 +97,7 @@ void init_GraphView(py::module& m) { .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) .def("forward_dims", &GraphView::forwardDims) - .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype")) + .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0) .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index c2439a459dcbe1b53d6aa31fd467ca3cd137aa23..968e98e75cc587977eb3033fe7f25936880755a4 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -55,7 +55,7 @@ std::string Aidge::GraphView::name() const { return mName; } void Aidge::GraphView::setName(const std::string &name) { mName = name; } -void Aidge::GraphView::save(std::string path, bool verbose) const { +void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const { FILE *fp = std::fopen((path + ".mmd").c_str(), "w"); std::fprintf(fp, "%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, " @@ -68,7 +68,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { for (const std::shared_ptr<Node> &node_ptr : mNodes) { const std::string currentType = node_ptr->type(); if (typeCounter.find(currentType) == typeCounter.end()) - typeCounter[currentType] = 0; + typeCounter[currentType] = 0; ++typeCounter[currentType]; std::string givenName = @@ -83,13 +83,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { givenName.c_str()); } else { - std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), - givenName.c_str()); + if ((currentType != "Producer") || showProducers) { + std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } } } // Write every link for (const std::shared_ptr<Node> &node_ptr : mNodes) { + if ((node_ptr -> type() == "Producer") && !showProducers) { + continue; + } IOIndex_t outputIdx = 0; for (auto childs : node_ptr->getOrderedChildren()) { for (auto child : childs) { diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp index 9a9b53da615f77dbdb8e597763411a2e84920b2a..00a031e3fa9b03ff1870446b9ae58e8d3eb65bf7 100644 --- a/src/graphRegex/GraphRegex.cpp +++ b/src/graphRegex/GraphRegex.cpp @@ -1,5 +1,5 @@ #include "aidge/graphRegex/GraphRegex.hpp" -using namespace Aidge; +using namespace Aidge; void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ @@ -27,7 +27,7 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ // void GraphRegex::addQuery(const std::string query){ -// //TODO one query only but the same string is a same query but +// //TODO one query only but the same string is a same query but // //2 different string it's maybe the same query , we need to check the AST // mQueryRecipe[query] = nullptr; // } @@ -52,7 +52,7 @@ void GraphRegex::_generateCombinationsStart(const std::set<NodePtr>& elements, s } } - +// factorial(n) tree searched optimized with a stopping condition void GraphRegex::_findLargestCompatibleSet( const std::vector<std::shared_ptr<MatchSolution>>& solutions, std::set<std::shared_ptr<MatchSolution>>& currentSet, @@ -75,6 +75,10 @@ void GraphRegex::_findLargestCompatibleSet( currentSet.insert(solutions[i]); _findLargestCompatibleSet(solutions, currentSet, largestSet, i + 1); currentSet.erase(solutions[i]); + // cut the size of the graph of possibilities + if ((currentSet.size() + solutions.size() - currentIndex) <= largestSet.size()) { + return; + } } } } @@ -101,14 +105,14 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); - // generate all the start possibility + // generate all the start possibility std::size_t nb_startSt = fsm->getNbStart(); std::set<std::vector<NodePtr>> combinations; std::vector<NodePtr> current; _generateCombinationsStart(ref->getNodes(), nb_startSt, 0, current, combinations); - - // all start + + // all start for (const auto& combination : combinations) { std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination); solutions.insert(solutions.end(), solution.begin(), solution.end()); @@ -133,7 +137,7 @@ void GraphRegex::setNodeKey(const std::string key, const std::string conditional void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f){ - //we can applied to all key but it's not efficient + //we can applied to all key but it's not efficient if(mAllLambda.find(key) != mAllLambda.end()){ throw std::runtime_error(key + " is define"); } @@ -142,7 +146,7 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f } void GraphRegex::_majConditionalInterpreterLambda(){ - + for (const auto& test : mAllTest) { for (const auto& pair : mAllLambda) { const std::string& key = pair.first; @@ -151,7 +155,7 @@ void GraphRegex::_majConditionalInterpreterLambda(){ if(!test->isLambdaRegister(key)){ test->insertLambda(key,lambda); } - + } } } diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 9c4cad3f7a444c627f2324f729cb3bc3d8517f49..2fb017567550ada083d0d79d0323b0b45998026f 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -8,30 +8,66 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ -#include <set> #include <cassert> #include <memory> +#include <set> #include <string> -#include "aidge/operator/FC.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/recipies/Recipies.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/operator/GenericOperator.hpp" - +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" -//Graph Regex +// Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" -void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) { +void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, + std::shared_ptr<Aidge::Node> batchnormNode) { + // Case: convNode is a MetaOperator ending with a Convolution + // eg. PaddedConv + if (!(convNode -> getOperator() -> isAtomic())) { + const std::shared_ptr<MetaOperator_Op> metaNode = std::static_pointer_cast<MetaOperator_Op>(convNode -> getOperator()); + const std::shared_ptr<GraphView> metanodeGraph = metaNode -> getMicroGraph(); + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputNodes = metanodeGraph -> getOrderedOutputs(); + if (outputNodes.size() != 1) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad MetaOperator argument for fuseBatchNorm recipie."); + } + convNode = outputNodes[0].first; + } + + AIDGE_ASSERT(((convNode->type() == Conv_Op<2>::Type) || (convNode->type() == ConvDepthWise_Op<2>::Type)), "Wrong type"); + AIDGE_ASSERT(batchnormNode->type() == BatchNorm_Op<2>::Type, "Wrong type for batchnorm node."); // TODO: Find a way to remove the template // A feature map with 2 dimensions is assumed - const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); - const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); + const std::shared_ptr<BatchNorm_Op<2>> batchOp = + std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); + + DimSize_t convNbOutChannels; + DimSize_t channelsSize; + std::array<DimSize_t, 2> kernelDims; + std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator()); + if (convNode->type() == Conv_Op<2>::Type) { + const std::shared_ptr<Conv_Op<2>> convOpPtr = + std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); + convNbOutChannels = convOpPtr->getAttr<DimSize_t>("OutChannels"); + channelsSize = convOpPtr->getAttr<DimSize_t>("InChannels"); + kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims"); + } + else if (convNode->type() == ConvDepthWise_Op<2>::Type) { + const std::shared_ptr<ConvDepthWise_Op<2>> convOpPtr = + std::static_pointer_cast<ConvDepthWise_Op<2>>(convNode->getOperator()); + convNbOutChannels = convOpPtr->getAttr<DimSize_t>("Channels"); + channelsSize = 1; + kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims"); + } std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf; const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu"); @@ -39,20 +75,12 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu"); const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu"); - const float epsilon = batchOp -> getAttr<float>("Epsilon"); - const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels"); - const DimSize_t channelsSize = convOp -> getAttr<DimSize_t>("InChannels"); - const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims"); + const float epsilon = batchOp->getAttr<float>("Epsilon"); - assert(scale.size() == convNbOutChannels); - assert(shift.size() == convNbOutChannels); - assert(b_mean.size() == convNbOutChannels); - assert(b_var.size() == convNbOutChannels); assert(epsilon > 0.0); // TODO : no no_bias attribute ? - float meanVariance = 0.0; unsigned int count = 0; @@ -60,8 +88,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr if (b_var.get<float>(outChId) > 1.0e-12) { meanVariance += b_var.get<float>(outChId); ++count; - } - else { + } else { printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId); } } @@ -86,8 +113,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr // Weights adjustments for (std::size_t channel = 0; channel < channelsSize; ++channel) { // TODO : Suppose kerneldims = 2 - for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ - for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ + for (std::size_t k0 = 0; k0 < kernelDims[0]; ++k0) { + for (std::size_t k1 = 0; k1 < kernelDims[1]; ++k1) { std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1}; float weightValue = weight.get<float>(currentIdx); weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights @@ -119,24 +146,37 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr } void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) { - assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n"); assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n"); for (const auto& op : solution->at("OP")) { - for (const auto& batchNorm : solution->at("BatchNorm")) { - fuseBatchNorm(op,batchNorm); + if (op->getOperator()->isAtomic()) { + for (const auto& batchNorm : solution->at("BatchNorm")) { + fuseBatchNorm(op, batchNorm); + } + } else { // op is a MetaOperator + auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator()); + if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) && + ((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() == + Conv_Op<2>::Type) || + (metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() == + ConvDepthWise_Op<2>::Type))) { + for (const auto& batchNorm : solution->at("BatchNorm")) { + fuseBatchNorm(op, batchNorm); + } + } } } - } void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { - - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'"); - regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' "); + regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'"); + printf("\n============================\nSearching for solutions\n==============================\n"); + regex->setNodeKey( + "OP", + "getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'"); + // || getType($) =='FC' "); regex->addQuery("OP -> BatchNorm");