Skip to content

Add a mechanism to inject faults in graph

Jerome Hue requested to merge jeromeh/aidge_core:ft-dev into dev

Context

This MR implements a fault injection mechanism to enable testing neural network inference under hardware fault conditions. Specifically, it focuses on simulating Single Event Upset (SEU) faults through bit flips in weights.

Detailed Major Changes

The fault injection mechanism works by inserting some fault operator nodes between existing nodes in the computation graph. A key feature is the ability to distribute faults proportionally across different tensors based on their sizes.

Future work

This MR defines a FaultLocation class, work has to be done so that the user could create a method returning the faults locations, and inject faults to theses specific locations. More work has to be done to support different types of faults, and to be able to have a unified function that that accepts different types of faults : inject_fault(graph, fault_type, fault_attributes) -> void.

Files Changed

  • New Header Files:

    • include/aidge/faults/FaultClass.hpp - Core fault types and location definitions
    • include/aidge/operator/NBitFlip.hpp - Random bit flip operator.
    • include/aidge/operator/FixedNBitFlip.hpp - Same as NBitFlip operator, but the affected weights do not change after creation of the operator.
  • New Implementation Files:

    • src/operator/NBitFlip.cpp
    • src/operator/FixedNBitFlip.cpp
    • src/recipes/InjectFault.cpp
  • Python Binding Files:

    • python_binding/operator/pybind_Fault.cpp - Python bindings for fault operators
  • Modified Files:

    • include/aidge/aidge.hpp - Added includes for new fault operators
    • include/aidge/recipes/Recipes.hpp - Added fault injection function declarations
    • python_binding/pybind_core.cpp - Added initialization for fault module
    • python_binding/recipes/pybind_Recipes.cpp - Added Python bindings for fault injection functions

Reference

The script shown in the next section was used to replicate the results of the original serie from Figure 3 (Average prediction accuracy of VGG16 with different compressions vs. BERs for all layers) of the following paper : https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8782505

Usage

Assuming you have a vgg16 onnx model :

import torch
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader

import aidge_core
import aidge_backend_cpu
import aidge_onnx

import numpy as np
import os
import argparse

MAX_ITER = 20
BATCH_SIZE = 32
data_dir = '/home/data/dataset/'
MODEL_PATH = "vgg16.onnx"
NB_RUNS = 1000

def torch_tensor_to_aidge(torch_tensor: torch.Tensor) -> aidge_core.Tensor:
    aidge_tensor = None
    numpy_tensor = torch_tensor.cpu().detach().numpy()
    aidge_tensor = aidge_core.Tensor(numpy_tensor)
    return aidge_tensor

def get_model_output(model: aidge_core.GraphView) -> np.array:
    output_aidge = np.array(list(model.get_output_nodes())[0].get_operator().get_output(0))
    return output_aidge

def prepare_test_loader():

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    test_dataset = datasets.CIFAR10(root=data_dir, train=False,
                                    transform=transform_test, download=False)

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=2)

    return test_loader


def main():

    parser = argparse.ArgumentParser(description="Number of bits to flip")
    parser.add_argument("number", type=int) 
    nb_faults = parser.parse_args().number

    print(f"Testing vgg16 with {nb_faults} faults")

    test_loader = prepare_test_loader()

    for run in range(NB_RUNS):

        vgg16 = aidge_onnx.load_onnx(MODEL_PATH)
        vgg16.set_datatype(aidge_core.dtype.float32)
        vgg16.set_backend("cpu")

        # Inject faults and recompile graph
        aidge_core.inject_fixed_bitflip(vgg16, nb_faults)
        vgg16.set_datatype(aidge_core.dtype.float32)
        vgg16.set_backend("cpu")

        # Fix shape output being set to int64 by default
        matchs = aidge_core.SinglePassGraphMatching(vgg16).match("Shape->Gather")
        for match_result in matchs:
            matched_graph = match_result.graph
            cast = aidge_core.Cast(aidge_core.dtype.float32)
            out_node = matched_graph.get_ordered_outputs()[0][0]
            vgg16.insert_parent(out_node, cast, 0, 0, 0)

        scheduler = aidge_core.SequentialScheduler(vgg16)
        print("--- Loaded and set backend for faults", flush=True)

        # Setup statistic variables
        device = torch.device('cpu')
        total_size = 0
        total_correct = 0
        iter_count  = 0;
        nb_catastrophic = 0;

        for images, labels in test_loader:
        
            iter_count = iter_count + 1 

            images, labels = images.to(device), labels.to(device)
            aidge_images = torch_tensor_to_aidge(images)
            aidge_images.set_datatype(aidge_core.dtype.float32)
            aidge_images.set_backend("cpu")

            scheduler.forward(data=[aidge_images])
            output_aidge = get_model_output(vgg16)
        
            # Compute accuracy
            batch_correct = (output_aidge.argmax(axis=1) == labels.flatten()).sum().item()
            accuracy = batch_correct / labels.size(0) 
            total_size += labels.size(0)
            total_correct += batch_correct
            total_acc = total_correct / total_size
        
            if(accuracy <= 0.10):
                nb_catastrophic = nb_catastrophic + 1
        
            ratio_catastrophic = nb_catastrophic / iter_count
            
            print(f"-------------------------- Iteration {iter_count}/{MAX_ITER} | {run} --------------------------", flush=True)
            print(f"Accuracy (this run)             : {accuracy} ({batch_correct}/{labels.size(0)})", flush=True)
            print(f"Accuracy (all runs)             : {total_acc} ({total_correct}/{total_size})", flush=True)
            print(f"Catastrophic accuracty rate     : {ratio_catastrophic}", flush=True)
        
            # Break and go to next run
            if (iter_count == MAX_ITER):
                break
        
if __name__ == "__main__":
    main()
Edited by Jerome Hue

Merge request reports

Loading