Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions catalog/rest/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package rest

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
)

// AuthManager is an interface for providing custom authorization headers.
type AuthManager interface {
// AuthHeader returns the key and value for the authorization header.
AuthHeader() (string, string, error)
}

type oauthTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}

type oauthErrorResponse struct {
Err string `json:"error"`
ErrDesc string `json:"error_description"`
ErrURI string `json:"error_uri"`
}

func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError }
func (o oauthErrorResponse) Error() string {
msg := o.Err
if o.ErrDesc != "" {
msg += ": " + o.ErrDesc
}

if o.ErrURI != "" {
msg += " (" + o.ErrURI + ")"
}

return msg
}

// Oauth2AuthManager is an implementation of the AuthManager interface which
// simply returns the provided token as a bearer token. If a credential
// is provided instead of a static token, it will fetch and refresh the
// token as needed.
type Oauth2AuthManager struct {
Token string
Credential string

AuthURI *url.URL
Scope string
Client *http.Client
}

// AuthHeader returns the authorization header with the bearer token.
func (o *Oauth2AuthManager) AuthHeader() (string, string, error) {
if o.Token == "" && o.Credential != "" {
if o.Client == nil {
return "", "", fmt.Errorf("%w: cannot fetch token without http client", ErrRESTError)
}

tok, err := o.fetchAccessToken()
if err != nil {
return "", "", err
}
o.Token = tok
}
return "Authorization", "Bearer " + o.Token, nil

Check failure on line 88 in catalog/rest/auth.go

View workflow job for this annotation

GitHub Actions / windows-latest go1.23.6

return with no blank line before (nlreturn)

Check failure on line 88 in catalog/rest/auth.go

View workflow job for this annotation

GitHub Actions / macos-latest go1.23.6

return with no blank line before (nlreturn)

Check failure on line 88 in catalog/rest/auth.go

View workflow job for this annotation

GitHub Actions / ubuntu-latest go1.23.6

return with no blank line before (nlreturn)
}

func (o *Oauth2AuthManager) fetchAccessToken() (string, error) {
clientID, clientSecret, hasID := strings.Cut(o.Credential, ":")
if !hasID {
clientID, clientSecret = "", o.Credential
}

scope := "catalog"
if o.Scope != "" {
scope = o.Scope
}
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {clientID},
"client_secret": {clientSecret},
"scope": {scope},
}

if o.AuthURI == nil {
return "", fmt.Errorf("%w: missing auth uri for fetching token", ErrRESTError)
}

rsp, err := o.Client.PostForm(o.AuthURI.String(), data)
if err != nil {
return "", err
}

if rsp.StatusCode == http.StatusOK {
defer rsp.Body.Close()
dec := json.NewDecoder(rsp.Body)
var tok oauthTokenResponse
if err := dec.Decode(&tok); err != nil {
return "", fmt.Errorf("failed to decode oauth token response: %w", err)
}

return tok.AccessToken, nil
}

switch rsp.StatusCode {
case http.StatusUnauthorized, http.StatusBadRequest:
defer rsp.Body.Close()
dec := json.NewDecoder(rsp.Body)
var oauthErr oauthErrorResponse
if err := dec.Decode(&oauthErr); err != nil {
return "", fmt.Errorf("failed to decode oauth error: %w", err)
}

return "", oauthErr
default:
return "", handleNon200(rsp, nil)
}
}
8 changes: 6 additions & 2 deletions catalog/rest/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ func WithCredential(cred string) Option {
}

func WithOAuthToken(token string) Option {
return WithAuthManager(&Oauth2AuthManager{Token: token})
}

func WithAuthManager(authManager AuthManager) Option {
return func(o *options) {
o.oauthToken = token
o.authManager = authManager
}
}

Expand Down Expand Up @@ -122,7 +126,7 @@ type options struct {
awsConfigSet bool
tlsConfig *tls.Config
credential string
oauthToken string
authManager AuthManager
warehouseLocation string
metadataLocation string
enableSigv4 bool
Expand Down
114 changes: 21 additions & 93 deletions catalog/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,34 +160,6 @@ type createTableRequest struct {
Props iceberg.Properties `json:"properties,omitempty"`
}

type oauthTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}

type oauthErrorResponse struct {
Err string `json:"error"`
ErrDesc string `json:"error_description"`
ErrURI string `json:"error_uri"`
}

func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError }
func (o oauthErrorResponse) Error() string {
msg := o.Err
if o.ErrDesc != "" {
msg += ": " + o.ErrDesc
}

if o.ErrURI != "" {
msg += " (" + o.ErrURI + ")"
}

return msg
}

type configResponse struct {
Defaults iceberg.Properties `json:"defaults"`
Overrides iceberg.Properties `json:"overrides"`
Expand Down Expand Up @@ -380,8 +352,6 @@ func handleNon200(rsp *http.Response, override map[int]error) error {
func fromProps(props iceberg.Properties, o *options) {
for k, v := range props {
switch k {
case keyOauthToken:
o.oauthToken = v
case keyWarehouseLocation:
o.warehouseLocation = v
case keyMetadataLocation:
Expand Down Expand Up @@ -436,7 +406,6 @@ func toProps(o *options) iceberg.Properties {
}

setIf(keyOauthCredential, o.credential)
setIf(keyOauthToken, o.oauthToken)
setIf(keyWarehouseLocation, o.warehouseLocation)
setIf(keyMetadataLocation, o.metadataLocation)
if o.enableSigv4 {
Expand Down Expand Up @@ -510,59 +479,6 @@ func (r *Catalog) init(ctx context.Context, ops *options, uri string) error {
return nil
}

func (r *Catalog) fetchAccessToken(cl *http.Client, creds string, opts *options) (string, error) {
clientID, clientSecret, hasID := strings.Cut(creds, ":")
if !hasID {
clientID, clientSecret = "", clientID
}

scope := "catalog"
if opts.scope != "" {
scope = opts.scope
}
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {clientID},
"client_secret": {clientSecret},
"scope": {scope},
}

uri := opts.authUri
if uri == nil {
uri = r.baseURI.JoinPath("oauth/tokens")
}

rsp, err := cl.PostForm(uri.String(), data)
if err != nil {
return "", err
}

if rsp.StatusCode == http.StatusOK {
defer rsp.Body.Close()
dec := json.NewDecoder(rsp.Body)
var tok oauthTokenResponse
if err := dec.Decode(&tok); err != nil {
return "", fmt.Errorf("failed to decode oauth token response: %w", err)
}

return tok.AccessToken, nil
}

switch rsp.StatusCode {
case http.StatusUnauthorized, http.StatusBadRequest:
defer rsp.Request.GetBody()
dec := json.NewDecoder(rsp.Body)
var oauthErr oauthErrorResponse
if err := dec.Decode(&oauthErr); err != nil {
return "", fmt.Errorf("failed to decode oauth error: %w", err)
}

return "", oauthErr
default:
return "", handleNon200(rsp, nil)
}
}

func (r *Catalog) createSession(ctx context.Context, opts *options) (*http.Client, error) {
session := &sessionTransport{
defaultHeaders: http.Header{},
Expand All @@ -574,23 +490,35 @@ func (r *Catalog) createSession(ctx context.Context, opts *options) (*http.Clien
}
cl := &http.Client{Transport: session}

token := opts.oauthToken
if token == "" && opts.credential != "" {
var err error
if token, err = r.fetchAccessToken(cl, opts.credential, opts); err != nil {
return nil, fmt.Errorf("auth error: %w", err)
}
}
if opts.credential != "" {
if _, ok := opts.authManager.(*Oauth2AuthManager); !ok {
authURI := opts.authUri
if authURI == nil {
authURI = r.baseURI.JoinPath("oauth/tokens")
}

if token != "" {
session.defaultHeaders.Set(authorizationHeader, bearerPrefix+" "+token)
opts.authManager = &Oauth2AuthManager{
Credential: opts.credential,
AuthURI: authURI,
Scope: opts.scope,
Client: cl,
}
}
}

session.defaultHeaders.Set("X-Client-Version", icebergRestSpecVersion)
session.defaultHeaders.Set("Content-Type", "application/json")
session.defaultHeaders.Set("User-Agent", "GoIceberg/"+iceberg.Version())
session.defaultHeaders.Set("X-Iceberg-Access-Delegation", "vended-credentials")

if opts.authManager != nil {
k, v, err := opts.authManager.AuthHeader()
if err != nil {
return nil, err
}
session.defaultHeaders.Set(k, v)
}

if opts.enableSigv4 {
cfg := opts.awsConfig
if !opts.awsConfigSet {
Expand Down
Loading