package firewall import ( "context" "errors" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "git.netcell-it.de/projekte/edgeguard-native/internal/models" ) var ErrAddressGroupNotFound = errors.New("address group not found") type AddressGroupsRepo struct { Pool *pgxpool.Pool } func NewAddressGroupsRepo(pool *pgxpool.Pool) *AddressGroupsRepo { return &AddressGroupsRepo{Pool: pool} } const addrGrpBaseSelect = ` SELECT id, name, description, created_at, updated_at FROM firewall_address_groups ` // List returns all groups with their MemberIDs populated via a // single follow-up query. Two roundtrips total — keeps the SQL // simple at the cost of one extra query for an inherently small // table. func (r *AddressGroupsRepo) List(ctx context.Context) ([]models.FirewallAddressGroup, error) { rows, err := r.Pool.Query(ctx, addrGrpBaseSelect+" ORDER BY name ASC") if err != nil { return nil, err } defer rows.Close() groups := []models.FirewallAddressGroup{} byID := map[int64]int{} for rows.Next() { g, err := scanAddrGrp(rows) if err != nil { return nil, err } byID[g.ID] = len(groups) groups = append(groups, *g) } if err := rows.Err(); err != nil { return nil, err } mRows, err := r.Pool.Query(ctx, `SELECT group_id, object_id FROM firewall_address_group_members`) if err != nil { return nil, err } defer mRows.Close() for mRows.Next() { var gid, oid int64 if err := mRows.Scan(&gid, &oid); err != nil { return nil, err } if idx, ok := byID[gid]; ok { groups[idx].MemberIDs = append(groups[idx].MemberIDs, oid) } } return groups, mRows.Err() } func (r *AddressGroupsRepo) Get(ctx context.Context, id int64) (*models.FirewallAddressGroup, error) { row := r.Pool.QueryRow(ctx, addrGrpBaseSelect+" WHERE id = $1", id) g, err := scanAddrGrp(row) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrAddressGroupNotFound } return nil, err } mRows, err := r.Pool.Query(ctx, `SELECT object_id FROM firewall_address_group_members WHERE group_id = $1 ORDER BY object_id`, id) if err != nil { return nil, err } defer mRows.Close() for mRows.Next() { var oid int64 if err := mRows.Scan(&oid); err != nil { return nil, err } g.MemberIDs = append(g.MemberIDs, oid) } return g, mRows.Err() } // Create inserts the group and (optionally) its members atomically. func (r *AddressGroupsRepo) Create(ctx context.Context, g models.FirewallAddressGroup) (*models.FirewallAddressGroup, error) { tx, err := r.Pool.Begin(ctx) if err != nil { return nil, err } defer tx.Rollback(ctx) row := tx.QueryRow(ctx, ` INSERT INTO firewall_address_groups (name, description) VALUES ($1, $2) RETURNING id, name, description, created_at, updated_at`, g.Name, g.Description) out, err := scanAddrGrp(row) if err != nil { return nil, err } if len(g.MemberIDs) > 0 { if err := insertAddrGrpMembers(ctx, tx, out.ID, g.MemberIDs); err != nil { return nil, err } out.MemberIDs = append([]int64{}, g.MemberIDs...) } return out, tx.Commit(ctx) } // Update replaces both the metadata and the membership set // atomically. Pass the desired complete member list; absent IDs are // removed. func (r *AddressGroupsRepo) Update(ctx context.Context, id int64, g models.FirewallAddressGroup) (*models.FirewallAddressGroup, error) { tx, err := r.Pool.Begin(ctx) if err != nil { return nil, err } defer tx.Rollback(ctx) row := tx.QueryRow(ctx, ` UPDATE firewall_address_groups SET name = $1, description = $2, updated_at = NOW() WHERE id = $3 RETURNING id, name, description, created_at, updated_at`, g.Name, g.Description, id) out, err := scanAddrGrp(row) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrAddressGroupNotFound } return nil, err } if _, err := tx.Exec(ctx, `DELETE FROM firewall_address_group_members WHERE group_id = $1`, id); err != nil { return nil, err } if len(g.MemberIDs) > 0 { if err := insertAddrGrpMembers(ctx, tx, id, g.MemberIDs); err != nil { return nil, err } out.MemberIDs = append([]int64{}, g.MemberIDs...) } return out, tx.Commit(ctx) } func (r *AddressGroupsRepo) Delete(ctx context.Context, id int64) error { tag, err := r.Pool.Exec(ctx, `DELETE FROM firewall_address_groups WHERE id = $1`, id) if err != nil { return err } if tag.RowsAffected() == 0 { return ErrAddressGroupNotFound } return nil } func insertAddrGrpMembers(ctx context.Context, tx pgx.Tx, gid int64, members []int64) error { for _, oid := range members { if _, err := tx.Exec(ctx, `INSERT INTO firewall_address_group_members (group_id, object_id) VALUES ($1, $2) ON CONFLICT DO NOTHING`, gid, oid); err != nil { return err } } return nil } func scanAddrGrp(row interface{ Scan(...any) error }) (*models.FirewallAddressGroup, error) { var g models.FirewallAddressGroup if err := row.Scan( &g.ID, &g.Name, &g.Description, &g.CreatedAt, &g.UpdatedAt, ); err != nil { return nil, err } return &g, nil }