diff --git a/unit_tests/recipies/Test_ExplicitConvert.cpp b/unit_tests/recipies/Test_ExplicitConvert.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80548aac027c9c719049240b18afefd4ca2eb678 --- /dev/null +++ b/unit_tests/recipies/Test_ExplicitConvert.cpp @@ -0,0 +1,46 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/recipies/Recipies.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[ExplicitConvert] conv") { + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3", {2, 2}); + + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + conv1, + conv2, + conv3 + }); + + g1->setBackend("cpu"); + conv1->getOperator()->setDataType(DataType::Int32); + conv3->getOperator()->setDataType(DataType::Float64); + + g1->save("ExplicitConvert_before"); + REQUIRE(g1->getNodes().size() == 10); + + g1->forwardDims(); + explicitConvert(g1); + + g1->save("ExplicitConvert_after"); + REQUIRE(g1->getNodes().size() == 5); +}