Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ as the Identity Provider.
export CALLBACK_URL
export CLIENT_ID
export CLIENT_SECRET
export HOSTED_DOMAIN
```

## License
Expand Down
46 changes: 46 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"os"
"strings"

"github.com/utilitywarehouse/go-operational/op"
)
Expand All @@ -17,6 +18,7 @@ var (
clientID = os.Getenv("CLIENT_ID")
clientSecret = os.Getenv("CLIENT_SECRET")
callbackURL = os.Getenv("CALLBACK_URL")
expectedHostedDomain = os.Getenv("HOSTED_DOMAIN")
)

const oauthURL = "https://accounts.google.com/o/oauth2/auth?redirect_uri=%s&response_type=code&client_id=%s&scope=openid+email+profile&approval_prompt=force&access_type=offline"
Expand All @@ -34,6 +36,10 @@ type UserInfo struct {
Email string `json:"email"`
}

type HostedDomain struct {
HostedDomain string `json:"hd"`
}

type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
Expand Down Expand Up @@ -99,6 +105,34 @@ func getUserEmail(accessToken string) (string, error) {
return ui.Email, nil
}

func getHostedDomain(accessToken string) (string, error) {
uri, _ := url.Parse(userInfoURL)
q := uri.Query()
q.Set("alt", "json")
q.Set("access_token", accessToken)
uri.RawQuery = q.Encode()
resp, err := http.Get(uri.String())
if err != nil {
return "", err
}
defer func() {
io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}()
if resp.StatusCode != 200 {
return "", fmt.Errorf("Got: %d calling %s", resp.StatusCode, tokenURL)
}
if err != nil {
return "", err
}
hd := &HostedDomain{}
err = json.NewDecoder(resp.Body).Decode(hd)
if err != nil {
return "", err
}
return hd.HostedDomain, nil
}

func googleRedirect() http.Handler {
redirectURL := fmt.Sprintf(oauthURL, callbackURL, clientID)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -122,6 +156,18 @@ func googleCallback() http.Handler {
w.WriteHeader(http.StatusInternalServerError)
}

hostedDomain, err := getHostedDomain(tokResponse.AccessToken)
if err != nil {
log.Printf("Error getting user hosted domain: %s\n", err)
w.WriteHeader(http.StatusInternalServerError)
}

if ! strings.EqualFold(hostedDomain, expectedHostedDomain) {
log.Printf("Error hosted domain does not match (was %s instead of %s)\n", hostedDomain, expectedHostedDomain)
http.Error(w, "Forbidden", 403)
return
}

kubectlCMD := fmt.Sprintf(kubectlCMDTemplate, email, clientID, clientSecret, tokResponse.IdToken, idpIssuerURL, tokResponse.RefreshToken)

w.WriteHeader(http.StatusOK)
Expand Down