diff --git a/README.md b/README.md index df5a0d87..b73e4ca2 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ as the Identity Provider. export CALLBACK_URL export CLIENT_ID export CLIENT_SECRET +export HOSTED_DOMAIN ``` ## License diff --git a/main.go b/main.go index ddd739dc..6a63fec5 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "os" + "strings" "github.com/utilitywarehouse/go-operational/op" ) @@ -17,6 +18,7 @@ var ( clientID = os.Getenv("CLIENT_ID") clientSecret = os.Getenv("CLIENT_SECRET") callbackURL = os.Getenv("CALLBACK_URL") + expectedHostedDomain = os.Getenv("HOSTED_DOMAIN") ) const oauthURL = "https://accounts.google.com/o/oauth2/auth?redirect_uri=%s&response_type=code&client_id=%s&scope=openid+email+profile&approval_prompt=force&access_type=offline" @@ -34,6 +36,10 @@ type UserInfo struct { Email string `json:"email"` } +type HostedDomain struct { + HostedDomain string `json:"hd"` +} + type TokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` @@ -99,6 +105,34 @@ func getUserEmail(accessToken string) (string, error) { return ui.Email, nil } +func getHostedDomain(accessToken string) (string, error) { + uri, _ := url.Parse(userInfoURL) + q := uri.Query() + q.Set("alt", "json") + q.Set("access_token", accessToken) + uri.RawQuery = q.Encode() + resp, err := http.Get(uri.String()) + if err != nil { + return "", err + } + defer func() { + io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() + }() + if resp.StatusCode != 200 { + return "", fmt.Errorf("Got: %d calling %s", resp.StatusCode, tokenURL) + } + if err != nil { + return "", err + } + hd := &HostedDomain{} + err = json.NewDecoder(resp.Body).Decode(hd) + if err != nil { + return "", err + } + return hd.HostedDomain, nil +} + func googleRedirect() http.Handler { redirectURL := fmt.Sprintf(oauthURL, callbackURL, clientID) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -122,6 +156,18 @@ func googleCallback() http.Handler { w.WriteHeader(http.StatusInternalServerError) } + hostedDomain, err := getHostedDomain(tokResponse.AccessToken) + if err != nil { + log.Printf("Error getting user hosted domain: %s\n", err) + w.WriteHeader(http.StatusInternalServerError) + } + + if ! strings.EqualFold(hostedDomain, expectedHostedDomain) { + log.Printf("Error hosted domain does not match (was %s instead of %s)\n", hostedDomain, expectedHostedDomain) + http.Error(w, "Forbidden", 403) + return + } + kubectlCMD := fmt.Sprintf(kubectlCMDTemplate, email, clientID, clientSecret, tokResponse.IdToken, idpIssuerURL, tokResponse.RefreshToken) w.WriteHeader(http.StatusOK)