Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_WeightInterleavingImpl.cpp 23.83 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <catch2/catch_test_macros.hpp>

#include "aidge/data/Tensor.hpp"
#include "aidge/operator/WeightInterleaving.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/TensorUtils.hpp"

#include "aidge/backend/cpu.hpp"

#include <memory>

using namespace Aidge;

TEST_CASE("[cpu/operator] WeightInterleaving", "[WeightInterleaving][CPU]") {
    
    std::shared_ptr<Node> myWeightInterleaving = WeightInterleaving();
    auto opWeightInterleaving = std::static_pointer_cast<WeightInterleaving_Op>(myWeightInterleaving -> getOperator());

    SECTION("CompactDataSize - Single element cases") {
        REQUIRE(opWeightInterleaving->compactDataSize(1, 1) == 1);  // 1 bit, needs 1 byte
        REQUIRE(opWeightInterleaving->compactDataSize(1, 7) == 1);  // 7 bits, needs 1 byte
    }

    SECTION("CompactDataSize - Boundary cases for different nb_bits values") {
        REQUIRE(opWeightInterleaving->compactDataSize(8, 1) == 1);  // 8 elements at 1 bit each, fits in 1 byte
        REQUIRE(opWeightInterleaving->compactDataSize(8, 2) == 2);  // 8 elements at 2 bits each, needs 2 bytes
        REQUIRE(opWeightInterleaving->compactDataSize(8, 3) == 4);  // 8 elements at 3 bits each, needs 4 bytes
        REQUIRE(opWeightInterleaving->compactDataSize(8, 4) == 4);  // 8 elements at 4 bits each, needs 4 bytes
    }

    SECTION("CompactDataSize - Larger dataSize values") {
        REQUIRE(opWeightInterleaving->compactDataSize(16, 1) == 2);  // 16 elements at 1 bit each, fits in 2 bytes
        REQUIRE(opWeightInterleaving->compactDataSize(16, 2) == 4);  // 16 elements at 2 bits each, needs 4 bytes
        REQUIRE(opWeightInterleaving->compactDataSize(16, 3) == 8);  // 16 elements at 3 bits each, needs 6 bytes
        REQUIRE(opWeightInterleaving->compactDataSize(16, 4) == 8);  // 16 elements at 4 bits each, needs 8 bytes
    }

    SECTION("CompactDataSize - Odd dataSize values with varying nb_bits") {
        REQUIRE(opWeightInterleaving->compactDataSize(7, 1) == 1);  // 7 elements at 1 bit each, fits in 1 byte
        REQUIRE(opWeightInterleaving->compactDataSize(7, 2) == 2);  // 7 elements at 2 bits each, needs 2 bytes
        REQUIRE(opWeightInterleaving->compactDataSize(7, 3) == 4);  // 7 elements at 3 bits each, needs 4 bytes
        REQUIRE(opWeightInterleaving->compactDataSize(7, 4) == 4);  // 7 elements at 4 bits each, needs 4 bytes
    }

    SECTION("CompactDataSize - Minimum and maximum values for nb_bits") {
        REQUIRE(opWeightInterleaving->compactDataSize(5, 1) == 1);  // 5 elements at 1 bit each, fits in 1 byte
    }

    SECTION("CompactDataSize - Edge Case - dataSize of 0 should result in 0 required size") {
        REQUIRE(opWeightInterleaving->compactDataSize(0, 1) == 0);  // No data elements
    }


    SECTION("CompactData - 4-bit compaction") {
        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array1D<std::int8_t, 4>{
                                                                {static_cast<std::int8_t>(0x0F), 
                                                                static_cast<std::int8_t>(0xF5), 
                                                                static_cast<std::int8_t>(0xB3), 
                                                                static_cast<std::int8_t>(0x9C)}
                                                                });

        weight->setDataFormat(Aidge::DataFormat::NHWC);
        weight->setDataType(Aidge::DataType::Int4);

        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array1D<std::int8_t, 2>{
                                                                {static_cast<int8_t>(0xF5), 
                                                                static_cast<int8_t>(0x3C)}
                                                                });

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);

        std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving();
        auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator());
        op->associateInput(0,weight);
        op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);
        op->setDataFormat(DataFormat::NHWC);
        op->setBackend("cpu");
        myWeightInterleavingNode->forward();
        REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving);
    }

    SECTION("CompactData - 3-bit compaction") {
        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array1D<std::int8_t, 4>{
                                                                {static_cast<int8_t>(0x0F), 
                                                                static_cast<int8_t>(0x05), 
                                                                static_cast<int8_t>(0x04),
                                                                static_cast<int8_t>(0xD3)}
                                                                });

        weight->setDataFormat(Aidge::DataFormat::NHWC);
        weight->setDataType(Aidge::DataType::Int3);

        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array1D<std::int8_t, 2>{
                                                                {static_cast<int8_t>(0x75), 
                                                                static_cast<int8_t>(0x43)}
                                                                });

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type);

        std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving();
        auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator());
        op->associateInput(0,weight);
        op->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type);
        op->setDataFormat(DataFormat::NHWC);
        op->setBackend("cpu");
        myWeightInterleavingNode->forward();
        REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving);
    }

    SECTION("CompactData - 2-bit compaction") {
        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array1D<std::int8_t, 4>{
                                                                {static_cast<std::int8_t>(0x03),
                                                                 static_cast<std::int8_t>(0x02),
                                                                 static_cast<std::int8_t>(0x01), 
                                                                 static_cast<std::int8_t>(0x00)}
                                                                 });

        weight->setDataFormat(Aidge::DataFormat::NHWC);
        weight->setDataType(Aidge::DataType::Int2);

        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array1D<std::int8_t, 1>{
                                                                {static_cast<int8_t>(0xE4)}
                                                                });

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int2>::type);
        std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving();
        auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator());
        op->associateInput(0,weight);
        op->setDataType(WeightInterleavingType<Aidge::DataType::Int2>::type);
        op->setDataFormat(DataFormat::NHWC);
        op->setBackend("cpu");
        myWeightInterleavingNode->forward();
        REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving);
    }

    SECTION("CompactData - Edge Cases - Single element data") {
        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array1D<std::int8_t, 1>{
                                                                {static_cast<int8_t>(0x0F)}
                                                                });

        weight->setDataFormat(Aidge::DataFormat::NHWC);
        weight->setDataType(Aidge::DataType::Int4);

        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array1D<std::int8_t, 1>{
                                                                {static_cast<int8_t>(0xF0)}
                                                                });

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);

        std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving();
        auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator());
        op->associateInput(0,weight);
        op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);
        op->setDataFormat(DataFormat::NHWC);
        op->setBackend("cpu");
        myWeightInterleavingNode->forward();
        REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving);
    }

    SECTION("CompactData - Edge Cases - Non-divisible dataSize for nbSlot with nbbits=4") {
        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array1D<std::int8_t, 3>{
                                                                {static_cast<int8_t>(0x0F), 
                                                                static_cast<int8_t>(0xA5), 
                                                                static_cast<int8_t>(0x34)}
                                                                });

        weight->setDataFormat(Aidge::DataFormat::NHWC);
        weight->setDataType(Aidge::DataType::Int4);

        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array1D<std::int8_t, 2>{
                                                                {static_cast<int8_t>(0xF5), 
                                                                static_cast<int8_t>(0x40)}
                                                                });

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);

        std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving();
        auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator());
        op->associateInput(0,weight);
        op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);
        op->setDataFormat(DataFormat::NHWC);
        op->setBackend("cpu");
        myWeightInterleavingNode->forward();
        REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving);

    }

    SECTION("CompactData - Edge Cases - Non-divisible dataSize for nbSlot with nbbits=3") {

        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array1D<std::int8_t, 3>{
                                                                {static_cast<int8_t>(0x0F), 
                                                                static_cast<int8_t>(0x05), 
                                                                static_cast<int8_t>(0x04)}
                                                                });

        weight->setDataFormat(Aidge::DataFormat::NHWC);
        weight->setDataType(Aidge::DataType::Int3);

        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array1D<std::int8_t, 2>{
                                                                {static_cast<int8_t>(0x75), 
                                                                static_cast<int8_t>(0x40)}
                                                                });

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type);

        std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving();
        auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator());
        op->associateInput(0,weight);
        op->setDataType(WeightInterleavingType<Aidge::DataType::Int3>::type);
        op->setDataFormat(DataFormat::NHWC);
        op->setBackend("cpu");
        myWeightInterleavingNode->forward();
        REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving);

    }

    SECTION("Forward Op - Convolution weight interleaving") {

        // Weight [Cout = 2, H = 3, W = 3, Cin = 4]:
        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array4D<std::int8_t,2,3,3,4> {
            {
                {
                    {
                        {-6,  0,  5, -8}, // 'A' '0' '5' '8' in hexadecimal format
                        { 5,  5,  4, -5}, // '5' '5' '4' 'B' in hexadecimal format
                        {-7, -1,  4, -7}  // '9' 'F' '4' '9' in hexadecimal format
                    },
                    {
                        { 3, -3, -3, -3}, // '3' 'D' 'D' 'D' in hexadecimal format
                        { 1,  3,  1, -1}, // '1' '3' '1' 'F' in hexadecimal format
                        { 7, -3, -1,  4}  // '7' 'D' 'F' '4' in hexadecimal format
                    },
                    {
                        {-1,  3,  5,  6}, // 'F' '3' '5' '6' in hexadecimal format
                        {-8,  4,  7,  1}, // '8' '4' '7' '1' in hexadecimal format
                        {-5,  0, -1, -2}  // 'B' '0' 'F' 'E' in hexadecimal format
                    }
                },
                {
                    {
                        { 2, -7,  7, -4}, // '2' '9' '7' 'C' in hexadecimal format
                        {-7,  3,  0,  2}, // '9' '3' '0' '2' in hexadecimal format
                        { 1, -1,  2,  3}  // '1' 'F' '2' '3' in hexadecimal format
                    },
                    {
                        {-1, -5, -3, -7}, // 'F' 'B' 'D' '9' in hexadecimal format
                        {-8,  3,  5, -1}, // '8' '3' '5' 'F' in hexadecimal format
                        {-7, -4, -6, -1}  // '9' 'C' 'A' 'F' in hexadecimal format
                    },
                    {
                        { 1,  7,  5, -1}, // '1' '7' '5' 'F' in hexadecimal format
                        { 1, -8,  1,  2}, // '1' '8' '1' '2' in hexadecimal format
                        {-1, -6, -3,  0}  // 'F' 'A' 'D' '0' in hexadecimal format
                    }
                }
            } 
        });
        
        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array4D<std::int8_t,2,3,3,2> {
            {
                {
                    {
                        {static_cast<int8_t>(0xA0), static_cast<int8_t>(0x58)}, // 'A' '0' '5' '8' in hexadecimal format
                        {static_cast<int8_t>(0x55), static_cast<int8_t>(0x4B)}, // '5' '5' '4' 'B' in hexadecimal format
                        {static_cast<int8_t>(0x9F), static_cast<int8_t>(0x49)}  // '9' 'F' '4' '9' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0x3D), static_cast<int8_t>(0xDD)}, // '3' 'D' 'D' 'D' in hexadecimal format
                        {static_cast<int8_t>(0x13), static_cast<int8_t>(0x1F)}, // '1' '3' '1' 'F' in hexadecimal format
                        {static_cast<int8_t>(0x7D), static_cast<int8_t>(0xF4)}  // '7' 'D' 'F' '4' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0xF3), static_cast<int8_t>(0x56)}, // 'F' '3' '5' '6' in hexadecimal format
                        {static_cast<int8_t>(0x84), static_cast<int8_t>(0x71)}, // '8' '4' '7' '1' in hexadecimal format
                        {static_cast<int8_t>(0xB0), static_cast<int8_t>(0xFE)}  // 'B' '0' 'F' 'E' in hexadecimal format
                    }
                },
                {
                    {
                        {static_cast<int8_t>(0x29), static_cast<int8_t>(0x7C)}, // '2' '9' '7' 'C' in hexadecimal format
                        {static_cast<int8_t>(0x93), static_cast<int8_t>(0x02)}, // '9' '3' '0' '2' in hexadecimal format
                        {static_cast<int8_t>(0x1F), static_cast<int8_t>(0x23)}  // '1' 'F' '2' '3' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0xFB), static_cast<int8_t>(0xD9)}, // 'F' 'B' 'D' '9' in hexadecimal format
                        {static_cast<int8_t>(0x83), static_cast<int8_t>(0x5F)}, // '8' '3' '5' 'F' in hexadecimal format
                        {static_cast<int8_t>(0x9C), static_cast<int8_t>(0xAF)}  // '9' 'C' 'A' 'F' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0x17), static_cast<int8_t>(0x5F)}, // '1' '7' '5' 'F' in hexadecimal format
                        {static_cast<int8_t>(0x18), static_cast<int8_t>(0x12)}, // '1' '8' '1' '2' in hexadecimal format
                        {static_cast<int8_t>(0xFA), static_cast<int8_t>(0xD0)}  // 'F' 'A' 'D' '0' in hexadecimal format
                    }
                }
            } 
        });

        weight->setDataFormat(Aidge::DataFormat::NHWC);
        weight->setDataType(Aidge::DataType::Int4);

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);

        std::shared_ptr<Node> myWeightInterleavingNode = WeightInterleaving();
        auto op = std::static_pointer_cast<OperatorTensor>(myWeightInterleavingNode -> getOperator());
        op->associateInput(0,weight);
        op->setDataType(WeightInterleavingType<Aidge::DataType::Int4>::type);
        op->setDataFormat(DataFormat::NHWC);
        op->setBackend("cpu");
        myWeightInterleavingNode->forward();
        REQUIRE(*(op->getOutput(0)) == *expectedWeightInterleaving);
    }

    SECTION("Recipie ApplyWeightInterleaving") {

        // Weight [Cout = 2, H = 3, W = 3, Cin = 4]:
        std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(Array4D<std::int8_t,2,3,3,4> {
            {
                {
                    {
                        {-6,  0,  5, -8}, // 'A' '0' '5' '8' in hexadecimal format
                        { 5,  5,  4, -5}, // '5' '5' '4' 'B' in hexadecimal format
                        {-7, -1,  4, -7}  // '9' 'F' '4' '9' in hexadecimal format
                    },
                    {
                        { 3, -3, -3, -3}, // '3' 'D' 'D' 'D' in hexadecimal format
                        { 1,  3,  1, -1}, // '1' '3' '1' 'F' in hexadecimal format
                        { 7, -3, -1,  4}  // '7' 'D' 'F' '4' in hexadecimal format
                    },
                    {
                        {-1,  3,  5,  6}, // 'F' '3' '5' '6' in hexadecimal format
                        {-8,  4,  7,  1}, // '8' '4' '7' '1' in hexadecimal format
                        {-5,  0, -1, -2}  // 'B' '0' 'F' 'E' in hexadecimal format
                    }
                },
                {
                    {
                        { 2, -7,  7, -4}, // '2' '9' '7' 'C' in hexadecimal format
                        {-7,  3,  0,  2}, // '9' '3' '0' '2' in hexadecimal format
                        { 1, -1,  2,  3}  // '1' 'F' '2' '3' in hexadecimal format
                    },
                    {
                        {-1, -5, -3, -7}, // 'F' 'B' 'D' '9' in hexadecimal format
                        {-8,  3,  5, -1}, // '8' '3' '5' 'F' in hexadecimal format
                        {-7, -4, -6, -1}  // '9' 'C' 'A' 'F' in hexadecimal format
                    },
                    {
                        { 1,  7,  5, -1}, // '1' '7' '5' 'F' in hexadecimal format
                        { 1, -8,  1,  2}, // '1' '8' '1' '2' in hexadecimal format
                        {-1, -6, -3,  0}  // 'F' 'A' 'D' '0' in hexadecimal format
                    }
                }
            } 
        });
        
        std::shared_ptr<Tensor> expectedWeightInterleaving = std::make_shared<Tensor>(Array4D<std::int8_t,2,3,3,2> {
            {
                {
                    {
                        {static_cast<int8_t>(0xA0), static_cast<int8_t>(0x58)}, // 'A' '0' '5' '8' in hexadecimal format
                        {static_cast<int8_t>(0x55), static_cast<int8_t>(0x4B)}, // '5' '5' '4' 'B' in hexadecimal format
                        {static_cast<int8_t>(0x9F), static_cast<int8_t>(0x49)}  // '9' 'F' '4' '9' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0x3D), static_cast<int8_t>(0xDD)}, // '3' 'D' 'D' 'D' in hexadecimal format
                        {static_cast<int8_t>(0x13), static_cast<int8_t>(0x1F)}, // '1' '3' '1' 'F' in hexadecimal format
                        {static_cast<int8_t>(0x7D), static_cast<int8_t>(0xF4)}  // '7' 'D' 'F' '4' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0xF3), static_cast<int8_t>(0x56)}, // 'F' '3' '5' '6' in hexadecimal format
                        {static_cast<int8_t>(0x84), static_cast<int8_t>(0x71)}, // '8' '4' '7' '1' in hexadecimal format
                        {static_cast<int8_t>(0xB0), static_cast<int8_t>(0xFE)}  // 'B' '0' 'F' 'E' in hexadecimal format
                    }
                },
                {
                    {
                        {static_cast<int8_t>(0x29), static_cast<int8_t>(0x7C)}, // '2' '9' '7' 'C' in hexadecimal format
                        {static_cast<int8_t>(0x93), static_cast<int8_t>(0x02)}, // '9' '3' '0' '2' in hexadecimal format
                        {static_cast<int8_t>(0x1F), static_cast<int8_t>(0x23)}  // '1' 'F' '2' '3' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0xFB), static_cast<int8_t>(0xD9)}, // 'F' 'B' 'D' '9' in hexadecimal format
                        {static_cast<int8_t>(0x83), static_cast<int8_t>(0x5F)}, // '8' '3' '5' 'F' in hexadecimal format
                        {static_cast<int8_t>(0x9C), static_cast<int8_t>(0xAF)}  // '9' 'C' 'A' 'F' in hexadecimal format
                    },
                    {
                        {static_cast<int8_t>(0x17), static_cast<int8_t>(0x5F)}, // '1' '7' '5' 'F' in hexadecimal format
                        {static_cast<int8_t>(0x18), static_cast<int8_t>(0x12)}, // '1' '8' '1' '2' in hexadecimal format
                        {static_cast<int8_t>(0xFA), static_cast<int8_t>(0xD0)}  // 'F' 'A' 'D' '0' in hexadecimal format
                    }
                }
            } 
        });

        expectedWeightInterleaving->setDataFormat(Aidge::DataFormat::NHWC);
        expectedWeightInterleaving->setDataType(Aidge::DataType::Dual_Int4);

        // Create convolution node
        std::shared_ptr<Node> conv = Conv(4, 2, {3, 3}, "conv1");
        
        // Place the weight tensor in the weight producer of the conv
        auto weightProducer = conv->getParent(1);
        weightProducer->getOperator()->setOutput(0, weight);

        // Set dataType, dataformat and backend of convolution 
        conv->getOperator()->setDataFormat(Aidge::DataFormat::NHWC);
        conv->getOperator()->setDataType(Aidge::DataType::Int4);
        conv->getOperator()->setBackend("cpu");

        // Apply recipie
        applyWeightInterleaving(conv);

        // Compare the weight producer output tensor with the expected weights with interleaving
        auto newProdOp = std::static_pointer_cast<OperatorTensor>(conv->getParent(1)->getOperator());
        REQUIRE(*(newProdOp->getOutput(0)) == *expectedWeightInterleaving);
    }

}