diff --git a/internal/models/firewall_address_group.go b/internal/models/firewall_address_group.go new file mode 100644 index 0000000..816533b --- /dev/null +++ b/internal/models/firewall_address_group.go @@ -0,0 +1,18 @@ +package models + +import "time" + +type FirewallAddressGroup struct { + ID int64 `gorm:"primaryKey" json:"id"` + Name string `gorm:"column:name;uniqueIndex" json:"name"` + Description *string `gorm:"column:description" json:"description,omitempty"` + CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` + + // MemberIDs is filled by the repo's Get/List joiners — not a real + // column. JSON-omitted when empty so the bare-create response + // stays terse. + MemberIDs []int64 `gorm:"-" json:"member_ids,omitempty"` +} + +func (FirewallAddressGroup) TableName() string { return "firewall_address_groups" } diff --git a/internal/models/firewall_address_object.go b/internal/models/firewall_address_object.go new file mode 100644 index 0000000..637ff04 --- /dev/null +++ b/internal/models/firewall_address_object.go @@ -0,0 +1,15 @@ +package models + +import "time" + +type FirewallAddressObject struct { + ID int64 `gorm:"primaryKey" json:"id"` + Name string `gorm:"column:name;uniqueIndex" json:"name"` + Kind string `gorm:"column:kind" json:"kind"` // host|network|range|fqdn + Value string `gorm:"column:value" json:"value"` + Description *string `gorm:"column:description" json:"description,omitempty"` + CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` +} + +func (FirewallAddressObject) TableName() string { return "firewall_address_objects" } diff --git a/internal/models/firewall_nat_rule.go b/internal/models/firewall_nat_rule.go new file mode 100644 index 0000000..413f417 --- /dev/null +++ b/internal/models/firewall_nat_rule.go @@ -0,0 +1,40 @@ +package models + +import "time" + +// FirewallNATRule covers the three nft NAT shapes in one table: +// +// - kind=dnat: in_zone + match_dport_* → target_addr [+ target_port_*] +// (port-forward incoming traffic) +// - kind=snat: out_zone + match_src_cidr → target_addr +// (rewrite source IP to a fixed address) +// - kind=masquerade: out_zone [+ match_src_cidr] +// (rewrite source to out-iface IP — typical lan→wan) +// +// Validation of kind-specific field combinations lives in the +// handler. +type FirewallNATRule struct { + ID int64 `gorm:"primaryKey" json:"id"` + Name *string `gorm:"column:name" json:"name,omitempty"` + Priority int `gorm:"column:priority" json:"priority"` + Enabled bool `gorm:"column:enabled" json:"enabled"` + Kind string `gorm:"column:kind" json:"kind"` // dnat|snat|masquerade + + InZone *string `gorm:"column:in_zone" json:"in_zone,omitempty"` + OutZone *string `gorm:"column:out_zone" json:"out_zone,omitempty"` + Proto *string `gorm:"column:proto" json:"proto,omitempty"` + MatchSrcCIDR *string `gorm:"column:match_src_cidr" json:"match_src_cidr,omitempty"` + MatchDstCIDR *string `gorm:"column:match_dst_cidr" json:"match_dst_cidr,omitempty"` + MatchDPortStart *int `gorm:"column:match_dport_start" json:"match_dport_start,omitempty"` + MatchDPortEnd *int `gorm:"column:match_dport_end" json:"match_dport_end,omitempty"` + + TargetAddr *string `gorm:"column:target_addr" json:"target_addr,omitempty"` + TargetPortStart *int `gorm:"column:target_port_start" json:"target_port_start,omitempty"` + TargetPortEnd *int `gorm:"column:target_port_end" json:"target_port_end,omitempty"` + + Comment *string `gorm:"column:comment" json:"comment,omitempty"` + CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` +} + +func (FirewallNATRule) TableName() string { return "firewall_nat_rules" } diff --git a/internal/models/firewall_rule.go b/internal/models/firewall_rule.go new file mode 100644 index 0000000..74154f3 --- /dev/null +++ b/internal/models/firewall_rule.go @@ -0,0 +1,43 @@ +package models + +import "time" + +// FirewallRule is the v2 (Fortigate-style) policy row. Source and +// destination each carry exactly one of: +// - AddressObjectID → primitive address object +// - AddressGroupID → address group +// - CIDR → inline CIDR +// - all three nil → "any" +// +// Same rule applies to ServiceObjectID / ServiceGroupID — exactly one +// or both nil for "any service". +// +// Validation lives in the handler layer (DB doesn't enforce +// "exactly one" because expressive CHECK constraints get unwieldy). +type FirewallRule struct { + ID int64 `gorm:"primaryKey" json:"id"` + Name *string `gorm:"column:name" json:"name,omitempty"` + Priority int `gorm:"column:priority" json:"priority"` + Enabled bool `gorm:"column:enabled" json:"enabled"` + Action string `gorm:"column:action" json:"action"` // accept|drop|reject + + SrcZone string `gorm:"column:src_zone" json:"src_zone"` + SrcAddressObjectID *int64 `gorm:"column:src_address_object_id" json:"src_address_object_id,omitempty"` + SrcAddressGroupID *int64 `gorm:"column:src_address_group_id" json:"src_address_group_id,omitempty"` + SrcCIDR *string `gorm:"column:src_cidr" json:"src_cidr,omitempty"` + + DstZone string `gorm:"column:dst_zone" json:"dst_zone"` + DstAddressObjectID *int64 `gorm:"column:dst_address_object_id" json:"dst_address_object_id,omitempty"` + DstAddressGroupID *int64 `gorm:"column:dst_address_group_id" json:"dst_address_group_id,omitempty"` + DstCIDR *string `gorm:"column:dst_cidr" json:"dst_cidr,omitempty"` + + ServiceObjectID *int64 `gorm:"column:service_object_id" json:"service_object_id,omitempty"` + ServiceGroupID *int64 `gorm:"column:service_group_id" json:"service_group_id,omitempty"` + + Log bool `gorm:"column:log" json:"log"` + Comment *string `gorm:"column:comment" json:"comment,omitempty"` + CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` +} + +func (FirewallRule) TableName() string { return "firewall_rules" } diff --git a/internal/models/firewall_service.go b/internal/models/firewall_service.go new file mode 100644 index 0000000..38b9217 --- /dev/null +++ b/internal/models/firewall_service.go @@ -0,0 +1,17 @@ +package models + +import "time" + +type FirewallService struct { + ID int64 `gorm:"primaryKey" json:"id"` + Name string `gorm:"column:name;uniqueIndex" json:"name"` + Proto string `gorm:"column:proto" json:"proto"` // tcp|udp|icmp|icmpv6|any + PortStart *int `gorm:"column:port_start" json:"port_start,omitempty"` + PortEnd *int `gorm:"column:port_end" json:"port_end,omitempty"` + Builtin bool `gorm:"column:builtin" json:"builtin"` + Description *string `gorm:"column:description" json:"description,omitempty"` + CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` +} + +func (FirewallService) TableName() string { return "firewall_services" } diff --git a/internal/models/firewall_service_group.go b/internal/models/firewall_service_group.go new file mode 100644 index 0000000..8f66ef7 --- /dev/null +++ b/internal/models/firewall_service_group.go @@ -0,0 +1,15 @@ +package models + +import "time" + +type FirewallServiceGroup struct { + ID int64 `gorm:"primaryKey" json:"id"` + Name string `gorm:"column:name;uniqueIndex" json:"name"` + Description *string `gorm:"column:description" json:"description,omitempty"` + CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` + + MemberIDs []int64 `gorm:"-" json:"member_ids,omitempty"` +} + +func (FirewallServiceGroup) TableName() string { return "firewall_service_groups" } diff --git a/internal/services/firewall/addressgroups.go b/internal/services/firewall/addressgroups.go new file mode 100644 index 0000000..0c81ba9 --- /dev/null +++ b/internal/services/firewall/addressgroups.go @@ -0,0 +1,185 @@ +package firewall + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "git.netcell-it.de/projekte/edgeguard-native/internal/models" +) + +var ErrAddressGroupNotFound = errors.New("address group not found") + +type AddressGroupsRepo struct { + Pool *pgxpool.Pool +} + +func NewAddressGroupsRepo(pool *pgxpool.Pool) *AddressGroupsRepo { + return &AddressGroupsRepo{Pool: pool} +} + +const addrGrpBaseSelect = ` +SELECT id, name, description, created_at, updated_at +FROM firewall_address_groups +` + +// List returns all groups with their MemberIDs populated via a +// single follow-up query. Two roundtrips total — keeps the SQL +// simple at the cost of one extra query for an inherently small +// table. +func (r *AddressGroupsRepo) List(ctx context.Context) ([]models.FirewallAddressGroup, error) { + rows, err := r.Pool.Query(ctx, addrGrpBaseSelect+" ORDER BY name ASC") + if err != nil { + return nil, err + } + defer rows.Close() + groups := []models.FirewallAddressGroup{} + byID := map[int64]int{} + for rows.Next() { + g, err := scanAddrGrp(rows) + if err != nil { + return nil, err + } + byID[g.ID] = len(groups) + groups = append(groups, *g) + } + if err := rows.Err(); err != nil { + return nil, err + } + + mRows, err := r.Pool.Query(ctx, `SELECT group_id, object_id FROM firewall_address_group_members`) + if err != nil { + return nil, err + } + defer mRows.Close() + for mRows.Next() { + var gid, oid int64 + if err := mRows.Scan(&gid, &oid); err != nil { + return nil, err + } + if idx, ok := byID[gid]; ok { + groups[idx].MemberIDs = append(groups[idx].MemberIDs, oid) + } + } + return groups, mRows.Err() +} + +func (r *AddressGroupsRepo) Get(ctx context.Context, id int64) (*models.FirewallAddressGroup, error) { + row := r.Pool.QueryRow(ctx, addrGrpBaseSelect+" WHERE id = $1", id) + g, err := scanAddrGrp(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrAddressGroupNotFound + } + return nil, err + } + mRows, err := r.Pool.Query(ctx, + `SELECT object_id FROM firewall_address_group_members WHERE group_id = $1 ORDER BY object_id`, id) + if err != nil { + return nil, err + } + defer mRows.Close() + for mRows.Next() { + var oid int64 + if err := mRows.Scan(&oid); err != nil { + return nil, err + } + g.MemberIDs = append(g.MemberIDs, oid) + } + return g, mRows.Err() +} + +// Create inserts the group and (optionally) its members atomically. +func (r *AddressGroupsRepo) Create(ctx context.Context, g models.FirewallAddressGroup) (*models.FirewallAddressGroup, error) { + tx, err := r.Pool.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback(ctx) + + row := tx.QueryRow(ctx, ` +INSERT INTO firewall_address_groups (name, description) +VALUES ($1, $2) +RETURNING id, name, description, created_at, updated_at`, + g.Name, g.Description) + out, err := scanAddrGrp(row) + if err != nil { + return nil, err + } + if len(g.MemberIDs) > 0 { + if err := insertAddrGrpMembers(ctx, tx, out.ID, g.MemberIDs); err != nil { + return nil, err + } + out.MemberIDs = append([]int64{}, g.MemberIDs...) + } + return out, tx.Commit(ctx) +} + +// Update replaces both the metadata and the membership set +// atomically. Pass the desired complete member list; absent IDs are +// removed. +func (r *AddressGroupsRepo) Update(ctx context.Context, id int64, g models.FirewallAddressGroup) (*models.FirewallAddressGroup, error) { + tx, err := r.Pool.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback(ctx) + + row := tx.QueryRow(ctx, ` +UPDATE firewall_address_groups SET name = $1, description = $2, updated_at = NOW() +WHERE id = $3 +RETURNING id, name, description, created_at, updated_at`, + g.Name, g.Description, id) + out, err := scanAddrGrp(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrAddressGroupNotFound + } + return nil, err + } + if _, err := tx.Exec(ctx, `DELETE FROM firewall_address_group_members WHERE group_id = $1`, id); err != nil { + return nil, err + } + if len(g.MemberIDs) > 0 { + if err := insertAddrGrpMembers(ctx, tx, id, g.MemberIDs); err != nil { + return nil, err + } + out.MemberIDs = append([]int64{}, g.MemberIDs...) + } + return out, tx.Commit(ctx) +} + +func (r *AddressGroupsRepo) Delete(ctx context.Context, id int64) error { + tag, err := r.Pool.Exec(ctx, `DELETE FROM firewall_address_groups WHERE id = $1`, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return ErrAddressGroupNotFound + } + return nil +} + +func insertAddrGrpMembers(ctx context.Context, tx pgx.Tx, gid int64, members []int64) error { + for _, oid := range members { + if _, err := tx.Exec(ctx, + `INSERT INTO firewall_address_group_members (group_id, object_id) VALUES ($1, $2) ON CONFLICT DO NOTHING`, + gid, oid); err != nil { + return err + } + } + return nil +} + +func scanAddrGrp(row interface{ Scan(...any) error }) (*models.FirewallAddressGroup, error) { + var g models.FirewallAddressGroup + if err := row.Scan( + &g.ID, &g.Name, &g.Description, + &g.CreatedAt, &g.UpdatedAt, + ); err != nil { + return nil, err + } + return &g, nil +} diff --git a/internal/services/firewall/addressobjects.go b/internal/services/firewall/addressobjects.go new file mode 100644 index 0000000..72459a9 --- /dev/null +++ b/internal/services/firewall/addressobjects.go @@ -0,0 +1,103 @@ +package firewall + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "git.netcell-it.de/projekte/edgeguard-native/internal/models" +) + +var ErrAddressObjectNotFound = errors.New("address object not found") + +type AddressObjectsRepo struct { + Pool *pgxpool.Pool +} + +func NewAddressObjectsRepo(pool *pgxpool.Pool) *AddressObjectsRepo { + return &AddressObjectsRepo{Pool: pool} +} + +const addrObjBaseSelect = ` +SELECT id, name, kind, value, description, created_at, updated_at +FROM firewall_address_objects +` + +func (r *AddressObjectsRepo) List(ctx context.Context) ([]models.FirewallAddressObject, error) { + rows, err := r.Pool.Query(ctx, addrObjBaseSelect+" ORDER BY name ASC") + if err != nil { + return nil, err + } + defer rows.Close() + out := make([]models.FirewallAddressObject, 0, 8) + for rows.Next() { + o, err := scanAddrObj(rows) + if err != nil { + return nil, err + } + out = append(out, *o) + } + return out, rows.Err() +} + +func (r *AddressObjectsRepo) Get(ctx context.Context, id int64) (*models.FirewallAddressObject, error) { + row := r.Pool.QueryRow(ctx, addrObjBaseSelect+" WHERE id = $1", id) + o, err := scanAddrObj(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrAddressObjectNotFound + } + return nil, err + } + return o, nil +} + +func (r *AddressObjectsRepo) Create(ctx context.Context, o models.FirewallAddressObject) (*models.FirewallAddressObject, error) { + row := r.Pool.QueryRow(ctx, ` +INSERT INTO firewall_address_objects (name, kind, value, description) +VALUES ($1, $2, $3, $4) +RETURNING id, name, kind, value, description, created_at, updated_at`, + o.Name, o.Kind, o.Value, o.Description) + return scanAddrObj(row) +} + +func (r *AddressObjectsRepo) Update(ctx context.Context, id int64, o models.FirewallAddressObject) (*models.FirewallAddressObject, error) { + row := r.Pool.QueryRow(ctx, ` +UPDATE firewall_address_objects SET + name = $1, kind = $2, value = $3, description = $4, updated_at = NOW() +WHERE id = $5 +RETURNING id, name, kind, value, description, created_at, updated_at`, + o.Name, o.Kind, o.Value, o.Description, id) + out, err := scanAddrObj(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrAddressObjectNotFound + } + return nil, err + } + return out, nil +} + +func (r *AddressObjectsRepo) Delete(ctx context.Context, id int64) error { + tag, err := r.Pool.Exec(ctx, `DELETE FROM firewall_address_objects WHERE id = $1`, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return ErrAddressObjectNotFound + } + return nil +} + +func scanAddrObj(row interface{ Scan(...any) error }) (*models.FirewallAddressObject, error) { + var o models.FirewallAddressObject + if err := row.Scan( + &o.ID, &o.Name, &o.Kind, &o.Value, + &o.Description, &o.CreatedAt, &o.UpdatedAt, + ); err != nil { + return nil, err + } + return &o, nil +} diff --git a/internal/services/firewall/doc.go b/internal/services/firewall/doc.go new file mode 100644 index 0000000..aa784a2 --- /dev/null +++ b/internal/services/firewall/doc.go @@ -0,0 +1,7 @@ +// Package firewall holds the v2 (Fortigate-style) firewall data +// repos: address objects + groups, services + groups, policy rules, +// and NAT rules. Each entity has its own *.go file; the public +// surface is one Repo per entity, all sharing the same *pgxpool.Pool. +// +// Render-Logik (Joins zu nftables) wohnt in internal/firewall/. +package firewall diff --git a/internal/services/firewall/natrules.go b/internal/services/firewall/natrules.go new file mode 100644 index 0000000..0fdac76 --- /dev/null +++ b/internal/services/firewall/natrules.go @@ -0,0 +1,139 @@ +package firewall + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "git.netcell-it.de/projekte/edgeguard-native/internal/models" +) + +var ErrNATRuleNotFound = errors.New("nat rule not found") + +type NATRulesRepo struct { + Pool *pgxpool.Pool +} + +func NewNATRulesRepo(pool *pgxpool.Pool) *NATRulesRepo { return &NATRulesRepo{Pool: pool} } + +const natRuleBaseSelect = ` +SELECT id, name, priority, enabled, kind, + in_zone, out_zone, proto, + match_src_cidr, match_dst_cidr, match_dport_start, match_dport_end, + target_addr, target_port_start, target_port_end, + comment, created_at, updated_at +FROM firewall_nat_rules +` + +func (r *NATRulesRepo) List(ctx context.Context) ([]models.FirewallNATRule, error) { + rows, err := r.Pool.Query(ctx, natRuleBaseSelect+" ORDER BY priority DESC, id ASC") + if err != nil { + return nil, err + } + defer rows.Close() + out := make([]models.FirewallNATRule, 0, 8) + for rows.Next() { + x, err := scanNATRule(rows) + if err != nil { + return nil, err + } + out = append(out, *x) + } + return out, rows.Err() +} + +func (r *NATRulesRepo) Get(ctx context.Context, id int64) (*models.FirewallNATRule, error) { + row := r.Pool.QueryRow(ctx, natRuleBaseSelect+" WHERE id = $1", id) + x, err := scanNATRule(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNATRuleNotFound + } + return nil, err + } + return x, nil +} + +func (r *NATRulesRepo) Create(ctx context.Context, x models.FirewallNATRule) (*models.FirewallNATRule, error) { + row := r.Pool.QueryRow(ctx, ` +INSERT INTO firewall_nat_rules ( + name, priority, enabled, kind, + in_zone, out_zone, proto, + match_src_cidr, match_dst_cidr, match_dport_start, match_dport_end, + target_addr, target_port_start, target_port_end, + comment +) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, + $8, $9, $10, $11, + $12, $13, $14, + $15 +) +RETURNING id, name, priority, enabled, kind, + in_zone, out_zone, proto, + match_src_cidr, match_dst_cidr, match_dport_start, match_dport_end, + target_addr, target_port_start, target_port_end, + comment, created_at, updated_at`, + x.Name, x.Priority, x.Enabled, x.Kind, + x.InZone, x.OutZone, x.Proto, + x.MatchSrcCIDR, x.MatchDstCIDR, x.MatchDPortStart, x.MatchDPortEnd, + x.TargetAddr, x.TargetPortStart, x.TargetPortEnd, + x.Comment) + return scanNATRule(row) +} + +func (r *NATRulesRepo) Update(ctx context.Context, id int64, x models.FirewallNATRule) (*models.FirewallNATRule, error) { + row := r.Pool.QueryRow(ctx, ` +UPDATE firewall_nat_rules SET + name = $1, priority = $2, enabled = $3, kind = $4, + in_zone = $5, out_zone = $6, proto = $7, + match_src_cidr = $8, match_dst_cidr = $9, match_dport_start = $10, match_dport_end = $11, + target_addr = $12, target_port_start = $13, target_port_end = $14, + comment = $15, updated_at = NOW() +WHERE id = $16 +RETURNING id, name, priority, enabled, kind, + in_zone, out_zone, proto, + match_src_cidr, match_dst_cidr, match_dport_start, match_dport_end, + target_addr, target_port_start, target_port_end, + comment, created_at, updated_at`, + x.Name, x.Priority, x.Enabled, x.Kind, + x.InZone, x.OutZone, x.Proto, + x.MatchSrcCIDR, x.MatchDstCIDR, x.MatchDPortStart, x.MatchDPortEnd, + x.TargetAddr, x.TargetPortStart, x.TargetPortEnd, + x.Comment, id) + out, err := scanNATRule(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNATRuleNotFound + } + return nil, err + } + return out, nil +} + +func (r *NATRulesRepo) Delete(ctx context.Context, id int64) error { + tag, err := r.Pool.Exec(ctx, `DELETE FROM firewall_nat_rules WHERE id = $1`, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return ErrNATRuleNotFound + } + return nil +} + +func scanNATRule(row interface{ Scan(...any) error }) (*models.FirewallNATRule, error) { + var x models.FirewallNATRule + if err := row.Scan( + &x.ID, &x.Name, &x.Priority, &x.Enabled, &x.Kind, + &x.InZone, &x.OutZone, &x.Proto, + &x.MatchSrcCIDR, &x.MatchDstCIDR, &x.MatchDPortStart, &x.MatchDPortEnd, + &x.TargetAddr, &x.TargetPortStart, &x.TargetPortEnd, + &x.Comment, &x.CreatedAt, &x.UpdatedAt, + ); err != nil { + return nil, err + } + return &x, nil +} diff --git a/internal/services/firewall/rules.go b/internal/services/firewall/rules.go new file mode 100644 index 0000000..fddafd3 --- /dev/null +++ b/internal/services/firewall/rules.go @@ -0,0 +1,139 @@ +package firewall + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "git.netcell-it.de/projekte/edgeguard-native/internal/models" +) + +var ErrRuleNotFound = errors.New("firewall rule not found") + +type RulesRepo struct { + Pool *pgxpool.Pool +} + +func NewRulesRepo(pool *pgxpool.Pool) *RulesRepo { return &RulesRepo{Pool: pool} } + +const ruleBaseSelect = ` +SELECT id, name, priority, enabled, action, + src_zone, src_address_object_id, src_address_group_id, src_cidr, + dst_zone, dst_address_object_id, dst_address_group_id, dst_cidr, + service_object_id, service_group_id, + log, comment, created_at, updated_at +FROM firewall_rules +` + +func (r *RulesRepo) List(ctx context.Context) ([]models.FirewallRule, error) { + rows, err := r.Pool.Query(ctx, ruleBaseSelect+" ORDER BY priority DESC, id ASC") + if err != nil { + return nil, err + } + defer rows.Close() + out := make([]models.FirewallRule, 0, 16) + for rows.Next() { + x, err := scanRule(rows) + if err != nil { + return nil, err + } + out = append(out, *x) + } + return out, rows.Err() +} + +func (r *RulesRepo) Get(ctx context.Context, id int64) (*models.FirewallRule, error) { + row := r.Pool.QueryRow(ctx, ruleBaseSelect+" WHERE id = $1", id) + x, err := scanRule(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrRuleNotFound + } + return nil, err + } + return x, nil +} + +func (r *RulesRepo) Create(ctx context.Context, x models.FirewallRule) (*models.FirewallRule, error) { + row := r.Pool.QueryRow(ctx, ` +INSERT INTO firewall_rules ( + name, priority, enabled, action, + src_zone, src_address_object_id, src_address_group_id, src_cidr, + dst_zone, dst_address_object_id, dst_address_group_id, dst_cidr, + service_object_id, service_group_id, + log, comment +) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16 +) +RETURNING id, name, priority, enabled, action, + src_zone, src_address_object_id, src_address_group_id, src_cidr, + dst_zone, dst_address_object_id, dst_address_group_id, dst_cidr, + service_object_id, service_group_id, + log, comment, created_at, updated_at`, + x.Name, x.Priority, x.Enabled, x.Action, + x.SrcZone, x.SrcAddressObjectID, x.SrcAddressGroupID, x.SrcCIDR, + x.DstZone, x.DstAddressObjectID, x.DstAddressGroupID, x.DstCIDR, + x.ServiceObjectID, x.ServiceGroupID, + x.Log, x.Comment) + return scanRule(row) +} + +func (r *RulesRepo) Update(ctx context.Context, id int64, x models.FirewallRule) (*models.FirewallRule, error) { + row := r.Pool.QueryRow(ctx, ` +UPDATE firewall_rules SET + name = $1, priority = $2, enabled = $3, action = $4, + src_zone = $5, src_address_object_id = $6, src_address_group_id = $7, src_cidr = $8, + dst_zone = $9, dst_address_object_id = $10, dst_address_group_id = $11, dst_cidr = $12, + service_object_id = $13, service_group_id = $14, + log = $15, comment = $16, updated_at = NOW() +WHERE id = $17 +RETURNING id, name, priority, enabled, action, + src_zone, src_address_object_id, src_address_group_id, src_cidr, + dst_zone, dst_address_object_id, dst_address_group_id, dst_cidr, + service_object_id, service_group_id, + log, comment, created_at, updated_at`, + x.Name, x.Priority, x.Enabled, x.Action, + x.SrcZone, x.SrcAddressObjectID, x.SrcAddressGroupID, x.SrcCIDR, + x.DstZone, x.DstAddressObjectID, x.DstAddressGroupID, x.DstCIDR, + x.ServiceObjectID, x.ServiceGroupID, + x.Log, x.Comment, id) + out, err := scanRule(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrRuleNotFound + } + return nil, err + } + return out, nil +} + +func (r *RulesRepo) Delete(ctx context.Context, id int64) error { + tag, err := r.Pool.Exec(ctx, `DELETE FROM firewall_rules WHERE id = $1`, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return ErrRuleNotFound + } + return nil +} + +func scanRule(row interface{ Scan(...any) error }) (*models.FirewallRule, error) { + var x models.FirewallRule + if err := row.Scan( + &x.ID, &x.Name, &x.Priority, &x.Enabled, &x.Action, + &x.SrcZone, &x.SrcAddressObjectID, &x.SrcAddressGroupID, &x.SrcCIDR, + &x.DstZone, &x.DstAddressObjectID, &x.DstAddressGroupID, &x.DstCIDR, + &x.ServiceObjectID, &x.ServiceGroupID, + &x.Log, &x.Comment, &x.CreatedAt, &x.UpdatedAt, + ); err != nil { + return nil, err + } + return &x, nil +} diff --git a/internal/services/firewall/servicegroups.go b/internal/services/firewall/servicegroups.go new file mode 100644 index 0000000..75c0dbb --- /dev/null +++ b/internal/services/firewall/servicegroups.go @@ -0,0 +1,175 @@ +package firewall + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "git.netcell-it.de/projekte/edgeguard-native/internal/models" +) + +var ErrServiceGroupNotFound = errors.New("service group not found") + +type ServiceGroupsRepo struct { + Pool *pgxpool.Pool +} + +func NewServiceGroupsRepo(pool *pgxpool.Pool) *ServiceGroupsRepo { + return &ServiceGroupsRepo{Pool: pool} +} + +const svcGrpBaseSelect = ` +SELECT id, name, description, created_at, updated_at +FROM firewall_service_groups +` + +func (r *ServiceGroupsRepo) List(ctx context.Context) ([]models.FirewallServiceGroup, error) { + rows, err := r.Pool.Query(ctx, svcGrpBaseSelect+" ORDER BY name ASC") + if err != nil { + return nil, err + } + defer rows.Close() + groups := []models.FirewallServiceGroup{} + byID := map[int64]int{} + for rows.Next() { + g, err := scanSvcGrp(rows) + if err != nil { + return nil, err + } + byID[g.ID] = len(groups) + groups = append(groups, *g) + } + if err := rows.Err(); err != nil { + return nil, err + } + mRows, err := r.Pool.Query(ctx, `SELECT group_id, service_id FROM firewall_service_group_members`) + if err != nil { + return nil, err + } + defer mRows.Close() + for mRows.Next() { + var gid, sid int64 + if err := mRows.Scan(&gid, &sid); err != nil { + return nil, err + } + if idx, ok := byID[gid]; ok { + groups[idx].MemberIDs = append(groups[idx].MemberIDs, sid) + } + } + return groups, mRows.Err() +} + +func (r *ServiceGroupsRepo) Get(ctx context.Context, id int64) (*models.FirewallServiceGroup, error) { + row := r.Pool.QueryRow(ctx, svcGrpBaseSelect+" WHERE id = $1", id) + g, err := scanSvcGrp(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrServiceGroupNotFound + } + return nil, err + } + mRows, err := r.Pool.Query(ctx, + `SELECT service_id FROM firewall_service_group_members WHERE group_id = $1 ORDER BY service_id`, id) + if err != nil { + return nil, err + } + defer mRows.Close() + for mRows.Next() { + var sid int64 + if err := mRows.Scan(&sid); err != nil { + return nil, err + } + g.MemberIDs = append(g.MemberIDs, sid) + } + return g, mRows.Err() +} + +func (r *ServiceGroupsRepo) Create(ctx context.Context, g models.FirewallServiceGroup) (*models.FirewallServiceGroup, error) { + tx, err := r.Pool.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback(ctx) + + row := tx.QueryRow(ctx, ` +INSERT INTO firewall_service_groups (name, description) VALUES ($1, $2) +RETURNING id, name, description, created_at, updated_at`, + g.Name, g.Description) + out, err := scanSvcGrp(row) + if err != nil { + return nil, err + } + if len(g.MemberIDs) > 0 { + if err := insertSvcGrpMembers(ctx, tx, out.ID, g.MemberIDs); err != nil { + return nil, err + } + out.MemberIDs = append([]int64{}, g.MemberIDs...) + } + return out, tx.Commit(ctx) +} + +func (r *ServiceGroupsRepo) Update(ctx context.Context, id int64, g models.FirewallServiceGroup) (*models.FirewallServiceGroup, error) { + tx, err := r.Pool.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback(ctx) + + row := tx.QueryRow(ctx, ` +UPDATE firewall_service_groups SET name = $1, description = $2, updated_at = NOW() +WHERE id = $3 +RETURNING id, name, description, created_at, updated_at`, + g.Name, g.Description, id) + out, err := scanSvcGrp(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrServiceGroupNotFound + } + return nil, err + } + if _, err := tx.Exec(ctx, `DELETE FROM firewall_service_group_members WHERE group_id = $1`, id); err != nil { + return nil, err + } + if len(g.MemberIDs) > 0 { + if err := insertSvcGrpMembers(ctx, tx, id, g.MemberIDs); err != nil { + return nil, err + } + out.MemberIDs = append([]int64{}, g.MemberIDs...) + } + return out, tx.Commit(ctx) +} + +func (r *ServiceGroupsRepo) Delete(ctx context.Context, id int64) error { + tag, err := r.Pool.Exec(ctx, `DELETE FROM firewall_service_groups WHERE id = $1`, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return ErrServiceGroupNotFound + } + return nil +} + +func insertSvcGrpMembers(ctx context.Context, tx pgx.Tx, gid int64, members []int64) error { + for _, sid := range members { + if _, err := tx.Exec(ctx, + `INSERT INTO firewall_service_group_members (group_id, service_id) VALUES ($1, $2) ON CONFLICT DO NOTHING`, + gid, sid); err != nil { + return err + } + } + return nil +} + +func scanSvcGrp(row interface{ Scan(...any) error }) (*models.FirewallServiceGroup, error) { + var g models.FirewallServiceGroup + if err := row.Scan( + &g.ID, &g.Name, &g.Description, + &g.CreatedAt, &g.UpdatedAt, + ); err != nil { + return nil, err + } + return &g, nil +} diff --git a/internal/services/firewall/services.go b/internal/services/firewall/services.go new file mode 100644 index 0000000..52d58be --- /dev/null +++ b/internal/services/firewall/services.go @@ -0,0 +1,113 @@ +package firewall + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "git.netcell-it.de/projekte/edgeguard-native/internal/models" +) + +var ErrServiceNotFound = errors.New("service not found") + +type ServicesRepo struct { + Pool *pgxpool.Pool +} + +func NewServicesRepo(pool *pgxpool.Pool) *ServicesRepo { return &ServicesRepo{Pool: pool} } + +const svcBaseSelect = ` +SELECT id, name, proto, port_start, port_end, builtin, description, + created_at, updated_at +FROM firewall_services +` + +func (r *ServicesRepo) List(ctx context.Context) ([]models.FirewallService, error) { + rows, err := r.Pool.Query(ctx, svcBaseSelect+" ORDER BY name ASC") + if err != nil { + return nil, err + } + defer rows.Close() + out := make([]models.FirewallService, 0, 16) + for rows.Next() { + s, err := scanService(rows) + if err != nil { + return nil, err + } + out = append(out, *s) + } + return out, rows.Err() +} + +func (r *ServicesRepo) Get(ctx context.Context, id int64) (*models.FirewallService, error) { + row := r.Pool.QueryRow(ctx, svcBaseSelect+" WHERE id = $1", id) + s, err := scanService(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrServiceNotFound + } + return nil, err + } + return s, nil +} + +func (r *ServicesRepo) Create(ctx context.Context, s models.FirewallService) (*models.FirewallService, error) { + row := r.Pool.QueryRow(ctx, ` +INSERT INTO firewall_services (name, proto, port_start, port_end, builtin, description) +VALUES ($1, $2, $3, $4, FALSE, $5) +RETURNING id, name, proto, port_start, port_end, builtin, description, created_at, updated_at`, + s.Name, s.Proto, s.PortStart, s.PortEnd, s.Description) + return scanService(row) +} + +func (r *ServicesRepo) Update(ctx context.Context, id int64, s models.FirewallService) (*models.FirewallService, error) { + // Forbid editing builtin services — they're guaranteed by the + // migration set; users that want a tweak can clone with a new name. + cur, err := r.Get(ctx, id) + if err != nil { + return nil, err + } + if cur.Builtin { + return nil, errors.New("builtin service cannot be edited — clone it under a new name") + } + row := r.Pool.QueryRow(ctx, ` +UPDATE firewall_services SET + name = $1, proto = $2, port_start = $3, port_end = $4, + description = $5, updated_at = NOW() +WHERE id = $6 +RETURNING id, name, proto, port_start, port_end, builtin, description, created_at, updated_at`, + s.Name, s.Proto, s.PortStart, s.PortEnd, s.Description, id) + return scanService(row) +} + +func (r *ServicesRepo) Delete(ctx context.Context, id int64) error { + cur, err := r.Get(ctx, id) + if err != nil { + return err + } + if cur.Builtin { + return errors.New("builtin service cannot be deleted") + } + tag, err := r.Pool.Exec(ctx, `DELETE FROM firewall_services WHERE id = $1`, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return ErrServiceNotFound + } + return nil +} + +func scanService(row interface{ Scan(...any) error }) (*models.FirewallService, error) { + var s models.FirewallService + if err := row.Scan( + &s.ID, &s.Name, &s.Proto, &s.PortStart, &s.PortEnd, + &s.Builtin, &s.Description, + &s.CreatedAt, &s.UpdatedAt, + ); err != nil { + return nil, err + } + return &s, nil +}