Skip to content
Snippets Groups Projects
Commit e4db07cb authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd][WIP] Horizontal Tiling test

parent 78330990
No related branches found
No related tags found
No related merge requests found
...@@ -9,29 +9,174 @@ ...@@ -9,29 +9,174 @@
* *
********************************************************************************/ ********************************************************************************/
#include <catch2/catch_test_macros.hpp> // #include <catch2/catch_test_macros.hpp>
#include <set> // #include <set>
#include "aidge/graph/GraphView.hpp" // #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp" // #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp" // #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp" // #include "aidge/operator/ReLU.hpp"
#include "aidge/recipies/Recipies.hpp" // #include "aidge/recipies/Recipies.hpp"
namespace Aidge { // namespace Aidge {
TEST_CASE("[core/recipies] Tiling(horizontal)", "[Tiling][HorizontalTiling][Recipies]") { // TEST_CASE("[core/recipies] Tiling(transformation)", "[Tiling][Recipies]") {
// SECTION("Transform a pre-generated GraphView") { // SECTION("Transform a pre-generated GraphView") {
// std::shared_ptr<GraphView> g = Sequential({
// Conv(3, 16, {3,3}, "conv1"), // SECTION("Simple Node: Conv") {
// ReLU("relu1"), // std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv");
// Conv(16, 32, {1,1}, "conv2"), // myConv->getOperator()->setDatatype(DataType::Int32);
// Conv(32, 16, {1,1}, "conv3"), // myConv->getOperator()->setBackend("cpu");
// Conv(16, 10, {3,3}, "conv4"), // std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<int,4,3,3,3> {
// ReLU("relu2") // {
// }); // {
// {{ 0, 1, 2},
// { 3, 4, 5},
// { 6, 7, 8}},
// {{ 9, 10, 11},
// { 12, 13, 14},
// { 15, 16, 17}},
// {{ 18, 19, 20},
// { 21, 22, 23},
// { 24, 25, 26}}
// },
// {
// {{ 27, 28, 29},
// { 30, 31, 32},
// { 33, 34, 35}},
// {{ 36, 37, 38},
// { 39, 40, 41},
// { 42, 43, 44}},
// {{ 45, 46, 47},
// { 48, 49, 50},
// { 51, 52, 53}}
// },
// {
// {{ 54, 55, 56},
// { 57, 58, 59},
// { 60, 61, 62}},
// {{ 63, 64, 65},
// { 66, 67, 68},
// { 69, 70, 71}},
// {{ 72, 73, 74},
// { 75, 76, 77},
// { 78, 79, 80}}
// },
// {
// {{ 81, 82, 83},
// { 84, 85, 86},
// { 87, 88, 89}},
// {{ 90, 91, 92},
// { 93, 94, 95},
// { 96, 97, 98}},
// {{ 99, 100, 101},
// {102, 103, 104},
// {105, 106, 107}}
// }
// }
// });
// std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<int,4> {{7,0,9,0}});
// std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<int,2,3,5,5> { //NCHW
// {
// {
// {{ 0, 1, 2, 3, 4},
// { 5, 6, 7, 8, 9},
// { 10, 11, 12, 13, 14},
// { 15, 16, 17, 18, 19},
// { 20, 21, 22, 23, 24}},
// {{ 25, 26, 27, 28, 29},
// { 30, 31, 32, 33, 34},
// { 35, 36, 37, 38, 39},
// { 40, 41, 42, 43, 44},
// { 45, 46, 47, 48, 49}},
// {{ 50, 51, 52, 53, 54},
// { 55, 56, 57, 58, 59},
// { 60, 61, 62, 63, 64},
// { 65, 66, 67, 68, 69},
// { 70, 71, 72, 73, 74}}
// },
// {
// {{ 75, 76, 77, 78, 79},
// { 80, 81, 82, 83, 84},
// { 85, 86, 87, 88, 89},
// { 90, 91, 92, 93, 94},
// { 95, 96, 97, 98, 99}},
// {{100, 101, 102, 103, 104},
// {105, 106, 107, 108, 109},
// {110, 111, 112, 113, 114},
// {115, 116, 117, 118, 119},
// {120, 121, 122, 123, 124}},
// {{125, 126, 127, 128, 129},
// {130, 131, 132, 133, 134},
// {135, 136, 137, 138, 139},
// {140, 141, 142, 143, 144},
// {145, 146, 147, 148, 149}}
// }
// }
// });
// std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<int,2,4,3,3> {
// {
// {
// {{ 15226, 15577, 15928},
// { 16981, 17332, 17683},
// { 18736, 19087, 19438}},
// {{ 37818, 38898, 39978},
// { 43218, 44298, 45378},
// { 48618, 49698, 50778}},
// {{ 60426, 62235, 64044},
// { 69471, 71280, 73089},
// { 78516, 80325, 82134}},
// {{ 83016, 85554, 88092},
// { 95706, 98244, 100782},
// {108396, 110934, 113472}}
// },
// {
// {{ 41551, 41902, 42253},
// { 43306, 43657, 44008},
// { 45061, 45412, 45763}},
// {{118818, 119898, 120978},
// {124218, 125298, 126378},
// {129618, 130698, 131778}},
// {{196101, 197910, 199719},
// {205146, 206955, 208764},
// {214191, 216000, 217809}},
// {{273366, 275904, 278442},
// {286056, 288594, 291132},
// {298746, 301284, 303822}}
// }
// }
// });
// myConv->getOperator()->associateInput(0,myInput);
// myConv->getOperator()->associateInput(1,myWeights);
// myConv->getOperator()->associateInput(2,myBias);
// myConv->getOperator()->computeOutputDims();
// std::shared_ptr<GraphView> g;
// g->add(myConv);
// auto tiledConv = horizontalTile({myConv}, 3);
// g->replace({myConv}, {tiledConv});
// SequentialScheduler s(g);
// s->forward();
// // myConv->getOperator()->getOutput(0)->print();
// REQUIRE(*(myConv->getOperator()->getOutput(0)) == *myOutput);
// }
// std::shared_ptr<GraphView> g = Sequential({
// Conv(3, 16, {3,3}, "conv1"),
// ReLU("relu1"),
// Conv(16, 32, {1,1}, "conv2"),
// Conv(32, 16, {1,1}, "conv3"),
// Conv(16, 10, {3,3}, "conv4"),
// ReLU("relu2")
// });
// for (auto& individualConv : g->match("Conv")) { // for (auto& individualConv : g->match("Conv")) {
// auto tiledConv = horizontalTiling(individualConv); // auto tiledConv = horizontalTiling(individualConv);
...@@ -44,5 +189,5 @@ TEST_CASE("[core/recipies] Tiling(horizontal)", "[Tiling][HorizontalTiling][Reci ...@@ -44,5 +189,5 @@ TEST_CASE("[core/recipies] Tiling(horizontal)", "[Tiling][HorizontalTiling][Reci
// g->addChild(horizontalTiling(Conv())) // g->addChild(horizontalTiling(Conv()))
// } // }
} // }
} // namespace Aidge // } // namespace Aidge
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment