diff --git a/connector.go b/connector.go index fce77970..dca595c2 100644 --- a/connector.go +++ b/connector.go @@ -260,6 +260,12 @@ func WithAuthenticator(authr auth.Authenticator) ConnOption { func WithTransport(t http.RoundTripper) ConnOption { return func(c *config.Config) { c.Transport = t + + if c.CloudFetchConfig.HTTPClient == nil { + c.CloudFetchConfig.HTTPClient = &http.Client{ + Transport: t, + } + } } } diff --git a/connector_test.go b/connector_test.go index 57554b98..bba5db1f 100644 --- a/connector_test.go +++ b/connector_test.go @@ -48,6 +48,7 @@ func TestNewConnector(t *testing.T) { MaxFilesInMemory: 10, MinTimeToExpiry: 0 * time.Second, CloudFetchSpeedThresholdMbps: 0.1, + HTTPClient: &http.Client{Transport: roundTripper}, } expectedUserConfig := config.UserConfig{ Host: host, @@ -246,6 +247,25 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) assert.False(t, coni.cfg.EnableMetricViewMetadata) }) + + t.Run("Connector test WithTransport sets HTTPClient in CloudFetchConfig", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + customTransport := &http.Transport{MaxIdleConns: 10} + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + WithTransport(customTransport), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.NotNil(t, coni.cfg.CloudFetchConfig.HTTPClient) + assert.Equal(t, customTransport, coni.cfg.CloudFetchConfig.HTTPClient.Transport) + }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index 67437a9c..e13cb98f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -479,6 +479,7 @@ type CloudFetchConfig struct { MaxFilesInMemory int MinTimeToExpiry time.Duration CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1) + HTTPClient *http.Client } func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index d26d8a4a..e12ea4e6 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -34,12 +34,18 @@ func NewCloudIPCStreamIterator( startRowOffset int64, cfg *config.Config, ) (IPCStreamIterator, dbsqlerr.DBError) { + httpClient := http.DefaultClient + if cfg.UserConfig.CloudFetchConfig.HTTPClient != nil { + httpClient = cfg.UserConfig.CloudFetchConfig.HTTPClient + } + bi := &cloudIPCStreamIterator{ ctx: ctx, cfg: cfg, startRowOffset: startRowOffset, pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), downloadTasks: NewQueue[cloudFetchDownloadTask](), + httpClient: httpClient, } for _, link := range files { @@ -140,6 +146,7 @@ type cloudIPCStreamIterator struct { startRowOffset int64 pendingLinks Queue[cli_service.TSparkArrowResultLink] downloadTasks Queue[cloudFetchDownloadTask] + httpClient *http.Client } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -162,6 +169,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, + httpClient: bi.httpClient, } task.Run() bi.downloadTasks.Enqueue(task) @@ -210,6 +218,7 @@ type cloudFetchDownloadTask struct { link *cli_service.TSparkArrowResultLink resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 + httpClient *http.Client } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { @@ -252,7 +261,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -300,6 +309,7 @@ func fetchBatchBytes( link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, + httpClient *http.Client, ) (io.ReadCloser, error) { if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) @@ -318,8 +328,7 @@ func fetchBatchBytes( } startTime := time.Now() - client := http.DefaultClient - res, err := client.Do(req) + res, err := httpClient.Do(req) if err != nil { return nil, err } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index b018eb6d..99538bbc 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -253,6 +253,85 @@ func TestCloudFetchIterator(t *testing.T) { assert.NotNil(t, err3) assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) + + t.Run("should use custom HTTPClient when provided", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + customHTTPClient := &http.Client{ + Transport: &http.Transport{MaxIdleConns: 10}, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.UserConfig.CloudFetchConfig.HTTPClient = customHTTPClient + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) + assert.Equal(t, customHTTPClient, cbi.httpClient) + + // Verify fetch works + sab, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab) + }) + + t.Run("should fallback to http.DefaultClient when HTTPClient is nil", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + // Explicitly set HTTPClient to nil to verify fallback behavior + cfg.UserConfig.CloudFetchConfig.HTTPClient = nil + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) + assert.Equal(t, http.DefaultClient, cbi.httpClient) + + // Verify fetch works with default client + sab, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab) + }) } func generateArrowRecord() arrow.Record {