Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
StmFactory.cpp 5.13 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/graphmatching/StmFactory.hpp"

using namespace Aidge;

StmFactory::StmFactory(const std::map<std::string, NodeRegex *> &nodesRegex)
    : mNodesRegex(nodesRegex) {}

SeqStm *StmFactory::duplicateStm(SeqStm *stm) { return stm->duplicateStm(); }

SeqStm *StmFactory::makeNewStm(const std::string &sequRegex) {

  ParsingReturn parsing = initParsingSequRegex(sequRegex);
  std::vector<std::vector<int>> transitionMatrix =
      initTransitionMatrix(parsing);

  std::set<NodeTmp> allNodeValidated;
  std::set<NodeTmp> allNodeTested;
  std::set<std::pair<NodeTmp, std::string>> allCommonNode;

  SeqStm *newStm = new SeqStm(static_cast<int>(mCmptStm), transitionMatrix, mNodesRegex,
                              parsing.typeToIdxTransition, 0, allNodeValidated,
                              allNodeTested, allCommonNode, false);
  mCmptStm += 1;

  return newStm;
}

ParsingReturn StmFactory::initParsingSequRegex(const std::string &sequRegex) {

  std::string toMatch;
  std::regex re("\\s*([A-Za-z]+)(#\\d*)?([+*])?\\s*(->|;)");
  std::smatch matches;

  int idxType = 0;
  // return
  ParsingReturn parsing;
  // std::map<std::pair<NodeType,std::string>,int> typeToIdxTransition;
  // std::vector<std::pair<std::pair<NodeType,std::string>,std::string>>
  // transition;
  // assert
  std::map<NodeType, std::string> assertCommonNodeTypes;

  for (std::size_t i = 0; i < sequRegex.length(); i++) {
    toMatch += sequRegex[i];
    if (std::regex_match(toMatch, matches, re)) {

      std::string type = matches.str(1);
      std::string commonTag = matches.str(2);
      std::string quantification = matches.str(3);

      if ((commonTag != "") && (quantification != "")) {
        throw std::runtime_error("bad commonTag and quantification");
      }

      // make the typeToIdxTransition
      NodeTypeKey typeTag = std::make_pair(type, commonTag);
      /*std::cout << "              typeTag: " << type << "  " << commonTag
                << parsing.typeToIdxTransition.size() << std::endl;*/
      if (parsing.typeToIdxTransition.find(typeTag) ==
          parsing.typeToIdxTransition.end()) {
        parsing.typeToIdxTransition[typeTag] = idxType;
        idxType += 1;
      }
      ////////////////////////////////////////////////////////////
      // ASSERT
      // SAME Common node in the sequ
      if (commonTag != "") {
        if (assertCommonNodeTypes.find(type) != assertCommonNodeTypes.end()) {
          if (assertCommonNodeTypes[type] == commonTag) {
            throw std::runtime_error("same common node in the sequ regex");
          }
        } else {
          assertCommonNodeTypes[type] = commonTag;
        }
      }

      // save all transition
      parsing.transition.push_back(std::make_pair(typeTag, quantification));

      /*std::cout << "Match found: " << matches.str() << std::endl;
      std::cout << "Type: " << matches.str(1) << std::endl;
      std::cout << "Common tag: " << matches.str(2) << std::endl;
      std::cout << "Quantification: " << matches.str(3) << std::endl;*/

      toMatch = "";
    }
  }
  if (parsing.transition.size() == 0) {
    throw std::runtime_error("Bad Parsing SequRegex ");
  }

  return parsing;
}

std::vector<std::vector<int>>
StmFactory::initTransitionMatrix(ParsingReturn &parsing) {

  // std::pair<NodeTypeKey,std::string>
  std::vector<std::vector<int>> transitionMatrix;
  std::size_t numberOfType = parsing.typeToIdxTransition.size();

  if (numberOfType == 0) {
    throw std::runtime_error("Bad number Of Type ");
  }
  // init start st
  transitionMatrix.push_back(std::vector<int>(numberOfType, -1));

  std::size_t idxTransition = 0;
  int idxState = 0;
  for (const auto &pair : parsing.transition) {
    const NodeTypeKey &nodeTypeKey = pair.first;
    const std::string &quant = pair.second;

    /*std::cout << "Key: {" << nodeTypeKey.first << ", " << nodeTypeKey.second
              << "}, Value: " << quant << std::endl;
    std::cout << "idxState " << idxState << " TM: " << transitionMatrix.size()
              << std::endl;*/
    std::size_t idxType = parsing.typeToIdxTransition[nodeTypeKey];
    /*std::cout << "idxType " << idxType << " TM: " << transitionMatrix[0].size()
              << "type" << numberOfType << std::endl;*/

    if (quant == "*") {
      transitionMatrix[idxTransition][idxType] = idxState;
    } else if (quant == "+") {
      idxState += 1;
      transitionMatrix[idxTransition][idxType] = idxState;
      transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
      idxTransition += 1;
      transitionMatrix[idxTransition][idxType] = idxState;
    } else {

      idxState += 1;
      transitionMatrix[idxTransition][idxType] = idxState;
      transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
      idxTransition += 1;
    }
  }
  return transitionMatrix;
}