Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,50 @@ type Rule struct {
UIDRange *RuleUIDRange
Protocol uint8
Type uint8
L3mdev uint8
}

func (r Rule) Equal(x Rule) bool {
return r.Table == x.Table &&
((r.Src == nil && x.Src == nil) ||
(r.Src != nil && x.Src != nil && r.Src.String() == x.Src.String())) &&
((r.Dst == nil && x.Dst == nil) ||
(r.Dst != nil && x.Dst != nil && r.Dst.String() == x.Dst.String())) &&
r.OifName == x.OifName &&
r.Priority == x.Priority &&
r.Family == x.Family &&
r.IifName == x.IifName &&
r.Invert == x.Invert &&
r.Tos == x.Tos &&
(r.Type == x.Type ||
(r.Type == 0 && x.Type == 1 || r.Type == 1 && x.Type == 0)) && // 1 is unix.RTN_UNICAST
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so sure about this one. Without it, tests that create a rule, then read it back and compare with the struct used for creating the rule, need to be updated because currently Type is left unset in the struct and reads back as UNICAST. Maybe that's better? It seems though that when deleting a rule again, we want to have the Type cleared in the struct. Saw some test failures when keeping Type set to unicast.

r.IPProto == x.IPProto &&
r.Protocol == x.Protocol &&
r.Mark == x.Mark &&
// For non-zero marks, mask defaults to 0xFFFFFFFF if not set. So if either mask is nil
// while the other is 0xFFFFFFFF when mark is non-zero, treat the masks as identical.
// See kernel source: https://github.com/torvalds/linux/blob/v6.15/net/core/fib_rules.c#L624
(ptrEqual(r.Mask, x.Mask) || (r.Mark != 0 &&
(r.Mask == nil && *x.Mask == 0xFFFFFFFF || x.Mask == nil && *r.Mask == 0xFFFFFFFF))) &&
r.TunID == x.TunID &&
r.Goto == x.Goto &&
r.Flow == x.Flow &&
r.SuppressIfgroup == x.SuppressIfgroup &&
r.SuppressPrefixlen == x.SuppressPrefixlen &&
(r.Dport == x.Dport || (r.Dport != nil && x.Dport != nil && r.Dport.Equal(*x.Dport))) &&
(r.Sport == x.Sport || (r.Sport != nil && x.Sport != nil && r.Sport.Equal(*x.Sport))) &&
(r.UIDRange == x.UIDRange || (r.UIDRange != nil && x.UIDRange != nil && r.UIDRange.Equal(*x.UIDRange))) &&
r.L3mdev == x.L3mdev
}

func ptrEqual(a, b *uint32) bool {
if a == b {
return true
}
if (a == nil) || (b == nil) {
return false
}
return *a == *b
}

func (r Rule) String() string {
Expand Down Expand Up @@ -70,6 +114,10 @@ type RulePortRange struct {
End uint16
}

func (r RulePortRange) Equal(x RulePortRange) bool {
return r.Start == x.Start && r.End == x.End
}

// NewRuleUIDRange creates rule uid range.
func NewRuleUIDRange(start, end uint32) *RuleUIDRange {
return &RuleUIDRange{Start: start, End: end}
Expand All @@ -80,3 +128,7 @@ type RuleUIDRange struct {
Start uint32
End uint32
}

func (r RuleUIDRange) Equal(x RuleUIDRange) bool {
return r.Start == x.Start && r.End == x.End
}
17 changes: 7 additions & 10 deletions rule_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol)))
}

if rule.L3mdev > 0 {
req.AddData(nl.NewRtAttr(nl.FRA_L3MDEV, nl.Uint8Attr(rule.L3mdev)))
}

_, err := req.Execute(unix.NETLINK_ROUTE, 0)
return err
}
Expand Down Expand Up @@ -239,6 +243,7 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) (
rule.Invert = msg.Flags&FibRuleInvert > 0
rule.Family = int(msg.Family)
rule.Tos = uint(msg.Tos)
rule.Type = msg.Type

for j := range attrs {
switch attrs[j].Attr.Type {
Expand Down Expand Up @@ -291,6 +296,8 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) (
rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
case nl.FRA_PROTOCOL:
rule.Protocol = uint8(attrs[j].Value[0])
case nl.FRA_L3MDEV:
rule.L3mdev = uint8(attrs[j].Value[0])
}
}

Expand Down Expand Up @@ -336,16 +343,6 @@ func (pr *RuleUIDRange) toRtAttrData() []byte {
return bytes.Join(b, []byte{})
}

func ptrEqual(a, b *uint32) bool {
if a == b {
return true
}
if (a == nil) || (b == nil) {
return false
}
return *a == *b
}

func (r Rule) typeString() string {
switch r.Type {
case unix.RTN_UNSPEC: // zero
Expand Down
130 changes: 109 additions & 21 deletions rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package netlink

import (
"net"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -54,7 +55,7 @@ func TestRuleAddDel(t *testing.T) {
// find this rule
found := ruleExists(rules, *rule)
if !found {
t.Fatal("Rule has diffrent options than one added")
t.Fatal("Rule has different options than one added")
}

if err := RuleDel(rule); err != nil {
Expand Down Expand Up @@ -600,7 +601,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) {
t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules))
} else {
for i := range wantRules {
if !ruleEquals(wantRules[i], rules[i]) {
if !wantRules[i].Equal(rules[i]) {
t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i])
}
}
Expand Down Expand Up @@ -666,30 +667,117 @@ func TestRuleString(t *testing.T) {

func ruleExists(rules []Rule, rule Rule) bool {
for i := range rules {
if ruleEquals(rules[i], rule) {
if rules[i].Equal(rule) {
return true
}
}

return false
}

func ruleEquals(a, b Rule) bool {
return a.Table == b.Table &&
((a.Src == nil && b.Src == nil) ||
(a.Src != nil && b.Src != nil && a.Src.String() == b.Src.String())) &&
((a.Dst == nil && b.Dst == nil) ||
(a.Dst != nil && b.Dst != nil && a.Dst.String() == b.Dst.String())) &&
a.OifName == b.OifName &&
a.Priority == b.Priority &&
a.Family == b.Family &&
a.IifName == b.IifName &&
a.Invert == b.Invert &&
a.Tos == b.Tos &&
a.Type == b.Type &&
a.IPProto == b.IPProto &&
a.Protocol == b.Protocol &&
a.Mark == b.Mark &&
(ptrEqual(a.Mask, b.Mask) || (a.Mark != 0 &&
(a.Mask == nil && *b.Mask == 0xFFFFFFFF || b.Mask == nil && *a.Mask == 0xFFFFFFFF)))
func TestRuleEqual(t *testing.T) {
cases := []Rule{
{Priority: 1000},
{Family: FAMILY_V6},
{Table: 10},
{Mark: 1},
{Mask: &[]uint32{0x1}[0]},
{Tos: 1},
{TunID: 3},
{Goto: 10},
{Src: &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)}},
{Dst: &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)}},
{Flow: 3},
{IifName: "IifName"},
{OifName: "OifName"},
{SuppressIfgroup: 7},
{SuppressPrefixlen: 16},
{Invert: true},
{Dport: &RulePortRange{Start: 10, End: 20}},
{Sport: &RulePortRange{Start: 1, End: 2}},
{IPProto: unix.IPPROTO_TCP},
{UIDRange: &RuleUIDRange{Start: 3, End: 5}},
{Protocol: FAMILY_V6},
{Type: unix.RTN_UNREACHABLE},
{L3mdev: 1},
}
for i1 := range cases {
for i2 := range cases {
got := cases[i1].Equal(cases[i2])
expected := i1 == i2
if got != expected {
t.Errorf("Equal(%q,%q) == %s but expected %s",
cases[i1], cases[i2],
strconv.FormatBool(got),
strconv.FormatBool(expected))
}
}
}
}

func TestRuleEqualTypeUnspecifiedEqualsUnicast(t *testing.T) {
a := Rule{Type: unix.RTN_UNSPEC}
b := Rule{Type: unix.RTN_UNICAST}
if !a.Equal(b) || !b.Equal(a) {
t.Errorf("Rules are expected to be equal")
}
}

func TestRuleEqualMaskMark(t *testing.T) {
a := Rule{Mark: 1, Mask: nil}
b := Rule{Mark: 1, Mask: &[]uint32{0xFFFFFFFF}[0]}
if !a.Equal(b) || !b.Equal(a) {
t.Errorf("Rules are expected to be equal")
}

b = Rule{Mark: 2, Mask: &[]uint32{0xFFFFFFFF}[0]}
if a.Equal(b) || b.Equal(a) {
t.Errorf("Rules are not expected to be equal")
}

a = Rule{Mark: 0, Mask: nil}
b = Rule{Mark: 0, Mask: &[]uint32{0xFFFFFFFF}[0]}
if a.Equal(b) || b.Equal(a) {
t.Errorf("Rules are not expected to be equal")
}
}

func TestRulePortRangeEqual(t *testing.T) {
cases := []RulePortRange{
{Start: 10, End: 10},
{Start: 10, End: 22},
{Start: 11, End: 22},
}
for i1 := range cases {
for i2 := range cases {
got := cases[i1].Equal(cases[i2])
expected := i1 == i2
if got != expected {
t.Errorf("Equal(%q,%q) == %s but expected %s",
cases[i1], cases[i2],
strconv.FormatBool(got),
strconv.FormatBool(expected))
}
}
}
}

func TestRuleUIDRangeEqual(t *testing.T) {
cases := []RuleUIDRange{
{Start: 10, End: 10},
{Start: 10, End: 22},
{Start: 11, End: 22},
}
for i1 := range cases {
for i2 := range cases {
got := cases[i1].Equal(cases[i2])
expected := i1 == i2
if got != expected {
t.Errorf("Equal(%q,%q) == %s but expected %s",
cases[i1], cases[i2],
strconv.FormatBool(got),
strconv.FormatBool(expected))
}
}
}
}
Loading