Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
postgres.go 7.35 KiB
package database

import (
	"context"
	"errors"
	"fmt"
	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgxpool"
	"github.com/sethvargo/go-retry"
	log "github.com/sirupsen/logrus"
	entity2 "gitlab.eclipse.org/eclipse/xfsc/organisational-credential-manager-w-stack/status-list-service/internal/entity"
	"regexp"
	"time"
)

// TODO: queries as constants?

type postgresConnection struct {
	conn             *pgxpool.Pool
	blockSizeInBytes int
}

func newPostgresConnection(username string, password string, host string, port int, db string, blockSizeInBytes int) (DbConnection, error) {
	connUrl := fmt.Sprintf("postgres://%s:%s@%s:%d/%s", username, password, host, port, db)
	conn, err := pgxpool.New(context.Background(), connUrl)
	if err != nil {
		return nil, err
	}

	bf := retry.NewFibonacci(time.Millisecond * 500)
	bf = retry.WithCappedDuration(time.Second*30, bf)
	bf = retry.WithJitter(time.Millisecond*50, bf)
	bf = retry.WithMaxDuration(time.Minute*2, bf)

	if err = retry.Do(context.Background(), bf, func(_ context.Context) error {
		if err = conn.Ping(context.Background()); err != nil {
			err = fmt.Errorf("connection check failed: %w", err)
			log.Error(err)
			err = retry.RetryableError(err)
			return err
		}
		return nil
	}); err != nil {
		return nil, err
	}

	return &postgresConnection{
		conn:             conn,
		blockSizeInBytes: blockSizeInBytes,
	}, nil
}

func (pc *postgresConnection) AllocateIndexInCurrentBlock(ctx context.Context, tenantId string) (*entity2.StatusData, error) {
	tx, err := pc.conn.BeginTx(ctx, pgx.TxOptions{
		IsoLevel:       pgx.ReadCommitted,
		AccessMode:     pgx.ReadWrite,
		DeferrableMode: pgx.NotDeferrable,
	})
	if err != nil {
		return nil, fmt.Errorf("could not start transaction: %w", err)
	}
	defer tx.Rollback(ctx)

	if err != nil {
		return nil, fmt.Errorf("error creating transaction: %w", err)
	}

	tableName, err := createTableName(tenantId)
	if err != nil {
		return nil, err
	}

	selectQuery := fmt.Sprintf("SELECT blockID, block, free FROM %s WHERE free > 0 FOR UPDATE LIMIT 1", tableName)
	rows, err := tx.Query(ctx, selectQuery)
	if err != nil {
		return nil, fmt.Errorf("error while select current block from the database: %w", err)
	}
	// not optimized for performance cause of reflection
	databaseRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[entity2.Block])
	if err != nil {
		return nil, fmt.Errorf("error while collecting current block from rows: %w", err)
	}

	if len(databaseRows) == 0 {
		// no current block -> create new one and allocate index
		newBlock := entity2.NewBlock(pc.blockSizeInBytes)

		index, err := newBlock.AllocateNextFreeIndex()
		if err != nil {
			return nil, fmt.Errorf("error allocating next free index from new block: %w", err)
		}

		insertQuery := fmt.Sprintf("INSERT INTO %s (block, free) VALUES ($1, $2) RETURNING blockID", tableName)
		var blockId int
		if err = tx.
			QueryRow(ctx, insertQuery, newBlock.Block, newBlock.Free).
			Scan(&blockId); err != nil {
			return nil, fmt.Errorf("error inserting new block into the database: %w", err)
		}

		if err := tx.Commit(ctx); err != nil {
			return nil, fmt.Errorf("error commiting transaction: %w", err)
		}

		return entity2.NewStatusData(index, blockId), nil
	}

	// allocate index in current block
	currentBlock := databaseRows[0]

	index, err := currentBlock.AllocateNextFreeIndex()
	if err != nil {
		return nil, fmt.Errorf("error allocating next free index from current block: %w", err)
	}

	updateQuery := fmt.Sprintf("UPDATE %s%s SET block = $1, free = $2 WHERE blockID = $3", TablePrefix, tenantId)
	if _, err := tx.Exec(ctx, updateQuery, currentBlock.Block, currentBlock.Free, currentBlock.BlockId); err != nil {
		return nil, fmt.Errorf("error updating block in the database: %w", err)
	}

	if err := tx.Commit(ctx); err != nil {
		return nil, fmt.Errorf("error commiting transaction: %w", err)
	}

	return entity2.NewStatusData(index, currentBlock.BlockId), nil
}

func (pc *postgresConnection) RevokeCredentialInSpecifiedBlock(ctx context.Context, tenantId string, blockId int, index int) error {
	tx, err := pc.conn.BeginTx(ctx, pgx.TxOptions{
		IsoLevel:       pgx.ReadCommitted,
		AccessMode:     pgx.ReadWrite,
		DeferrableMode: pgx.NotDeferrable,
	})
	if err != nil {
		return fmt.Errorf("could not start transaction: %w", err)
	}
	defer tx.Rollback(ctx)

	tableName, err := createTableName(tenantId)
	if err != nil {
		return err
	}

	fmt.Println(tableName)
	selectQuery := fmt.Sprintf("SELECT blockID, block, free FROM %s WHERE blockID = $1 FOR UPDATE LIMIT 1", tableName)
	rows, err := tx.Query(ctx, selectQuery, blockId)
	if err != nil {
		return fmt.Errorf("error while select specified block from the database: %w", err)
	}
	databaseRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[entity2.Block])
	if err != nil {
		return fmt.Errorf("error while getting specified block from rows: %w", err)
	}

	if len(databaseRows) == 0 {
		return fmt.Errorf("blockId %d does not exist in database", blockId)
	}

	specifiedBlock := databaseRows[0]

	specifiedBlock.RevokeAtIndex(index)

	updateQuery := fmt.Sprintf("UPDATE %s%s SET block = $1 WHERE blockID = $2", TablePrefix, tenantId)
	_, err = tx.Exec(ctx, updateQuery, specifiedBlock.Block, specifiedBlock.BlockId)
	if err != nil {
		return fmt.Errorf("error updating block in the database: %w", err)
	}

	err = tx.Commit(ctx)
	if err != nil {
		return fmt.Errorf("error commiting transaction: %w", err)
	}

	return nil
}

func (pc *postgresConnection) CreateTableForTenantIdIfNotExists(ctx context.Context, tenantId string) error {
	tx, err := pc.conn.BeginTx(ctx, pgx.TxOptions{
		IsoLevel:       pgx.ReadCommitted,
		AccessMode:     pgx.ReadWrite,
		DeferrableMode: pgx.NotDeferrable,
	})
	if err != nil {
		return fmt.Errorf("could not start transaction: %w", err)
	}
	defer tx.Rollback(ctx)

	var n int64
	exists := true

	_, err = tx.Exec(ctx, "LOCK TABLE information_schema.tables IN EXCLUSIVE MODE")
	if err != nil {
		return fmt.Errorf("could not lock table: %w", err)
	}

	tableName, err := createTableName(tenantId)
	if err != nil {
		return err
	}

	const tableExistQuery = "SELECT 1 FROM information_schema.tables WHERE table_name = $1"
	if err = tx.QueryRow(ctx, tableExistQuery, tableName).Scan(&n); err != nil {
		if errors.Is(err, pgx.ErrNoRows) {
			exists = false
		} else {
			return fmt.Errorf("error query for table name: %w", err)
		}
	}

	if !exists {
		createTableQuery := fmt.Sprintf("CREATE TABLE %s (blockID SERIAL PRIMARY KEY, block BYTEA, free INT)", tableName)
		_, err = tx.Exec(ctx, createTableQuery)
		if err != nil {
			return fmt.Errorf("could not create new table for tenantID: %w", err)
		}

		newBlock := entity2.NewBlock(pc.blockSizeInBytes)

		insertQuery := fmt.Sprintf("INSERT INTO %s (block, free) VALUES ($1, $2)", tableName)
		_, err = tx.Exec(ctx, insertQuery, newBlock.Block, newBlock.Free)
		if err != nil {
			return fmt.Errorf("error inserting new block into the database: %w", err)
		}
	}

	err = tx.Commit(ctx)
	if err != nil {
		return fmt.Errorf("error commiting transaction: %w", err)
	}

	return nil
}

func (pc *postgresConnection) Close() {
	pc.conn.Close()
}

func createTableName(tenantId string) (string, error) {
	tableName := TablePrefix + tenantId
	isValid, err := regexp.Match("^[a-zA-Z0-9_]+$", []byte(tableName))
	if err != nil {
		return "", fmt.Errorf("error while checking tableName: %w", err)
	}

	if !isValid {
		return "", fmt.Errorf("tableName '%s' is not valid", tableName)
	}

	return tableName, nil
}