Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ShapeFolding.cpp 1.28 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/
// #include <cassert>
#include <memory>
#include <set>
#include <string>

#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Shape.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/Log.hpp"
// #include "aidge/utils/Types.h"

bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims) {
    bool modified      = false;
    bool forwarded     = false;
    bool not_shape_present = true;
    for (auto nodePtr: graph->getNodes())
        not_shape_present &= (nodePtr->type() != Shape_Op::Type);
    if (not_shape_present)
        return false;
    do{
        forwarded = graph->forwardDims(dims, true);
        modified = constantFolding(graph, true);
    } while(modified);
    if (!forwarded){
        Log::warn("Failed to forward GraphView.");
    }

    return modified;
}