Skip to content

Commit 3211d96

Browse files
authored
Merge pull request #64 from bohanyang/fix-v1-command
v1: fix command always LOCAL
2 parents b6f440c + 43ce4ef commit 3211d96

File tree

6 files changed

+85
-68
lines changed

6 files changed

+85
-68
lines changed

header.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header {
5454
}
5555
h := &Header{
5656
Version: version,
57-
Command: PROXY,
57+
Command: LOCAL,
5858
TransportProtocol: UNSPEC,
5959
}
6060
switch sourceAddr := sourceAddr.(type) {
@@ -88,6 +88,7 @@ func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header {
8888
}
8989
}
9090
if h.TransportProtocol != UNSPEC {
91+
h.Command = PROXY
9192
h.SourceAddr = sourceAddr
9293
h.DestinationAddr = destAddr
9394
}
@@ -152,17 +153,15 @@ func (header *Header) EqualsTo(otherHeader *Header) bool {
152153
if otherHeader == nil {
153154
return false
154155
}
155-
if header.Command.IsLocal() {
156-
return true
157-
}
158156
// TLVs only exist for version 2
159-
if header.Version == 0x02 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) {
157+
if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) {
160158
return false
161159
}
162-
if header.Version != otherHeader.Version || header.TransportProtocol != otherHeader.TransportProtocol {
160+
if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol {
163161
return false
164162
}
165-
if header.TransportProtocol == UNSPEC {
163+
// Return early for header with LOCAL command, which contains no address information
164+
if header.Command == LOCAL {
166165
return true
167166
}
168167
return header.SourceAddr.String() == otherHeader.SourceAddr.String() &&

header_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ func TestFormatInvalid(t *testing.T) {
537537
func TestHeaderProxyFromAddrs(t *testing.T) {
538538
unspec := &Header{
539539
Version: 2,
540-
Command: PROXY,
540+
Command: LOCAL,
541541
TransportProtocol: UNSPEC,
542542
}
543543

v1.go

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -33,62 +33,68 @@ func parseVersion1(reader *bufio.Reader) (*Header, error) {
3333
}
3434
// Check full signature.
3535
tokens := strings.Split(line[:len(line)-2], separator)
36-
transportProtocol := UNSPEC // doesn't exist in v1 but fits UNKNOWN.
37-
if len(tokens) > 0 {
38-
// Read address family and protocol
39-
switch tokens[1] {
40-
case "TCP4":
41-
transportProtocol = TCPv4
42-
case "TCP6":
43-
transportProtocol = TCPv6
44-
case "UNKNOWN": // no-op as UNSPEC is set already
45-
default:
46-
return nil, ErrCantReadAddressFamilyAndProtocol
47-
}
48-
49-
// Expect 6 tokens only when UNKNOWN is not present.
50-
if !transportProtocol.IsUnspec() && len(tokens) < 6 {
51-
return nil, ErrCantReadAddressFamilyAndProtocol
52-
}
53-
}
54-
55-
// Allocation only happens when a signature is found.
36+
37+
// Expect at least 2 tokens: "PROXY" and the transport protocol.
38+
if len(tokens) < 2 {
39+
return nil, ErrCantReadAddressFamilyAndProtocol
40+
}
41+
42+
// Read address family and protocol
43+
var transportProtocol AddressFamilyAndProtocol
44+
switch tokens[1] {
45+
case "TCP4":
46+
transportProtocol = TCPv4
47+
case "TCP6":
48+
transportProtocol = TCPv6
49+
case "UNKNOWN":
50+
transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN
51+
default:
52+
return nil, ErrCantReadAddressFamilyAndProtocol
53+
}
54+
55+
// Expect 6 tokens only when UNKNOWN is not present.
56+
if transportProtocol != UNSPEC && len(tokens) < 6 {
57+
return nil, ErrCantReadAddressFamilyAndProtocol
58+
}
59+
60+
// When a signature is found, allocate a v1 header with Command set to PROXY.
61+
// Command doesn't exist in v1 but set it for other parts of this library
62+
// to rely on it for determining connection details.
5663
header := initVersion1()
57-
// If UNKNOWN is present, set Command to LOCAL.
58-
// Command is not present in v1 but set it for other parts of
59-
// this library to rely on it for determining connection details.
60-
header.Command = LOCAL
6164

6265
// Transport protocol has been processed already.
6366
header.TransportProtocol = transportProtocol
6467

65-
// Only process further if UNKNOWN is not present.
66-
if header.TransportProtocol != UNSPEC {
67-
// Read addresses and ports
68-
sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2])
69-
if err != nil {
70-
return nil, err
71-
}
72-
destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3])
73-
if err != nil {
74-
return nil, err
75-
}
76-
sourcePort, err := parseV1PortNumber(tokens[4])
77-
if err != nil {
78-
return nil, err
79-
}
80-
destPort, err := parseV1PortNumber(tokens[5])
81-
if err != nil {
82-
return nil, err
83-
}
84-
header.SourceAddr = &net.TCPAddr{
85-
IP: sourceIP,
86-
Port: sourcePort,
87-
}
88-
header.DestinationAddr = &net.TCPAddr{
89-
IP: destIP,
90-
Port: destPort,
91-
}
68+
// When UNKNOWN, set the command to LOCAL and return early
69+
if header.TransportProtocol == UNSPEC {
70+
header.Command = LOCAL
71+
return header, nil
72+
}
73+
74+
// Otherwise, continue to read addresses and ports
75+
sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2])
76+
if err != nil {
77+
return nil, err
78+
}
79+
destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3])
80+
if err != nil {
81+
return nil, err
82+
}
83+
sourcePort, err := parseV1PortNumber(tokens[4])
84+
if err != nil {
85+
return nil, err
86+
}
87+
destPort, err := parseV1PortNumber(tokens[5])
88+
if err != nil {
89+
return nil, err
90+
}
91+
header.SourceAddr = &net.TCPAddr{
92+
IP: sourceIP,
93+
Port: sourcePort,
94+
}
95+
header.DestinationAddr = &net.TCPAddr{
96+
IP: destIP,
97+
Port: destPort,
9298
}
9399

94100
return header, nil

v1_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ var invalidParseV1Tests = []struct {
4545
reader: newBufioReader([]byte("PROXY " + crlf)),
4646
expectedError: ErrCantReadAddressFamilyAndProtocol,
4747
},
48+
{
49+
desc: "proxy no space crlf",
50+
reader: newBufioReader([]byte("PROXY" + crlf)),
51+
expectedError: ErrCantReadAddressFamilyAndProtocol,
52+
},
4853
{
4954
desc: "proxy something crlf",
5055
reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)),
@@ -114,7 +119,7 @@ var validParseAndWriteV1Tests = []struct {
114119
reader: bufio.NewReader(strings.NewReader(fixtureUnknown)),
115120
expectedHeader: &Header{
116121
Version: 1,
117-
Command: PROXY,
122+
Command: LOCAL,
118123
TransportProtocol: UNSPEC,
119124
SourceAddr: nil,
120125
DestinationAddr: nil,
@@ -125,7 +130,7 @@ var validParseAndWriteV1Tests = []struct {
125130
reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)),
126131
expectedHeader: &Header{
127132
Version: 1,
128-
Command: PROXY,
133+
Command: LOCAL,
129134
TransportProtocol: UNSPEC,
130135
SourceAddr: nil,
131136
DestinationAddr: nil,

v2.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ func (header *Header) formatVersion2() ([]byte, error) {
205205
addrDst = formatUnixName(destAddr.Name)
206206
}
207207

208-
//
209208
if addrSrc == nil || addrDst == nil {
210209
return nil, ErrInvalidAddress
211210
}

version_cmd.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
package proxyproto
22

3-
// ProtocolVersionAndCommand represents proxy protocol version and command.
3+
// ProtocolVersionAndCommand represents the command in proxy protocol v2.
4+
// Command doesn't exist in v1 but it should be set since other parts of
5+
// this library may rely on it for determining connection details.
46
type ProtocolVersionAndCommand byte
57

68
const (
9+
// LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1,
10+
// in which case no address information is expected.
711
LOCAL ProtocolVersionAndCommand = '\x20'
12+
// PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1,
13+
// in which case valid local/remote address and port information is expected.
814
PROXY ProtocolVersionAndCommand = '\x21'
915
)
1016

@@ -13,17 +19,19 @@ var supportedCommand = map[ProtocolVersionAndCommand]bool{
1319
PROXY: true,
1420
}
1521

16-
// IsLocal returns true if the protocol version is \x2 and command is LOCAL, false otherwise.
22+
// IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN,
23+
// i.e. when no address information is expected, false otherwise.
1724
func (pvc ProtocolVersionAndCommand) IsLocal() bool {
18-
return 0x20 == pvc&0xF0 && 0x00 == pvc&0x0F
25+
return LOCAL == pvc
1926
}
2027

21-
// IsProxy returns true if the protocol version is \x2 and command is PROXY, false otherwise.
28+
// IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN,
29+
// i.e. when valid local/remote address and port information is expected, false otherwise.
2230
func (pvc ProtocolVersionAndCommand) IsProxy() bool {
23-
return 0x20 == pvc&0xF0 && 0x01 == pvc&0x0F
31+
return PROXY == pvc
2432
}
2533

26-
// IsUnspec returns true if the protocol version or command is unspecified, false otherwise.
34+
// IsUnspec returns true if the command is unspecified, false otherwise.
2735
func (pvc ProtocolVersionAndCommand) IsUnspec() bool {
2836
return !(pvc.IsLocal() || pvc.IsProxy())
2937
}

0 commit comments

Comments
 (0)