diff --git a/rule.go b/rule.go index 9d74c7cd..536f74a8 100644 --- a/rule.go +++ b/rule.go @@ -29,6 +29,50 @@ type Rule struct { UIDRange *RuleUIDRange Protocol uint8 Type uint8 + L3mdev uint8 +} + +func (r Rule) Equal(x Rule) bool { + return r.Table == x.Table && + ((r.Src == nil && x.Src == nil) || + (r.Src != nil && x.Src != nil && r.Src.String() == x.Src.String())) && + ((r.Dst == nil && x.Dst == nil) || + (r.Dst != nil && x.Dst != nil && r.Dst.String() == x.Dst.String())) && + r.OifName == x.OifName && + r.Priority == x.Priority && + r.Family == x.Family && + r.IifName == x.IifName && + r.Invert == x.Invert && + r.Tos == x.Tos && + (r.Type == x.Type || + (r.Type == 0 && x.Type == 1 || r.Type == 1 && x.Type == 0)) && // 1 is unix.RTN_UNICAST + r.IPProto == x.IPProto && + r.Protocol == x.Protocol && + r.Mark == x.Mark && + // For non-zero marks, mask defaults to 0xFFFFFFFF if not set. So if either mask is nil + // while the other is 0xFFFFFFFF when mark is non-zero, treat the masks as identical. + // See kernel source: https://github.com/torvalds/linux/blob/v6.15/net/core/fib_rules.c#L624 + (ptrEqual(r.Mask, x.Mask) || (r.Mark != 0 && + (r.Mask == nil && *x.Mask == 0xFFFFFFFF || x.Mask == nil && *r.Mask == 0xFFFFFFFF))) && + r.TunID == x.TunID && + r.Goto == x.Goto && + r.Flow == x.Flow && + r.SuppressIfgroup == x.SuppressIfgroup && + r.SuppressPrefixlen == x.SuppressPrefixlen && + (r.Dport == x.Dport || (r.Dport != nil && x.Dport != nil && r.Dport.Equal(*x.Dport))) && + (r.Sport == x.Sport || (r.Sport != nil && x.Sport != nil && r.Sport.Equal(*x.Sport))) && + (r.UIDRange == x.UIDRange || (r.UIDRange != nil && x.UIDRange != nil && r.UIDRange.Equal(*x.UIDRange))) && + r.L3mdev == x.L3mdev +} + +func ptrEqual(a, b *uint32) bool { + if a == b { + return true + } + if (a == nil) || (b == nil) { + return false + } + return *a == *b } func (r Rule) String() string { @@ -70,6 +114,10 @@ type RulePortRange struct { End uint16 } +func (r RulePortRange) Equal(x RulePortRange) bool { + return r.Start == x.Start && r.End == x.End +} + // NewRuleUIDRange creates rule uid range. func NewRuleUIDRange(start, end uint32) *RuleUIDRange { return &RuleUIDRange{Start: start, End: end} @@ -80,3 +128,7 @@ type RuleUIDRange struct { Start uint32 End uint32 } + +func (r RuleUIDRange) Equal(x RuleUIDRange) bool { + return r.Start == x.Start && r.End == x.End +} diff --git a/rule_linux.go b/rule_linux.go index dba99147..1fcfc3b7 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -178,6 +178,10 @@ func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error { req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol))) } + if rule.L3mdev > 0 { + req.AddData(nl.NewRtAttr(nl.FRA_L3MDEV, nl.Uint8Attr(rule.L3mdev))) + } + _, err := req.Execute(unix.NETLINK_ROUTE, 0) return err } @@ -239,6 +243,7 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( rule.Invert = msg.Flags&FibRuleInvert > 0 rule.Family = int(msg.Family) rule.Tos = uint(msg.Tos) + rule.Type = msg.Type for j := range attrs { switch attrs[j].Attr.Type { @@ -291,6 +296,8 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8])) case nl.FRA_PROTOCOL: rule.Protocol = uint8(attrs[j].Value[0]) + case nl.FRA_L3MDEV: + rule.L3mdev = uint8(attrs[j].Value[0]) } } @@ -336,16 +343,6 @@ func (pr *RuleUIDRange) toRtAttrData() []byte { return bytes.Join(b, []byte{}) } -func ptrEqual(a, b *uint32) bool { - if a == b { - return true - } - if (a == nil) || (b == nil) { - return false - } - return *a == *b -} - func (r Rule) typeString() string { switch r.Type { case unix.RTN_UNSPEC: // zero diff --git a/rule_test.go b/rule_test.go index 4420e5b5..29490cb8 100644 --- a/rule_test.go +++ b/rule_test.go @@ -5,6 +5,7 @@ package netlink import ( "net" + "strconv" "testing" "time" @@ -54,7 +55,7 @@ func TestRuleAddDel(t *testing.T) { // find this rule found := ruleExists(rules, *rule) if !found { - t.Fatal("Rule has diffrent options than one added") + t.Fatal("Rule has different options than one added") } if err := RuleDel(rule); err != nil { @@ -600,7 +601,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules)) } else { for i := range wantRules { - if !ruleEquals(wantRules[i], rules[i]) { + if !wantRules[i].Equal(rules[i]) { t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i]) } } @@ -666,7 +667,7 @@ func TestRuleString(t *testing.T) { func ruleExists(rules []Rule, rule Rule) bool { for i := range rules { - if ruleEquals(rules[i], rule) { + if rules[i].Equal(rule) { return true } } @@ -674,22 +675,109 @@ func ruleExists(rules []Rule, rule Rule) bool { return false } -func ruleEquals(a, b Rule) bool { - return a.Table == b.Table && - ((a.Src == nil && b.Src == nil) || - (a.Src != nil && b.Src != nil && a.Src.String() == b.Src.String())) && - ((a.Dst == nil && b.Dst == nil) || - (a.Dst != nil && b.Dst != nil && a.Dst.String() == b.Dst.String())) && - a.OifName == b.OifName && - a.Priority == b.Priority && - a.Family == b.Family && - a.IifName == b.IifName && - a.Invert == b.Invert && - a.Tos == b.Tos && - a.Type == b.Type && - a.IPProto == b.IPProto && - a.Protocol == b.Protocol && - a.Mark == b.Mark && - (ptrEqual(a.Mask, b.Mask) || (a.Mark != 0 && - (a.Mask == nil && *b.Mask == 0xFFFFFFFF || b.Mask == nil && *a.Mask == 0xFFFFFFFF))) +func TestRuleEqual(t *testing.T) { + cases := []Rule{ + {Priority: 1000}, + {Family: FAMILY_V6}, + {Table: 10}, + {Mark: 1}, + {Mask: &[]uint32{0x1}[0]}, + {Tos: 1}, + {TunID: 3}, + {Goto: 10}, + {Src: &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)}}, + {Dst: &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)}}, + {Flow: 3}, + {IifName: "IifName"}, + {OifName: "OifName"}, + {SuppressIfgroup: 7}, + {SuppressPrefixlen: 16}, + {Invert: true}, + {Dport: &RulePortRange{Start: 10, End: 20}}, + {Sport: &RulePortRange{Start: 1, End: 2}}, + {IPProto: unix.IPPROTO_TCP}, + {UIDRange: &RuleUIDRange{Start: 3, End: 5}}, + {Protocol: FAMILY_V6}, + {Type: unix.RTN_UNREACHABLE}, + {L3mdev: 1}, + } + for i1 := range cases { + for i2 := range cases { + got := cases[i1].Equal(cases[i2]) + expected := i1 == i2 + if got != expected { + t.Errorf("Equal(%q,%q) == %s but expected %s", + cases[i1], cases[i2], + strconv.FormatBool(got), + strconv.FormatBool(expected)) + } + } + } +} + +func TestRuleEqualTypeUnspecifiedEqualsUnicast(t *testing.T) { + a := Rule{Type: unix.RTN_UNSPEC} + b := Rule{Type: unix.RTN_UNICAST} + if !a.Equal(b) || !b.Equal(a) { + t.Errorf("Rules are expected to be equal") + } +} + +func TestRuleEqualMaskMark(t *testing.T) { + a := Rule{Mark: 1, Mask: nil} + b := Rule{Mark: 1, Mask: &[]uint32{0xFFFFFFFF}[0]} + if !a.Equal(b) || !b.Equal(a) { + t.Errorf("Rules are expected to be equal") + } + + b = Rule{Mark: 2, Mask: &[]uint32{0xFFFFFFFF}[0]} + if a.Equal(b) || b.Equal(a) { + t.Errorf("Rules are not expected to be equal") + } + + a = Rule{Mark: 0, Mask: nil} + b = Rule{Mark: 0, Mask: &[]uint32{0xFFFFFFFF}[0]} + if a.Equal(b) || b.Equal(a) { + t.Errorf("Rules are not expected to be equal") + } +} + +func TestRulePortRangeEqual(t *testing.T) { + cases := []RulePortRange{ + {Start: 10, End: 10}, + {Start: 10, End: 22}, + {Start: 11, End: 22}, + } + for i1 := range cases { + for i2 := range cases { + got := cases[i1].Equal(cases[i2]) + expected := i1 == i2 + if got != expected { + t.Errorf("Equal(%q,%q) == %s but expected %s", + cases[i1], cases[i2], + strconv.FormatBool(got), + strconv.FormatBool(expected)) + } + } + } +} + +func TestRuleUIDRangeEqual(t *testing.T) { + cases := []RuleUIDRange{ + {Start: 10, End: 10}, + {Start: 10, End: 22}, + {Start: 11, End: 22}, + } + for i1 := range cases { + for i2 := range cases { + got := cases[i1].Equal(cases[i2]) + expected := i1 == i2 + if got != expected { + t.Errorf("Equal(%q,%q) == %s but expected %s", + cases[i1], cases[i2], + strconv.FormatBool(got), + strconv.FormatBool(expected)) + } + } + } }