@@ -8,10 +8,15 @@ import (
88 "time"
99)
1010
11+ var DEFAULT_TIMEOUT = 200 * time .Millisecond
12+
1113// Listener is used to wrap an underlying listener,
1214// whose connections may be using the HAProxy Proxy Protocol.
1315// If the connection is using the protocol, the RemoteAddr() will return
14- // the correct client address.
16+ // the correct client address. ReadHeaderTimeout will be applied to all
17+ // connections in order to prevent blocking operations. If no ReadHeaderTimeout
18+ // is set, a default of 200ms will be used. This can be disabled by setting the
19+ // timeout to < 0.
1520type Listener struct {
1621 Listener net.Listener
1722 Policy PolicyFunc
@@ -21,7 +26,8 @@ type Listener struct {
2126
2227// Conn is used to wrap and underlying connection which
2328// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
24- // return the address of the client instead of the proxy address.
29+ // return the address of the client instead of the proxy address. Each connection
30+ // will have its own readHeaderTimeout and readDeadline set by the Accept() call.
2531type Conn struct {
2632 bufReader * bufio.Reader
2733 conn net.Conn
@@ -30,6 +36,8 @@ type Conn struct {
3036 ProxyHeaderPolicy Policy
3137 Validate Validator
3238 readErr error
39+ readHeaderTimeout time.Duration
40+ readDeadline time.Time
3341}
3442
3543// Validator receives a header and decides whether it is a valid one
@@ -53,12 +61,6 @@ func (p *Listener) Accept() (net.Conn, error) {
5361 return nil , err
5462 }
5563
56- if d := p .ReadHeaderTimeout ; d != 0 {
57- // The deadline will be reset after parsing the header.
58- // Otherwise, future p.conn.Read() will timeout.
59- conn .SetReadDeadline (time .Now ().Add (d ))
60- }
61-
6264 proxyHeaderPolicy := USE
6365 if p .Policy != nil {
6466 proxyHeaderPolicy , err = p .Policy (conn .RemoteAddr ())
@@ -74,6 +76,15 @@ func (p *Listener) Accept() (net.Conn, error) {
7476 WithPolicy (proxyHeaderPolicy ),
7577 ValidateHeader (p .ValidateHeader ),
7678 )
79+
80+ // If the ReadHeaderTimeout for the listener is 0, set a default of 200ms
81+ if p .ReadHeaderTimeout == 0 {
82+ p .ReadHeaderTimeout = DEFAULT_TIMEOUT
83+ }
84+
85+ // Set the readHeaderTimeout of the new conn to the value of the listener
86+ newConn .readHeaderTimeout = p .ReadHeaderTimeout
87+
7788 return newConn , nil
7889}
7990
@@ -108,7 +119,6 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
108119func (p * Conn ) Read (b []byte ) (int , error ) {
109120 p .once .Do (func () {
110121 p .readErr = p .readHeader ()
111- p .conn .SetReadDeadline (time.Time {})
112122 })
113123 if p .readErr != nil {
114124 return 0 , p .readErr
@@ -201,11 +211,16 @@ func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) {
201211
202212// SetDeadline wraps original conn.SetDeadline
203213func (p * Conn ) SetDeadline (t time.Time ) error {
214+ p .readDeadline = t
204215 return p .conn .SetDeadline (t )
205216}
206217
207218// SetReadDeadline wraps original conn.SetReadDeadline
208219func (p * Conn ) SetReadDeadline (t time.Time ) error {
220+ // Set a local var that tells us the desired deadline. This is
221+ // needed in order to reset the read deadline to the one that is
222+ // desired by the user, rather than an empty deadline.
223+ p .readDeadline = t
209224 return p .conn .SetReadDeadline (t )
210225}
211226
@@ -215,7 +230,28 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
215230}
216231
217232func (p * Conn ) readHeader () error {
233+ // If the connection's readHeaderTimeout is more than 0,
234+ // push our deadline back to now plus the timeout. This should only
235+ // run on the connection, as we don't want to override the previous
236+ // read deadline the user may have used.
237+ if p .readHeaderTimeout > 0 {
238+ p .conn .SetReadDeadline (time .Now ().Add (p .readHeaderTimeout ))
239+ }
240+
218241 header , err := Read (p .bufReader )
242+
243+ // If the connection's readHeaderTimeout is more than 0, undo the change to the
244+ // deadline that we made above. Because we retain the readDeadline as part of our
245+ // SetReadDeadline override, we know the user's desired deadline so we use that.
246+ // Therefore, we check whether the error is a net.Timeout and if it is, we decide
247+ // the proxy proto does not exist and set the error accordingly.
248+ if p .readHeaderTimeout > 0 {
249+ p .conn .SetReadDeadline (p .readDeadline )
250+ if netErr , ok := err .(net.Error ); ok && netErr .Timeout () {
251+ err = ErrNoProxyProtocol
252+ }
253+ }
254+
219255 // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto
220256 // let's act as if there was no error when PROXY protocol is not present.
221257 if err == ErrNoProxyProtocol {
0 commit comments