Skip to content
Snippets Groups Projects
Commit 70860ee1 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added unit test for adaptFCParamsFormat() recipe

parent 8dab18c1
No related branches found
No related tags found
1 merge request!177Improve export
Pipeline #75751 failed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/utils/TensorUtils.hpp"
#include <cstddef>
using namespace Aidge;
TEST_CASE("[AdaptFCParamsFormat] forward", "[AdaptFCParamsFormat][forward][CPU]") {
// generate the original GraphView
auto w = std::make_shared<Tensor>(Array2D<float,2,2*3*3>{{
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
{19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}}});
auto fc = FC(2*3*3, 2, true, "fc");
fc->getOperator()->setInput(1, w);
auto input = Producer(std::make_shared<Tensor>(Array4D<float,1,3,3,2>{{{{{1, 11}, {2, 12}, {3, 13}}, {{4, 14}, {5, 15}, {6, 16}}, {{7, 17}, {8, 18}, {9, 19}}}}}), "input", true);
input->getOperator()->setDataFormat(DataFormat::NHWC);
input->addChild(fc, 0, 0);
auto g = std::make_shared<GraphView>();
g->add({input, fc});
g->setBackend("cpu");
g->save("adaptFCParamsFormat_before");
auto fc2 = FC(2*3*3, 2, true, "fc");
fc2->getOperator()->setInput(1, w);
auto input2 = Producer(std::make_shared<Tensor>(Array4D<float,1,2,3,3>{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{11, 12, 13}, {14, 15, 16}, {17, 18, 19}}}}}), "input", true);
input2->getOperator()->setDataFormat(DataFormat::NCHW);
input2->addChild(fc2, 0, 0);
auto g2 = std::make_shared<GraphView>();
g2->add({input2, fc2});
g2->setBackend("cpu");
SECTION("unfolded") {
adaptFCParamsFormat(g, false);
g->forwardDims({}, true);
g->save("adaptFCParamsFormat_after");
REQUIRE(g->getNodes().size() == 8);
auto scheduler = SequentialScheduler(g);
scheduler.forward();
auto scheduler2 = SequentialScheduler(g2);
scheduler2.forward();
auto fcOp = std::static_pointer_cast<FC_Op>(fc->getOperator());
auto fc2Op = std::static_pointer_cast<FC_Op>(fc2->getOperator());
REQUIRE(approxEq<float>(*(fcOp->getOutput(0)), *(fc2Op->getOutput(0))));
}
SECTION("folded") {
adaptFCParamsFormat(g);
g->forwardDims({}, true);
REQUIRE(g->getNodes().size() == 3);
auto scheduler = SequentialScheduler(g);
scheduler.forward();
auto scheduler2 = SequentialScheduler(g2);
scheduler2.forward();
auto fcOp = std::static_pointer_cast<FC_Op>(fc->getOperator());
auto fc2Op = std::static_pointer_cast<FC_Op>(fc2->getOperator());
REQUIRE(approxEq<float>(*(fcOp->getOutput(0)), *(fc2Op->getOutput(0))));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment