Skip to content
Snippets Groups Projects
Test_FuseMulAdd.cpp 2.38 KiB
Newer Older
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <catch2/catch_test_macros.hpp>
#include <set>

#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipies/Recipies.hpp"
vincent  lorrain's avatar
vincent lorrain committed

TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
    // generate the original GraphView
    auto matmul0 = MatMul("matmul0");
    auto matmul1 = MatMul("matmul1");

    auto b0 = Producer({5}, "B0");
    auto w0 = Producer({5, 5}, "W0");
    auto b1 = Producer({5}, "B1");
    auto w1 = Producer({5,5},"W1");
    auto input = Producer({2,5}, "input");

    input->addChild(matmul0, 0, 0);
    w0->addChild(matmul0, 0, 1);

    matmul0->addChild(add0, 0, 0);
    b0->addChild(add0, 0, 1);

    add0->addChild(matmul1, 0, 0);
    w1->addChild(matmul1, 0, 1);

    matmul1->addChild(add1, 0, 0);
    b1->addChild(add1, 0, 1);

    auto g = std::make_shared<GraphView>();
    g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1});

    // Check original graph
    REQUIRE(g->getNodes() ==
            std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
    REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0)));
    REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0)));
    REQUIRE(((matmul1->getParent(0) == add0) && (matmul1->getParent(1) == w1)));
    REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));

	// Transform GraphView inplace
    fuseMulAdd(g);

	// Check new GraphView
	 std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
	REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
	REQUIRE(newNodes.size() == 6);
	for (const auto& node : newNodes) {
		REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
	}
}
vincent  lorrain's avatar
vincent lorrain committed

}  // namespace Aidge