Skip to content
Snippets Groups Projects
Commit 6c460029 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'tiling' into 'main'

Update fuseBatchNorm to include ConvDepthWise and MetaOperators

See merge request !62
parents 3d6d2bcf ab59233c
No related branches found
No related tags found
2 merge requests!63Temporary main branch,!62Update fuseBatchNorm to include ConvDepthWise and MetaOperators
Pipeline #36713 passed
...@@ -96,7 +96,7 @@ public: ...@@ -96,7 +96,7 @@ public:
* specified location. * specified location.
* @param path * @param path
*/ */
void save(std::string path, bool verbose = false) const; void save(std::string path, bool verbose = false, bool showProducers = true) const;
inline bool inView(NodePtr nodePtr) const { inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end(); return mNodes.find(nodePtr) != mNodes.end();
......
...@@ -51,12 +51,9 @@ public: ...@@ -51,12 +51,9 @@ public:
return std::make_shared<Erf_Op>(*this); return std::make_shared<Erf_Op>(*this);
} }
void setBackend(const std::string& name) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Erf_Op>::create(name)(*this); mImpl = Registrar<Erf_Op>::create(name)(*this);
mOutputs[0]->setBackend(name); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
Gather_Op() = delete; Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, int>; using Attributes_ = StaticAttributes<GatherAttr, int>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(int axis) Gather_Op(int axis)
...@@ -70,13 +70,9 @@ public: ...@@ -70,13 +70,9 @@ public:
void computeOutputDims() override final; void computeOutputDims() override final;
void setBackend(const std::string& name) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Gather_Op>::create(name)(*this); mImpl = Registrar<Gather_Op>::create(name)(*this);
mOutputs[0]->setBackend(name); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
getInput(1)->setBackend(name);
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -46,11 +46,11 @@ public: ...@@ -46,11 +46,11 @@ public:
return std::make_shared<MetaOperator_Op>(*this); return std::make_shared<MetaOperator_Op>(*this);
} }
const std::shared_ptr<GraphView>& getMicroGraph() const { inline const std::shared_ptr<GraphView>& getMicroGraph() const noexcept {
return mGraph; return mGraph;
} }
const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const { inline const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const noexcept {
return mScheduler; return mScheduler;
} }
......
...@@ -73,12 +73,18 @@ public: ...@@ -73,12 +73,18 @@ public:
public: public:
virtual std::shared_ptr<Operator> clone() const = 0; virtual std::shared_ptr<Operator> clone() const = 0;
/**
* @brief Set the specified input with a shallow copy.
* @param inputIdx Index of the input to set.
* @param data Data to copy.
*/
virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0; virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
/** /**
* @brief Set the specified input by performing a deep copy of the given data. * @brief Set the specified input by performing a deep copy of the given data.
* The pointer itself is not changed, thus keeping the current connections. * The pointer itself is not changed, thus keeping the current connections.
* @param inputIdx Index of the input to set. * @param inputIdx Index of the input to set.
* @param data Data to copy.
*/ */
virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0; virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0; virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0;
......
...@@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor, ...@@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor,
} }
else else
outDims.push_back(getInput(0)->dims()[d]); outDims.push_back(getInput(0)->dims()[d]);
} }
if(outDims.size()>0) if(outDims.size()>0)
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
else else
...@@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor, ...@@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor,
} }
} }
void setBackend(const std::string &name) override { void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this); mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -66,12 +66,9 @@ public: ...@@ -66,12 +66,9 @@ public:
void computeOutputDims() override final; void computeOutputDims() override final;
void setBackend(const std::string& name) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Reshape_Op>::create(name)(*this); mImpl = Registrar<Reshape_Op>::create(name)(*this);
mOutputs[0]->setBackend(name); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor, ...@@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor,
} }
} }
void setBackend(const std::string &name) override { void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this); mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -23,7 +23,7 @@ namespace Aidge { ...@@ -23,7 +23,7 @@ namespace Aidge {
void init_GraphView(py::module& m) { void init_GraphView(py::module& m) {
py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView") py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView")
.def(py::init<>()) .def(py::init<>())
.def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, py::arg("show_producers") = true,
R"mydelimiter( R"mydelimiter(
Save the GraphView as a Mermaid graph in a .md file at the specified location. Save the GraphView as a Mermaid graph in a .md file at the specified location.
...@@ -97,7 +97,7 @@ void init_GraphView(py::module& m) { ...@@ -97,7 +97,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes) .def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name")) .def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims) .def("forward_dims", &GraphView::forwardDims)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype")) .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0)
.def("__call__", &GraphView::operator(), py::arg("connectors")) .def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
......
...@@ -55,7 +55,7 @@ std::string Aidge::GraphView::name() const { return mName; } ...@@ -55,7 +55,7 @@ std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; } void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose) const { void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w"); FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp, std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, " "%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
...@@ -68,7 +68,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -68,7 +68,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type(); const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end()) if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0; typeCounter[currentType] = 0;
++typeCounter[currentType]; ++typeCounter[currentType];
std::string givenName = std::string givenName =
...@@ -83,13 +83,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -83,13 +83,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
givenName.c_str()); givenName.c_str());
} }
else { else {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), if ((currentType != "Producer") || showProducers) {
givenName.c_str()); std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
} }
} }
// Write every link // Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &node_ptr : mNodes) {
if ((node_ptr -> type() == "Producer") && !showProducers) {
continue;
}
IOIndex_t outputIdx = 0; IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) { for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) { for (auto child : childs) {
......
#include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
using namespace Aidge; using namespace Aidge;
void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){
...@@ -27,7 +27,7 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ ...@@ -27,7 +27,7 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){
// void GraphRegex::addQuery(const std::string query){ // void GraphRegex::addQuery(const std::string query){
// //TODO one query only but the same string is a same query but // //TODO one query only but the same string is a same query but
// //2 different string it's maybe the same query , we need to check the AST // //2 different string it's maybe the same query , we need to check the AST
// mQueryRecipe[query] = nullptr; // mQueryRecipe[query] = nullptr;
// } // }
...@@ -52,7 +52,7 @@ void GraphRegex::_generateCombinationsStart(const std::set<NodePtr>& elements, s ...@@ -52,7 +52,7 @@ void GraphRegex::_generateCombinationsStart(const std::set<NodePtr>& elements, s
} }
} }
// factorial(n) tree searched optimized with a stopping condition
void GraphRegex::_findLargestCompatibleSet( void GraphRegex::_findLargestCompatibleSet(
const std::vector<std::shared_ptr<MatchSolution>>& solutions, const std::vector<std::shared_ptr<MatchSolution>>& solutions,
std::set<std::shared_ptr<MatchSolution>>& currentSet, std::set<std::shared_ptr<MatchSolution>>& currentSet,
...@@ -75,6 +75,10 @@ void GraphRegex::_findLargestCompatibleSet( ...@@ -75,6 +75,10 @@ void GraphRegex::_findLargestCompatibleSet(
currentSet.insert(solutions[i]); currentSet.insert(solutions[i]);
_findLargestCompatibleSet(solutions, currentSet, largestSet, i + 1); _findLargestCompatibleSet(solutions, currentSet, largestSet, i + 1);
currentSet.erase(solutions[i]); currentSet.erase(solutions[i]);
// cut the size of the graph of possibilities
if ((currentSet.size() + solutions.size() - currentIndex) <= largestSet.size()) {
return;
}
} }
} }
} }
...@@ -101,14 +105,14 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph ...@@ -101,14 +105,14 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest);
std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret();
// generate all the start possibility // generate all the start possibility
std::size_t nb_startSt = fsm->getNbStart(); std::size_t nb_startSt = fsm->getNbStart();
std::set<std::vector<NodePtr>> combinations; std::set<std::vector<NodePtr>> combinations;
std::vector<NodePtr> current; std::vector<NodePtr> current;
_generateCombinationsStart(ref->getNodes(), nb_startSt, 0, current, combinations); _generateCombinationsStart(ref->getNodes(), nb_startSt, 0, current, combinations);
// all start
// all start
for (const auto& combination : combinations) { for (const auto& combination : combinations) {
std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination); std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination);
solutions.insert(solutions.end(), solution.begin(), solution.end()); solutions.insert(solutions.end(), solution.begin(), solution.end());
...@@ -133,7 +137,7 @@ void GraphRegex::setNodeKey(const std::string key, const std::string conditional ...@@ -133,7 +137,7 @@ void GraphRegex::setNodeKey(const std::string key, const std::string conditional
void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f){ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f){
//we can applied to all key but it's not efficient //we can applied to all key but it's not efficient
if(mAllLambda.find(key) != mAllLambda.end()){ if(mAllLambda.find(key) != mAllLambda.end()){
throw std::runtime_error(key + " is define"); throw std::runtime_error(key + " is define");
} }
...@@ -142,7 +146,7 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f ...@@ -142,7 +146,7 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f
} }
void GraphRegex::_majConditionalInterpreterLambda(){ void GraphRegex::_majConditionalInterpreterLambda(){
for (const auto& test : mAllTest) { for (const auto& test : mAllTest) {
for (const auto& pair : mAllLambda) { for (const auto& pair : mAllLambda) {
const std::string& key = pair.first; const std::string& key = pair.first;
...@@ -151,7 +155,7 @@ void GraphRegex::_majConditionalInterpreterLambda(){ ...@@ -151,7 +155,7 @@ void GraphRegex::_majConditionalInterpreterLambda(){
if(!test->isLambdaRegister(key)){ if(!test->isLambdaRegister(key)){
test->insertLambda(key,lambda); test->insertLambda(key,lambda);
} }
} }
} }
} }
......
...@@ -8,30 +8,66 @@ ...@@ -8,30 +8,66 @@
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#include <set>
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include "aidge/operator/FC.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipies/Recipies.hpp" #include "aidge/recipies/Recipies.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
//Graph Regex // Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) { void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution
// eg. PaddedConv
if (!(convNode -> getOperator() -> isAtomic())) {
const std::shared_ptr<MetaOperator_Op> metaNode = std::static_pointer_cast<MetaOperator_Op>(convNode -> getOperator());
const std::shared_ptr<GraphView> metanodeGraph = metaNode -> getMicroGraph();
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputNodes = metanodeGraph -> getOrderedOutputs();
if (outputNodes.size() != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad MetaOperator argument for fuseBatchNorm recipie.");
}
convNode = outputNodes[0].first;
}
AIDGE_ASSERT(((convNode->type() == Conv_Op<2>::Type) || (convNode->type() == ConvDepthWise_Op<2>::Type)), "Wrong type");
AIDGE_ASSERT(batchnormNode->type() == BatchNorm_Op<2>::Type, "Wrong type for batchnorm node.");
// TODO: Find a way to remove the template // TODO: Find a way to remove the template
// A feature map with 2 dimensions is assumed // A feature map with 2 dimensions is assumed
const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); const std::shared_ptr<BatchNorm_Op<2>> batchOp =
const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
DimSize_t convNbOutChannels;
DimSize_t channelsSize;
std::array<DimSize_t, 2> kernelDims;
std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
if (convNode->type() == Conv_Op<2>::Type) {
const std::shared_ptr<Conv_Op<2>> convOpPtr =
std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->getAttr<DimSize_t>("OutChannels");
channelsSize = convOpPtr->getAttr<DimSize_t>("InChannels");
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
else if (convNode->type() == ConvDepthWise_Op<2>::Type) {
const std::shared_ptr<ConvDepthWise_Op<2>> convOpPtr =
std::static_pointer_cast<ConvDepthWise_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->getAttr<DimSize_t>("Channels");
channelsSize = 1;
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf; std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu"); const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu");
...@@ -39,20 +75,12 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -39,20 +75,12 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu"); const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu");
const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu"); const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu");
const float epsilon = batchOp -> getAttr<float>("Epsilon"); const float epsilon = batchOp->getAttr<float>("Epsilon");
const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
const DimSize_t channelsSize = convOp -> getAttr<DimSize_t>("InChannels");
const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims");
assert(scale.size() == convNbOutChannels);
assert(shift.size() == convNbOutChannels);
assert(b_mean.size() == convNbOutChannels);
assert(b_var.size() == convNbOutChannels);
assert(epsilon > 0.0); assert(epsilon > 0.0);
// TODO : no no_bias attribute ? // TODO : no no_bias attribute ?
float meanVariance = 0.0; float meanVariance = 0.0;
unsigned int count = 0; unsigned int count = 0;
...@@ -60,8 +88,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -60,8 +88,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
if (b_var.get<float>(outChId) > 1.0e-12) { if (b_var.get<float>(outChId) > 1.0e-12) {
meanVariance += b_var.get<float>(outChId); meanVariance += b_var.get<float>(outChId);
++count; ++count;
} } else {
else {
printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId); printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId);
} }
} }
...@@ -86,8 +113,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -86,8 +113,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
// Weights adjustments // Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) { for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2 // TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ for (std::size_t k0 = 0; k0 < kernelDims[0]; ++k0) {
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ for (std::size_t k1 = 0; k1 < kernelDims[1]; ++k1) {
std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1}; std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
float weightValue = weight.get<float>(currentIdx); float weightValue = weight.get<float>(currentIdx);
weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
...@@ -119,24 +146,37 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -119,24 +146,37 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
} }
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) { void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n"); assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n"); assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) { for (const auto& op : solution->at("OP")) {
for (const auto& batchNorm : solution->at("BatchNorm")) { if (op->getOperator()->isAtomic()) {
fuseBatchNorm(op,batchNorm); for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
} }
} }
} }
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'"); regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' "); printf("\n============================\nSearching for solutions\n==============================\n");
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm"); regex->addQuery("OP -> BatchNorm");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment