Skip to content
Snippets Groups Projects
Commit 9c5c8aae authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed convToMatMul + unit test

parent 0adac8c0
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!72Im2col
Pipeline #49974 canceled
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/graph/OpArgs.hpp"
#include <cstddef>
using namespace Aidge;
TEST_CASE("[ConvToMatMul] conv") {
auto conv1 = Conv(3, 4, {3, 3}, "conv1");
auto conv2 = Conv(4, 7, {3, 3}, "conv2", {1, 1}, {1, 1}, true);
auto conv3 = Conv(7, 10, {1, 1}, "conv3", {2, 2});
auto g1 = Sequential({
Producer({2, 3, 13, 24}, "dataProvider"),
conv1,
conv2,
conv3
});
g1->setBackend("cpu");
g1->forwardDims();
// Random initialization of input and weights
uniformFiller<float>(std::static_pointer_cast<OperatorTensor>(conv1->getOperator())->getInput(0), -10.0, 10.0);
uniformFiller<float>(std::static_pointer_cast<OperatorTensor>(conv1->getOperator())->getInput(1), -10.0, 10.0);
uniformFiller<float>(std::static_pointer_cast<OperatorTensor>(conv1->getOperator())->getInput(2), -10.0, 10.0);
uniformFiller<float>(std::static_pointer_cast<OperatorTensor>(conv2->getOperator())->getInput(1), -10.0, 10.0);
uniformFiller<float>(std::static_pointer_cast<OperatorTensor>(conv3->getOperator())->getInput(1), -10.0, 10.0);
uniformFiller<float>(std::static_pointer_cast<OperatorTensor>(conv3->getOperator())->getInput(2), -10.0, 10.0);
auto s1 = SequentialScheduler(g1);
s1.forward();
g1->save("convToMatMul_before");
auto g2 = g1->clone();
g2->forwardDims();
REQUIRE(convToMatMul(g2) == 3);
g2->setBackend("cpu");
auto s2 = SequentialScheduler(g2);
s2.forward();
g2->save("convToMatMul_after");
auto g1OutOp = std::static_pointer_cast<OperatorTensor>((*g1->outputNodes().cbegin())->getOperator());
auto g2OutOp = std::static_pointer_cast<OperatorTensor>((*g1->outputNodes().cbegin())->getOperator());
REQUIRE(*(g1OutOp->getOutput(0)) == *(g2OutOp->getOutput(0)));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment