Skip to content
Snippets Groups Projects
Commit a3c90cca authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Merge branch 'tiling' into dataloader

parents 363c8929 2ed4315d
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!4Dataloader
......@@ -96,7 +96,7 @@ public:
* specified location.
* @param path
*/
void save(std::string path, bool verbose = false) const;
void save(std::string path, bool verbose = false, bool showProducers = true) const;
inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end();
......
......@@ -51,12 +51,9 @@ public:
return std::make_shared<Erf_Op>(*this);
}
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Erf_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -40,7 +40,7 @@ public:
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, int>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(int axis)
......@@ -70,13 +70,9 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Gather_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
getInput(1)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -46,11 +46,11 @@ public:
return std::make_shared<MetaOperator_Op>(*this);
}
const std::shared_ptr<GraphView>& getMicroGraph() const {
inline const std::shared_ptr<GraphView>& getMicroGraph() const noexcept {
return mGraph;
}
const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const {
inline const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const noexcept {
return mScheduler;
}
......
......@@ -73,12 +73,18 @@ public:
public:
virtual std::shared_ptr<Operator> clone() const = 0;
/**
* @brief Set the specified input with a shallow copy.
* @param inputIdx Index of the input to set.
* @param data Data to copy.
*/
virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
/**
* @brief Set the specified input by performing a deep copy of the given data.
* The pointer itself is not changed, thus keeping the current connections.
* @param inputIdx Index of the input to set.
* @param data Data to copy.
*/
virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0;
......
......@@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor,
}
else
outDims.push_back(getInput(0)->dims()[d]);
}
}
if(outDims.size()>0)
mOutputs[0]->resize(outDims);
else
......@@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor,
}
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -66,12 +66,9 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Reshape_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor,
}
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -23,7 +23,7 @@ namespace Aidge {
void init_GraphView(py::module& m) {
py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView")
.def(py::init<>())
.def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false,
.def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, py::arg("show_producers") = true,
R"mydelimiter(
Save the GraphView as a Mermaid graph in a .md file at the specified location.
......@@ -97,7 +97,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"))
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0)
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
......
......@@ -55,7 +55,7 @@ std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose) const {
void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
......@@ -68,7 +68,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
typeCounter[currentType] = 0;
++typeCounter[currentType];
std::string givenName =
......@@ -83,13 +83,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
givenName.c_str());
}
else {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
if ((currentType != "Producer") || showProducers) {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
}
}
// Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
if ((node_ptr -> type() == "Producer") && !showProducers) {
continue;
}
IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) {
......
#include "aidge/graphRegex/GraphRegex.hpp"
using namespace Aidge;
using namespace Aidge;
void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){
......@@ -27,7 +27,7 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){
// void GraphRegex::addQuery(const std::string query){
// //TODO one query only but the same string is a same query but
// //TODO one query only but the same string is a same query but
// //2 different string it's maybe the same query , we need to check the AST
// mQueryRecipe[query] = nullptr;
// }
......@@ -52,7 +52,7 @@ void GraphRegex::_generateCombinationsStart(const std::set<NodePtr>& elements, s
}
}
// factorial(n) tree searched optimized with a stopping condition
void GraphRegex::_findLargestCompatibleSet(
const std::vector<std::shared_ptr<MatchSolution>>& solutions,
std::set<std::shared_ptr<MatchSolution>>& currentSet,
......@@ -75,6 +75,10 @@ void GraphRegex::_findLargestCompatibleSet(
currentSet.insert(solutions[i]);
_findLargestCompatibleSet(solutions, currentSet, largestSet, i + 1);
currentSet.erase(solutions[i]);
// cut the size of the graph of possibilities
if ((currentSet.size() + solutions.size() - currentIndex) <= largestSet.size()) {
return;
}
}
}
}
......@@ -101,14 +105,14 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest);
std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret();
// generate all the start possibility
// generate all the start possibility
std::size_t nb_startSt = fsm->getNbStart();
std::set<std::vector<NodePtr>> combinations;
std::vector<NodePtr> current;
_generateCombinationsStart(ref->getNodes(), nb_startSt, 0, current, combinations);
// all start
// all start
for (const auto& combination : combinations) {
std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination);
solutions.insert(solutions.end(), solution.begin(), solution.end());
......@@ -133,7 +137,7 @@ void GraphRegex::setNodeKey(const std::string key, const std::string conditional
void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f){
//we can applied to all key but it's not efficient
//we can applied to all key but it's not efficient
if(mAllLambda.find(key) != mAllLambda.end()){
throw std::runtime_error(key + " is define");
}
......@@ -142,7 +146,7 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f
}
void GraphRegex::_majConditionalInterpreterLambda(){
for (const auto& test : mAllTest) {
for (const auto& pair : mAllLambda) {
const std::string& key = pair.first;
......@@ -151,7 +155,7 @@ void GraphRegex::_majConditionalInterpreterLambda(){
if(!test->isLambdaRegister(key)){
test->insertLambda(key,lambda);
}
}
}
}
......
......@@ -8,30 +8,66 @@
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include "aidge/operator/FC.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
//Graph Regex
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) {
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution
// eg. PaddedConv
if (!(convNode -> getOperator() -> isAtomic())) {
const std::shared_ptr<MetaOperator_Op> metaNode = std::static_pointer_cast<MetaOperator_Op>(convNode -> getOperator());
const std::shared_ptr<GraphView> metanodeGraph = metaNode -> getMicroGraph();
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputNodes = metanodeGraph -> getOrderedOutputs();
if (outputNodes.size() != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad MetaOperator argument for fuseBatchNorm recipie.");
}
convNode = outputNodes[0].first;
}
AIDGE_ASSERT(((convNode->type() == Conv_Op<2>::Type) || (convNode->type() == ConvDepthWise_Op<2>::Type)), "Wrong type");
AIDGE_ASSERT(batchnormNode->type() == BatchNorm_Op<2>::Type, "Wrong type for batchnorm node.");
// TODO: Find a way to remove the template
// A feature map with 2 dimensions is assumed
const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
const std::shared_ptr<BatchNorm_Op<2>> batchOp =
std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
DimSize_t convNbOutChannels;
DimSize_t channelsSize;
std::array<DimSize_t, 2> kernelDims;
std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
if (convNode->type() == Conv_Op<2>::Type) {
const std::shared_ptr<Conv_Op<2>> convOpPtr =
std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->getAttr<DimSize_t>("OutChannels");
channelsSize = convOpPtr->getAttr<DimSize_t>("InChannels");
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
else if (convNode->type() == ConvDepthWise_Op<2>::Type) {
const std::shared_ptr<ConvDepthWise_Op<2>> convOpPtr =
std::static_pointer_cast<ConvDepthWise_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->getAttr<DimSize_t>("Channels");
channelsSize = 1;
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu");
......@@ -39,20 +75,12 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu");
const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu");
const float epsilon = batchOp -> getAttr<float>("Epsilon");
const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
const DimSize_t channelsSize = convOp -> getAttr<DimSize_t>("InChannels");
const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims");
const float epsilon = batchOp->getAttr<float>("Epsilon");
assert(scale.size() == convNbOutChannels);
assert(shift.size() == convNbOutChannels);
assert(b_mean.size() == convNbOutChannels);
assert(b_var.size() == convNbOutChannels);
assert(epsilon > 0.0);
// TODO : no no_bias attribute ?
float meanVariance = 0.0;
unsigned int count = 0;
......@@ -60,8 +88,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
if (b_var.get<float>(outChId) > 1.0e-12) {
meanVariance += b_var.get<float>(outChId);
++count;
}
else {
} else {
printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId);
}
}
......@@ -86,8 +113,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
// Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
for (std::size_t k0 = 0; k0 < kernelDims[0]; ++k0) {
for (std::size_t k1 = 0; k1 < kernelDims[1]; ++k1) {
std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
float weightValue = weight.get<float>(currentIdx);
weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
......@@ -119,24 +146,37 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op,batchNorm);
if (op->getOperator()->isAtomic()) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'");
regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' ");
regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
printf("\n============================\nSearching for solutions\n==============================\n");
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment