Skip to content
Snippets Groups Projects
Commit b31c6d04 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] horizontalTiling test

parent 7c94c8b0
No related branches found
No related tags found
1 merge request!23Tiling
Pipeline #35106 canceled
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/operator/Concat.hpp"
namespace Aidge {
TEST_CASE("[core/recipies] Tiling(transformation)", "[Tiling][Recipies]") {
SECTION("Transform a pre-generated GraphView") {
SECTION("Simple Node: Conv") {
std::shared_ptr<Node> myReLU = ReLU("myReLU");
std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv");
std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<int,4,3,3,3> {
{
{
{{ 0, 1, 2},
{ 3, 4, 5},
{ 6, 7, 8}},
{{ 9, 10, 11},
{ 12, 13, 14},
{ 15, 16, 17}},
{{ 18, 19, 20},
{ 21, 22, 23},
{ 24, 25, 26}}
},
{
{{ 27, 28, 29},
{ 30, 31, 32},
{ 33, 34, 35}},
{{ 36, 37, 38},
{ 39, 40, 41},
{ 42, 43, 44}},
{{ 45, 46, 47},
{ 48, 49, 50},
{ 51, 52, 53}}
},
{
{{ 54, 55, 56},
{ 57, 58, 59},
{ 60, 61, 62}},
{{ 63, 64, 65},
{ 66, 67, 68},
{ 69, 70, 71}},
{{ 72, 73, 74},
{ 75, 76, 77},
{ 78, 79, 80}}
},
{
{{ 81, 82, 83},
{ 84, 85, 86},
{ 87, 88, 89}},
{{ 90, 91, 92},
{ 93, 94, 95},
{ 96, 97, 98}},
{{ 99, 100, 101},
{102, 103, 104},
{105, 106, 107}}
}
}
});
std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<int,4> {{7,0,9,0}});
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<int,2,3,5,5> { //NCHW
{
{
{{ 0, 1, 2, 3, 4},
{ 5, 6, 7, 8, 9},
{ 10, 11, 12, 13, 14},
{ 15, 16, 17, 18, 19},
{ 20, 21, 22, 23, 24}},
{{ 25, 26, 27, 28, 29},
{ 30, 31, 32, 33, 34},
{ 35, 36, 37, 38, 39},
{ 40, 41, 42, 43, 44},
{ 45, 46, 47, 48, 49}},
{{ 50, 51, 52, 53, 54},
{ 55, 56, 57, 58, 59},
{ 60, 61, 62, 63, 64},
{ 65, 66, 67, 68, 69},
{ 70, 71, 72, 73, 74}}
},
{
{{ 75, 76, 77, 78, 79},
{ 80, 81, 82, 83, 84},
{ 85, 86, 87, 88, 89},
{ 90, 91, 92, 93, 94},
{ 95, 96, 97, 98, 99}},
{{100, 101, 102, 103, 104},
{105, 106, 107, 108, 109},
{110, 111, 112, 113, 114},
{115, 116, 117, 118, 119},
{120, 121, 122, 123, 124}},
{{125, 126, 127, 128, 129},
{130, 131, 132, 133, 134},
{135, 136, 137, 138, 139},
{140, 141, 142, 143, 144},
{145, 146, 147, 148, 149}}
}
}
});
std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<int,2,4,3,3> {
{
{
{{ 15226, 15577, 15928},
{ 16981, 17332, 17683},
{ 18736, 19087, 19438}},
{{ 37818, 38898, 39978},
{ 43218, 44298, 45378},
{ 48618, 49698, 50778}},
{{ 60426, 62235, 64044},
{ 69471, 71280, 73089},
{ 78516, 80325, 82134}},
{{ 83016, 85554, 88092},
{ 95706, 98244, 100782},
{108396, 110934, 113472}}
},
{
{{ 41551, 41902, 42253},
{ 43306, 43657, 44008},
{ 45061, 45412, 45763}},
{{118818, 119898, 120978},
{124218, 125298, 126378},
{129618, 130698, 131778}},
{{196101, 197910, 199719},
{205146, 206955, 208764},
{214191, 216000, 217809}},
{{273366, 275904, 278442},
{286056, 288594, 291132},
{298746, 301284, 303822}}
}
}
});
myReLU->getOperator()->associateInput(0, myInput);
myReLU->addChild(myConv, 0, 0);
myConv->getOperator()->setInput(1, myWeights);
myConv->getOperator()->setInput(2, myBias);
std::dynamic_pointer_cast<Conv_Op<2>>(myConv->getOperator())->computeOutputDims();
std::shared_ptr<GraphView> g = std::make_shared<GraphView>();
g->add({myReLU, myConv});
g->compile("cpu", DataType::Int32);
std::set<std::shared_ptr<Node>> tiledConv = getConvHorizontalTiling(myConv, 2, 3);
SequentialScheduler s(g);
s.forward();
REQUIRE(*(std::dynamic_pointer_cast<Conv_Op<2>>(myConv->getOperator())->getOutput(0)) == *myOutput);
GraphView::replace({myConv, myConv->getParent(1), myConv->getParent(2)}, tiledConv);
g->compile("cpu", DataType::Int32);
s.resetScheduling();
s.forward();
REQUIRE(*(std::dynamic_pointer_cast<OperatorTensor>((*g->outputNodes().begin())->getOperator())->getOutput(0)) == *myOutput);
}
}
}
}
// std::shared_ptr<GraphView> g = Sequential({
// Conv(3, 16, {3,3}, "conv1"),
// ReLU("relu1"),
// Conv(16, 32, {1,1}, "conv2"),
// Conv(32, 16, {1,1}, "conv3"),
// Conv(16, 10, {3,3}, "conv4"),
// ReLU("relu2")
// });
// for (auto& individualConv : g->match("Conv")) {
// auto tiledConv = horizontalTiling(individualConv);
// g->replace(individualConv, tiledConv);
// }
// }
// SECTION("Create the GraphView with tiled layers") {
// std::shared_ptr<GraphView> g;
// g->addChild(horizontalTiling(Conv()))
// }
// }
// } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment