Skip to content

Commit 0c5719a

Browse files
antoniomikapires
authored andcommitted
protocol: fix header timeout when the user supplied deadline(s)
1 parent 094c0b6 commit 0c5719a

File tree

3 files changed

+276
-62
lines changed

3 files changed

+276
-62
lines changed

protocol.go

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@ import (
88
"time"
99
)
1010

11+
var DEFAULT_TIMEOUT = 200 * time.Millisecond
12+
1113
// Listener is used to wrap an underlying listener,
1214
// whose connections may be using the HAProxy Proxy Protocol.
1315
// If the connection is using the protocol, the RemoteAddr() will return
14-
// the correct client address.
16+
// the correct client address. ReadHeaderTimeout will be applied to all
17+
// connections in order to prevent blocking operations. If no ReadHeaderTimeout
18+
// is set, a default of 200ms will be used. This can be disabled by setting the
19+
// timeout to < 0.
1520
type Listener struct {
1621
Listener net.Listener
1722
Policy PolicyFunc
@@ -21,7 +26,8 @@ type Listener struct {
2126

2227
// Conn is used to wrap and underlying connection which
2328
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
24-
// return the address of the client instead of the proxy address.
29+
// return the address of the client instead of the proxy address. Each connection
30+
// will have its own readHeaderTimeout and readDeadline set by the Accept() call.
2531
type Conn struct {
2632
bufReader *bufio.Reader
2733
conn net.Conn
@@ -30,6 +36,8 @@ type Conn struct {
3036
ProxyHeaderPolicy Policy
3137
Validate Validator
3238
readErr error
39+
readHeaderTimeout time.Duration
40+
readDeadline time.Time
3341
}
3442

3543
// Validator receives a header and decides whether it is a valid one
@@ -53,12 +61,6 @@ func (p *Listener) Accept() (net.Conn, error) {
5361
return nil, err
5462
}
5563

56-
if d := p.ReadHeaderTimeout; d != 0 {
57-
// The deadline will be reset after parsing the header.
58-
// Otherwise, future p.conn.Read() will timeout.
59-
conn.SetReadDeadline(time.Now().Add(d))
60-
}
61-
6264
proxyHeaderPolicy := USE
6365
if p.Policy != nil {
6466
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
@@ -74,6 +76,15 @@ func (p *Listener) Accept() (net.Conn, error) {
7476
WithPolicy(proxyHeaderPolicy),
7577
ValidateHeader(p.ValidateHeader),
7678
)
79+
80+
// If the ReadHeaderTimeout for the listener is 0, set a default of 200ms
81+
if p.ReadHeaderTimeout == 0 {
82+
p.ReadHeaderTimeout = DEFAULT_TIMEOUT
83+
}
84+
85+
// Set the readHeaderTimeout of the new conn to the value of the listener
86+
newConn.readHeaderTimeout = p.ReadHeaderTimeout
87+
7788
return newConn, nil
7889
}
7990

@@ -108,7 +119,6 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
108119
func (p *Conn) Read(b []byte) (int, error) {
109120
p.once.Do(func() {
110121
p.readErr = p.readHeader()
111-
p.conn.SetReadDeadline(time.Time{})
112122
})
113123
if p.readErr != nil {
114124
return 0, p.readErr
@@ -201,11 +211,16 @@ func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) {
201211

202212
// SetDeadline wraps original conn.SetDeadline
203213
func (p *Conn) SetDeadline(t time.Time) error {
214+
p.readDeadline = t
204215
return p.conn.SetDeadline(t)
205216
}
206217

207218
// SetReadDeadline wraps original conn.SetReadDeadline
208219
func (p *Conn) SetReadDeadline(t time.Time) error {
220+
// Set a local var that tells us the desired deadline. This is
221+
// needed in order to reset the read deadline to the one that is
222+
// desired by the user, rather than an empty deadline.
223+
p.readDeadline = t
209224
return p.conn.SetReadDeadline(t)
210225
}
211226

@@ -215,7 +230,28 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
215230
}
216231

217232
func (p *Conn) readHeader() error {
233+
// If the connection's readHeaderTimeout is more than 0,
234+
// push our deadline back to now plus the timeout. This should only
235+
// run on the connection, as we don't want to override the previous
236+
// read deadline the user may have used.
237+
if p.readHeaderTimeout > 0 {
238+
p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout))
239+
}
240+
218241
header, err := Read(p.bufReader)
242+
243+
// If the connection's readHeaderTimeout is more than 0, undo the change to the
244+
// deadline that we made above. Because we retain the readDeadline as part of our
245+
// SetReadDeadline override, we know the user's desired deadline so we use that.
246+
// Therefore, we check whether the error is a net.Timeout and if it is, we decide
247+
// the proxy proto does not exist and set the error accordingly.
248+
if p.readHeaderTimeout > 0 {
249+
p.conn.SetReadDeadline(p.readDeadline)
250+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
251+
err = ErrNoProxyProtocol
252+
}
253+
}
254+
219255
// For the purpose of this wrapper shamefully stolen from armon/go-proxyproto
220256
// let's act as if there was no error when PROXY protocol is not present.
221257
if err == ErrNoProxyProtocol {

0 commit comments

Comments
 (0)