diff --git a/unit_tests/operator/Test_WeightInterleavingImpl.cpp b/unit_tests/operator/Test_WeightInterleavingImpl.cpp index 8d4a6ac5ca4daf81f4f1bbaf502edbcf3c25fce0..9bd9f14681115d8e7d77fbc651596eec714d7b1e 100644 --- a/unit_tests/operator/Test_WeightInterleavingImpl.cpp +++ b/unit_tests/operator/Test_WeightInterleavingImpl.cpp @@ -13,6 +13,8 @@ #include "aidge/data/Tensor.hpp" #include "aidge/operator/WeightInterleaving.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/TensorUtils.hpp" #include "aidge/backend/cpu.hpp" @@ -77,12 +79,12 @@ TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") { }); expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); - expectedWeightInterleaving->setDataType(Aidge::DataType::Int4); + expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving(); auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator()); op->associateInput(0,weight); - op->setDataType(DataType::Int4); + op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); op->setDataFormat(DataFormat::NHWC); op->setBackend("cpu"); myWeightInterleavingNode->forward(); @@ -106,12 +108,12 @@ TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") { }); expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); - expectedWeightInterleaving->setDataType(Aidge::DataType::Int3); + expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type); std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving(); auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator()); op->associateInput(0,weight); - op->setDataType(DataType::Int3); + op->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type); op->setDataFormat(DataFormat::NHWC); op->setBackend("cpu"); myWeightInterleavingNode->forward(); @@ -134,12 +136,12 @@ TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") { }); expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); - expectedWeightInterleaving->setDataType(Aidge::DataType::Int2); + expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int2>::type); std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving(); auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator()); op->associateInput(0,weight); - op->setDataType(DataType::Int2); + op->setDataType(WeightInterleavingType<Aidge::DataType::Int2>::type); op->setDataFormat(DataFormat::NHWC); op->setBackend("cpu"); myWeightInterleavingNode->forward(); @@ -159,12 +161,12 @@ TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") { }); expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); - expectedWeightInterleaving->setDataType(Aidge::DataType::Int4); + expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving(); auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator()); op->associateInput(0,weight); - op->setDataType(DataType::Int4); + op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); op->setDataFormat(DataFormat::NHWC); op->setBackend("cpu"); myWeightInterleavingNode->forward(); @@ -187,12 +189,12 @@ TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") { }); expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); - expectedWeightInterleaving->setDataType(Aidge::DataType::Int4); + expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving(); auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator()); op->associateInput(0,weight); - op->setDataType(DataType::Int4); + op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); op->setDataFormat(DataFormat::NHWC); op->setBackend("cpu"); myWeightInterleavingNode->forward(); @@ -217,12 +219,12 @@ TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") { }); expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); - expectedWeightInterleaving->setDataType(Aidge::DataType::Int3); + expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type); std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving(); auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator()); op->associateInput(0,weight); - op->setDataType(DataType::Int3); + op->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type); op->setDataFormat(DataFormat::NHWC); op->setBackend("cpu"); myWeightInterleavingNode->forward(); @@ -315,16 +317,120 @@ TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") { weight->setDataType(Aidge::DataType::Int4); expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); - expectedWeightInterleaving->setDataType(Aidge::DataType::Int4); + expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving(); auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator()); op->associateInput(0,weight); - op->setDataType(DataType::Int4); + op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type); op->setDataFormat(DataFormat::NHWC); op->setBackend("cpu"); myWeightInterleavingNode->forward(); REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving); } + SECTION("Recipie ApplyWeightInterleaving") { + + // Weight [Cout = 2, H = 3, W = 3, Cin = 4]: + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array4D<std::int8_t,2,3,3,4> { + { + { + { + {-6, 0, 5, -8}, // 'A' '0' '5' '8' in hexadecimal format + { 5, 5, 4, -5}, // '5' '5' '4' 'B' in hexadecimal format + {-7, -1, 4, -7} // '9' 'F' '4' '9' in hexadecimal format + }, + { + { 3, -3, -3, -3}, // '3' 'D' 'D' 'D' in hexadecimal format + { 1, 3, 1, -1}, // '1' '3' '1' 'F' in hexadecimal format + { 7, -3, -1, 4} // '7' 'D' 'F' '4' in hexadecimal format + }, + { + {-1, 3, 5, 6}, // 'F' '3' '5' '6' in hexadecimal format + {-8, 4, 7, 1}, // '8' '4' '7' '1' in hexadecimal format + {-5, 0, -1, -2} // 'B' '0' 'F' 'E' in hexadecimal format + } + }, + { + { + { 2, -7, 7, -4}, // '2' '9' '7' 'C' in hexadecimal format + {-7, 3, 0, 2}, // '9' '3' '0' '2' in hexadecimal format + { 1, -1, 2, 3} // '1' 'F' '2' '3' in hexadecimal format + }, + { + {-1, -5, -3, -7}, // 'F' 'B' 'D' '9' in hexadecimal format + {-8, 3, 5, -1}, // '8' '3' '5' 'F' in hexadecimal format + {-7, -4, -6, -1} // '9' 'C' 'A' 'F' in hexadecimal format + }, + { + { 1, 7, 5, -1}, // '1' '7' '5' 'F' in hexadecimal format + { 1, -8, 1, 2}, // '1' '8' '1' '2' in hexadecimal format + {-1, -6, -3, 0} // 'F' 'A' 'D' '0' in hexadecimal format + } + } + } + }); + + std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array4D<std::int8_t,2,3,3,2> { + { + { + { + {static_cast<int8_t>(0xA0), static_cast<int8_t>(0x58)}, // 'A' '0' '5' '8' in hexadecimal format + {static_cast<int8_t>(0x55), static_cast<int8_t>(0x4B)}, // '5' '5' '4' 'B' in hexadecimal format + {static_cast<int8_t>(0x9F), static_cast<int8_t>(0x49)} // '9' 'F' '4' '9' in hexadecimal format + }, + { + {static_cast<int8_t>(0x3D), static_cast<int8_t>(0xDD)}, // '3' 'D' 'D' 'D' in hexadecimal format + {static_cast<int8_t>(0x13), static_cast<int8_t>(0x1F)}, // '1' '3' '1' 'F' in hexadecimal format + {static_cast<int8_t>(0x7D), static_cast<int8_t>(0xF4)} // '7' 'D' 'F' '4' in hexadecimal format + }, + { + {static_cast<int8_t>(0xF3), static_cast<int8_t>(0x56)}, // 'F' '3' '5' '6' in hexadecimal format + {static_cast<int8_t>(0x84), static_cast<int8_t>(0x71)}, // '8' '4' '7' '1' in hexadecimal format + {static_cast<int8_t>(0xB0), static_cast<int8_t>(0xFE)} // 'B' '0' 'F' 'E' in hexadecimal format + } + }, + { + { + {static_cast<int8_t>(0x29), static_cast<int8_t>(0x7C)}, // '2' '9' '7' 'C' in hexadecimal format + {static_cast<int8_t>(0x93), static_cast<int8_t>(0x02)}, // '9' '3' '0' '2' in hexadecimal format + {static_cast<int8_t>(0x1F), static_cast<int8_t>(0x23)} // '1' 'F' '2' '3' in hexadecimal format + }, + { + {static_cast<int8_t>(0xFB), static_cast<int8_t>(0xD9)}, // 'F' 'B' 'D' '9' in hexadecimal format + {static_cast<int8_t>(0x83), static_cast<int8_t>(0x5F)}, // '8' '3' '5' 'F' in hexadecimal format + {static_cast<int8_t>(0x9C), static_cast<int8_t>(0xAF)} // '9' 'C' 'A' 'F' in hexadecimal format + }, + { + {static_cast<int8_t>(0x17), static_cast<int8_t>(0x5F)}, // '1' '7' '5' 'F' in hexadecimal format + {static_cast<int8_t>(0x18), static_cast<int8_t>(0x12)}, // '1' '8' '1' '2' in hexadecimal format + {static_cast<int8_t>(0xFA), static_cast<int8_t>(0xD0)} // 'F' 'A' 'D' '0' in hexadecimal format + } + } + } + }); + + expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC); + expectedWeightInterleaving->setDataType(Aidge::DataType::Dual_Int4); + + // Create convolution node + std::shared_ptr<Node> conv = Conv(4, 2, {3, 3}, "conv1"); + + // Place the weight tensor in the weight producer of the conv + auto weightProducer = conv->getParent(1); + weightProducer->getOperator()->setOutput(0, weight); + + // Set dataType, dataformat and backend of convolution + conv->getOperator()->setDataFormat(Aidge::DataFormat::NHWC); + conv->getOperator()->setDataType(Aidge::DataType::Int4); + conv->getOperator()->setBackend("cpu"); + + // Apply recipie + applyWeightInterleaving(conv); + + // Compare the weight producer output tensor with the expected weights with interleaving + auto newProdOp = std::static_pointer_cast<OperatorTensor>(conv->getParent(1)->getOperator()); + REQUIRE(*(newProdOp->getOutput(0)) == *expectedWeightInterleaving); + } + }