Skip to content

Commit 6633133

Browse files
authored
Add Config.RedirectURLHostname (#34)
1 parent 5e33be6 commit 6633133

File tree

3 files changed

+86
-28
lines changed

3 files changed

+86
-28
lines changed

e2e_test/e2e_test.go

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,46 @@ func TestGetToken(t *testing.T) {
3939
t.Errorf("scope wants %s but %s", w, r.Scope)
4040
return fmt.Sprintf("%s?error=invalid_scope", r.RedirectURI)
4141
}
42+
redirectURIPrefix := "http://localhost:"
43+
if !strings.HasPrefix(r.RedirectURI, redirectURIPrefix) {
44+
t.Errorf("redirect_uri wants prefix %s but was %s", redirectURIPrefix, r.RedirectURI)
45+
return fmt.Sprintf("%s?error=invalid_redirect_uri", r.RedirectURI)
46+
}
47+
return fmt.Sprintf("%s?state=%s&code=%s", r.RedirectURI, r.State, "AUTH_CODE")
48+
},
49+
NewTokenResponse: func(r authserver.TokenRequest) (int, string) {
50+
if w := "AUTH_CODE"; r.Code != w {
51+
t.Errorf("code wants %s but %s", w, r.Code)
52+
return 400, invalidGrantResponse
53+
}
54+
return 200, validTokenResponse
55+
},
56+
}
57+
successfulTest(t, cfg, h)
58+
})
59+
60+
t.Run("RedirectURLHostname", func(t *testing.T) {
61+
cfg := oauth2cli.Config{
62+
OAuth2Config: oauth2.Config{
63+
ClientID: "YOUR_CLIENT_ID",
64+
ClientSecret: "YOUR_CLIENT_SECRET",
65+
Scopes: []string{"email", "profile"},
66+
},
67+
RedirectURLHostname: "127.0.0.1",
68+
LocalServerMiddleware: loggingMiddleware(t),
69+
}
70+
h := &authserver.Handler{
71+
T: t,
72+
NewAuthorizationResponse: func(r authserver.AuthorizationRequest) string {
73+
if w := "email profile"; r.Scope != w {
74+
t.Errorf("scope wants %s but %s", w, r.Scope)
75+
return fmt.Sprintf("%s?error=invalid_scope", r.RedirectURI)
76+
}
77+
redirectURIPrefix := "http://127.0.0.1:"
78+
if !strings.HasPrefix(r.RedirectURI, redirectURIPrefix) {
79+
t.Errorf("redirect_uri wants prefix %s but was %s", redirectURIPrefix, r.RedirectURI)
80+
return fmt.Sprintf("%s?error=invalid_redirect_uri", r.RedirectURI)
81+
}
4282
return fmt.Sprintf("%s?state=%s&code=%s", r.RedirectURI, r.State, "AUTH_CODE")
4383
},
4484
NewTokenResponse: func(r authserver.TokenRequest) (int, string) {
@@ -70,8 +110,9 @@ func TestGetToken(t *testing.T) {
70110
t.Errorf("scope wants %s but %s", w, r.Scope)
71111
return fmt.Sprintf("%s?error=invalid_scope", r.RedirectURI)
72112
}
73-
if !strings.HasPrefix(r.RedirectURI, "https://") {
74-
t.Errorf("redirect_uri must start with https:// when using TLS config %s", r.RedirectURI)
113+
redirectURIPrefix := "https://localhost:"
114+
if !strings.HasPrefix(r.RedirectURI, redirectURIPrefix) {
115+
t.Errorf("redirect_uri wants prefix %s but was %s", redirectURIPrefix, r.RedirectURI)
75116
return fmt.Sprintf("%s?error=invalid_redirect_uri", r.RedirectURI)
76117
}
77118
return fmt.Sprintf("%s?state=%s&code=%s", r.RedirectURI, r.State, "AUTH_CODE")

oauth2cli.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ type Config struct {
2121
// OAuth2 config.
2222
// RedirectURL will be automatically set to the local server.
2323
OAuth2Config oauth2.Config
24+
// Hostname of the redirect URL.
25+
// You can set this if your provider does not accept localhost.
26+
// Default to localhost.
27+
RedirectURLHostname string
2428
// Options for an authorization request.
2529
// You can set oauth2.AccessTypeOffline and the PKCE options here.
2630
AuthCodeOptions []oauth2.AuthCodeOption
@@ -67,6 +71,30 @@ type Config struct {
6771
LocalServerPort []int
6872
}
6973

74+
func (c *Config) validateAndSetDefaults() error {
75+
if (c.LocalServerCertFile != "" && c.LocalServerKeyFile == "") ||
76+
(c.LocalServerCertFile == "" && c.LocalServerKeyFile != "") {
77+
return fmt.Errorf("both LocalServerCertFile and LocalServerKeyFile must be set")
78+
}
79+
if c.RedirectURLHostname == "" {
80+
c.RedirectURLHostname = "localhost"
81+
}
82+
if c.State == "" {
83+
s, err := oauth2params.NewState()
84+
if err != nil {
85+
return fmt.Errorf("could not generate a state parameter: %w", err)
86+
}
87+
c.State = s
88+
}
89+
if c.LocalServerMiddleware == nil {
90+
c.LocalServerMiddleware = noopMiddleware
91+
}
92+
if c.LocalServerSuccessHTML == "" {
93+
c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML
94+
}
95+
return nil
96+
}
97+
7098
func (c *Config) populateDeprecatedFields() {
7199
if len(c.LocalServerPort) > 0 {
72100
address := c.LocalServerAddress
@@ -92,18 +120,8 @@ func (c *Config) populateDeprecatedFields() {
92120
// 6. Return the code.
93121
//
94122
func GetToken(ctx context.Context, config Config) (*oauth2.Token, error) {
95-
if config.State == "" {
96-
s, err := oauth2params.NewState()
97-
if err != nil {
98-
return nil, fmt.Errorf("could not generate a state parameter: %w", err)
99-
}
100-
config.State = s
101-
}
102-
if config.LocalServerMiddleware == nil {
103-
config.LocalServerMiddleware = noopMiddleware
104-
}
105-
if config.LocalServerSuccessHTML == "" {
106-
config.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML
123+
if err := config.validateAndSetDefaults(); err != nil {
124+
return nil, fmt.Errorf("invalid config: %w", err)
107125
}
108126
config.populateDeprecatedFields()
109127
code, err := receiveCodeViaLocalServer(ctx, &config)

server.go

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"net"
78
"net/http"
89

910
"github.com/int128/listener"
@@ -16,17 +17,7 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
1617
return "", fmt.Errorf("could not start a local server: %w", err)
1718
}
1819
defer l.Close()
19-
20-
switch {
21-
case c.LocalServerCertFile == "" && c.LocalServerKeyFile == "":
22-
case c.LocalServerCertFile != "" && c.LocalServerKeyFile != "":
23-
l.URL.Scheme = "https"
24-
default:
25-
return "", fmt.Errorf("both LocalServerCertFile and LocalServerKeyFile must be set")
26-
}
27-
if c.OAuth2Config.RedirectURL == "" {
28-
c.OAuth2Config.RedirectURL = l.URL.String()
29-
}
20+
c.OAuth2Config.RedirectURL = computeRedirectURL(l, c)
3021

3122
respCh := make(chan *authorizationResponse)
3223
server := http.Server{
@@ -72,7 +63,7 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
7263
return nil
7364
})
7465
if c.LocalServerReadyChan != nil {
75-
c.LocalServerReadyChan <- l.URL.String()
66+
c.LocalServerReadyChan <- c.OAuth2Config.RedirectURL
7667
}
7768

7869
if err := eg.Wait(); err != nil {
@@ -84,6 +75,14 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
8475
return resp.code, resp.err
8576
}
8677

78+
func computeRedirectURL(l net.Listener, c *Config) string {
79+
hostPort := fmt.Sprintf("%s:%d", c.RedirectURLHostname, l.Addr().(*net.TCPAddr).Port)
80+
if c.LocalServerCertFile != "" {
81+
return "https://" + hostPort
82+
}
83+
return "http://" + hostPort
84+
}
85+
8786
type authorizationResponse struct {
8887
code string // non-empty if a valid code is received
8988
err error // non-nil if an error is received or any error occurs
@@ -109,8 +108,8 @@ func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
109108
}
110109

111110
func (h *localServerHandler) handleIndex(w http.ResponseWriter, r *http.Request) {
112-
url := h.config.OAuth2Config.AuthCodeURL(h.config.State, h.config.AuthCodeOptions...)
113-
http.Redirect(w, r, url, 302)
111+
authCodeURL := h.config.OAuth2Config.AuthCodeURL(h.config.State, h.config.AuthCodeOptions...)
112+
http.Redirect(w, r, authCodeURL, 302)
114113
}
115114

116115
func (h *localServerHandler) handleCodeResponse(w http.ResponseWriter, r *http.Request) *authorizationResponse {

0 commit comments

Comments
 (0)