diff --git a/catalog/rest/auth.go b/catalog/rest/auth.go new file mode 100644 index 000000000..ecf464499 --- /dev/null +++ b/catalog/rest/auth.go @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package rest + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" +) + +// AuthManager is an interface for providing custom authorization headers. +type AuthManager interface { + // AuthHeader returns the key and value for the authorization header. + AuthHeader() (string, string, error) +} + +type oauthTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + RefreshToken string `json:"refresh_token"` +} + +type oauthErrorResponse struct { + Err string `json:"error"` + ErrDesc string `json:"error_description"` + ErrURI string `json:"error_uri"` +} + +func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError } +func (o oauthErrorResponse) Error() string { + msg := o.Err + if o.ErrDesc != "" { + msg += ": " + o.ErrDesc + } + + if o.ErrURI != "" { + msg += " (" + o.ErrURI + ")" + } + + return msg +} + +// Oauth2AuthManager is an implementation of the AuthManager interface which +// simply returns the provided token as a bearer token. If a credential +// is provided instead of a static token, it will fetch and refresh the +// token as needed. +type Oauth2AuthManager struct { + Token string + Credential string + + AuthURI *url.URL + Scope string + Client *http.Client +} + +// AuthHeader returns the authorization header with the bearer token. +func (o *Oauth2AuthManager) AuthHeader() (string, string, error) { + if o.Token == "" && o.Credential != "" { + if o.Client == nil { + return "", "", fmt.Errorf("%w: cannot fetch token without http client", ErrRESTError) + } + + tok, err := o.fetchAccessToken() + if err != nil { + return "", "", err + } + o.Token = tok + } + return "Authorization", "Bearer " + o.Token, nil +} + +func (o *Oauth2AuthManager) fetchAccessToken() (string, error) { + clientID, clientSecret, hasID := strings.Cut(o.Credential, ":") + if !hasID { + clientID, clientSecret = "", o.Credential + } + + scope := "catalog" + if o.Scope != "" { + scope = o.Scope + } + data := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {clientID}, + "client_secret": {clientSecret}, + "scope": {scope}, + } + + if o.AuthURI == nil { + return "", fmt.Errorf("%w: missing auth uri for fetching token", ErrRESTError) + } + + rsp, err := o.Client.PostForm(o.AuthURI.String(), data) + if err != nil { + return "", err + } + + if rsp.StatusCode == http.StatusOK { + defer rsp.Body.Close() + dec := json.NewDecoder(rsp.Body) + var tok oauthTokenResponse + if err := dec.Decode(&tok); err != nil { + return "", fmt.Errorf("failed to decode oauth token response: %w", err) + } + + return tok.AccessToken, nil + } + + switch rsp.StatusCode { + case http.StatusUnauthorized, http.StatusBadRequest: + defer rsp.Body.Close() + dec := json.NewDecoder(rsp.Body) + var oauthErr oauthErrorResponse + if err := dec.Decode(&oauthErr); err != nil { + return "", fmt.Errorf("failed to decode oauth error: %w", err) + } + + return "", oauthErr + default: + return "", handleNon200(rsp, nil) + } +} diff --git a/catalog/rest/options.go b/catalog/rest/options.go index c14ec1f23..b2ea4da9e 100644 --- a/catalog/rest/options.go +++ b/catalog/rest/options.go @@ -35,8 +35,12 @@ func WithCredential(cred string) Option { } func WithOAuthToken(token string) Option { + return WithAuthManager(&Oauth2AuthManager{Token: token}) +} + +func WithAuthManager(authManager AuthManager) Option { return func(o *options) { - o.oauthToken = token + o.authManager = authManager } } @@ -122,7 +126,7 @@ type options struct { awsConfigSet bool tlsConfig *tls.Config credential string - oauthToken string + authManager AuthManager warehouseLocation string metadataLocation string enableSigv4 bool diff --git a/catalog/rest/rest.go b/catalog/rest/rest.go index 0c40089bf..3b79a77c3 100644 --- a/catalog/rest/rest.go +++ b/catalog/rest/rest.go @@ -160,34 +160,6 @@ type createTableRequest struct { Props iceberg.Properties `json:"properties,omitempty"` } -type oauthTokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope"` - RefreshToken string `json:"refresh_token"` -} - -type oauthErrorResponse struct { - Err string `json:"error"` - ErrDesc string `json:"error_description"` - ErrURI string `json:"error_uri"` -} - -func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError } -func (o oauthErrorResponse) Error() string { - msg := o.Err - if o.ErrDesc != "" { - msg += ": " + o.ErrDesc - } - - if o.ErrURI != "" { - msg += " (" + o.ErrURI + ")" - } - - return msg -} - type configResponse struct { Defaults iceberg.Properties `json:"defaults"` Overrides iceberg.Properties `json:"overrides"` @@ -380,8 +352,6 @@ func handleNon200(rsp *http.Response, override map[int]error) error { func fromProps(props iceberg.Properties, o *options) { for k, v := range props { switch k { - case keyOauthToken: - o.oauthToken = v case keyWarehouseLocation: o.warehouseLocation = v case keyMetadataLocation: @@ -436,7 +406,6 @@ func toProps(o *options) iceberg.Properties { } setIf(keyOauthCredential, o.credential) - setIf(keyOauthToken, o.oauthToken) setIf(keyWarehouseLocation, o.warehouseLocation) setIf(keyMetadataLocation, o.metadataLocation) if o.enableSigv4 { @@ -510,59 +479,6 @@ func (r *Catalog) init(ctx context.Context, ops *options, uri string) error { return nil } -func (r *Catalog) fetchAccessToken(cl *http.Client, creds string, opts *options) (string, error) { - clientID, clientSecret, hasID := strings.Cut(creds, ":") - if !hasID { - clientID, clientSecret = "", clientID - } - - scope := "catalog" - if opts.scope != "" { - scope = opts.scope - } - data := url.Values{ - "grant_type": {"client_credentials"}, - "client_id": {clientID}, - "client_secret": {clientSecret}, - "scope": {scope}, - } - - uri := opts.authUri - if uri == nil { - uri = r.baseURI.JoinPath("oauth/tokens") - } - - rsp, err := cl.PostForm(uri.String(), data) - if err != nil { - return "", err - } - - if rsp.StatusCode == http.StatusOK { - defer rsp.Body.Close() - dec := json.NewDecoder(rsp.Body) - var tok oauthTokenResponse - if err := dec.Decode(&tok); err != nil { - return "", fmt.Errorf("failed to decode oauth token response: %w", err) - } - - return tok.AccessToken, nil - } - - switch rsp.StatusCode { - case http.StatusUnauthorized, http.StatusBadRequest: - defer rsp.Request.GetBody() - dec := json.NewDecoder(rsp.Body) - var oauthErr oauthErrorResponse - if err := dec.Decode(&oauthErr); err != nil { - return "", fmt.Errorf("failed to decode oauth error: %w", err) - } - - return "", oauthErr - default: - return "", handleNon200(rsp, nil) - } -} - func (r *Catalog) createSession(ctx context.Context, opts *options) (*http.Client, error) { session := &sessionTransport{ defaultHeaders: http.Header{}, @@ -574,16 +490,20 @@ func (r *Catalog) createSession(ctx context.Context, opts *options) (*http.Clien } cl := &http.Client{Transport: session} - token := opts.oauthToken - if token == "" && opts.credential != "" { - var err error - if token, err = r.fetchAccessToken(cl, opts.credential, opts); err != nil { - return nil, fmt.Errorf("auth error: %w", err) - } - } + if opts.credential != "" { + if _, ok := opts.authManager.(*Oauth2AuthManager); !ok { + authURI := opts.authUri + if authURI == nil { + authURI = r.baseURI.JoinPath("oauth/tokens") + } - if token != "" { - session.defaultHeaders.Set(authorizationHeader, bearerPrefix+" "+token) + opts.authManager = &Oauth2AuthManager{ + Credential: opts.credential, + AuthURI: authURI, + Scope: opts.scope, + Client: cl, + } + } } session.defaultHeaders.Set("X-Client-Version", icebergRestSpecVersion) @@ -591,6 +511,14 @@ func (r *Catalog) createSession(ctx context.Context, opts *options) (*http.Clien session.defaultHeaders.Set("User-Agent", "GoIceberg/"+iceberg.Version()) session.defaultHeaders.Set("X-Iceberg-Access-Delegation", "vended-credentials") + if opts.authManager != nil { + k, v, err := opts.authManager.AuthHeader() + if err != nil { + return nil, err + } + session.defaultHeaders.Set(k, v) + } + if opts.enableSigv4 { cfg := opts.awsConfig if !opts.awsConfigSet {