From 456596e4ec63ef104df97fb2b2888d39ce68d937 Mon Sep 17 00:00:00 2001 From: multimo Date: Mon, 24 Nov 2025 11:29:14 +0100 Subject: [PATCH 1/7] Cloudfetch: Allow configuration of httpclient for cloudfetch --- connector.go | 7 ++ connector_test.go | 36 ++++++++ internal/config/config.go | 1 + internal/rows/arrowbased/batchloader.go | 14 ++- internal/rows/arrowbased/batchloader_test.go | 97 ++++++++++++++++++++ 5 files changed, 152 insertions(+), 3 deletions(-) diff --git a/connector.go b/connector.go index fce77970..74038681 100644 --- a/connector.go +++ b/connector.go @@ -270,6 +270,13 @@ func WithCloudFetch(useCloudFetch bool) ConnOption { } } +// WithCloudFetchHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient. +func WithCloudFetchHTTPClient(httpClient *http.Client) ConnOption { + return func(c *config.Config) { + c.UserConfig.CloudFetchConfig.HTTPClient = httpClient + } +} + // WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10. func WithMaxDownloadThreads(numThreads int) ConnOption { return func(c *config.Config) { diff --git a/connector_test.go b/connector_test.go index 57554b98..eadfc827 100644 --- a/connector_test.go +++ b/connector_test.go @@ -246,6 +246,42 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) assert.False(t, coni.cfg.EnableMetricViewMetadata) }) + + t.Run("Connector test WithCloudFetchHTTPClient sets custom client", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + customClient := &http.Client{Timeout: 5 * time.Second} + + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + WithCloudFetchHTTPClient(customClient), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.Equal(t, customClient, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient) + }) + + t.Run("Connector test WithCloudFetchHTTPClient with nil client is accepted", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.Nil(t, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient) + }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index 67437a9c..e13cb98f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -479,6 +479,7 @@ type CloudFetchConfig struct { MaxFilesInMemory int MinTimeToExpiry time.Duration CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1) + HTTPClient *http.Client } func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index d26d8a4a..4d718b9e 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -40,6 +40,7 @@ func NewCloudIPCStreamIterator( startRowOffset: startRowOffset, pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), downloadTasks: NewQueue[cloudFetchDownloadTask](), + httpClient: cfg.UserConfig.CloudFetchConfig.HTTPClient, } for _, link := range files { @@ -140,6 +141,7 @@ type cloudIPCStreamIterator struct { startRowOffset int64 pendingLinks Queue[cli_service.TSparkArrowResultLink] downloadTasks Queue[cloudFetchDownloadTask] + httpClient *http.Client } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -162,6 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, + httpClient: bi.httpClient, } task.Run() bi.downloadTasks.Enqueue(task) @@ -210,6 +213,7 @@ type cloudFetchDownloadTask struct { link *cli_service.TSparkArrowResultLink resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 + httpClient *http.Client } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { @@ -252,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -300,6 +304,7 @@ func fetchBatchBytes( link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, + httpClient *http.Client, ) (io.ReadCloser, error) { if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) @@ -317,9 +322,12 @@ func fetchBatchBytes( } } + if httpClient == nil { + httpClient = http.DefaultClient + } + startTime := time.Now() - client := http.DefaultClient - res, err := client.Do(req) + res, err := httpClient.Do(req) if err != nil { return nil, err } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index b018eb6d..c30e0e0b 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -253,6 +253,103 @@ func TestCloudFetchIterator(t *testing.T) { assert.NotNil(t, err3) assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) + + t.Run("should use custom HTTPClient when provided", func(t *testing.T) { + customClient := &http.Client{Timeout: 5 * time.Second} + requestCount := 0 + + handler = func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.UserConfig.CloudFetchConfig.HTTPClient = customClient + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + // Verify custom client is passed through the iterator chain + wrapper, ok := bi.(*batchIterator) + assert.True(t, ok) + cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) + assert.True(t, ok) + assert.Equal(t, customClient, cbi.httpClient) + + // Fetch should work with custom client + sab1, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab1) + assert.Greater(t, requestCount, 0) // Verify request was made + }) + + t.Run("should use http.DefaultClient when HTTPClient is nil", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + // HTTPClient is nil by default + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + // Verify nil client is passed through + wrapper, ok := bi.(*batchIterator) + assert.True(t, ok) + cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) + assert.True(t, ok) + assert.Nil(t, cbi.httpClient) + + // Fetch should work (falls back to http.DefaultClient) + sab1, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab1) + }) } func generateArrowRecord() arrow.Record { From 026ac42dfaafffb54bf0477129b664532348823d Mon Sep 17 00:00:00 2001 From: multimo Date: Tue, 25 Nov 2025 14:13:02 +0100 Subject: [PATCH 2/7] rename withHttpClient --- connector.go | 4 ++-- connector_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/connector.go b/connector.go index 74038681..ed21b71f 100644 --- a/connector.go +++ b/connector.go @@ -270,8 +270,8 @@ func WithCloudFetch(useCloudFetch bool) ConnOption { } } -// WithCloudFetchHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient. -func WithCloudFetchHTTPClient(httpClient *http.Client) ConnOption { +// WithHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient. +func WithHTTPClient(httpClient *http.Client) ConnOption { return func(c *config.Config) { c.UserConfig.CloudFetchConfig.HTTPClient = httpClient } diff --git a/connector_test.go b/connector_test.go index eadfc827..a5e8632f 100644 --- a/connector_test.go +++ b/connector_test.go @@ -257,7 +257,7 @@ func TestNewConnector(t *testing.T) { WithServerHostname(host), WithAccessToken(accessToken), WithHTTPPath(httpPath), - WithCloudFetchHTTPClient(customClient), + WithHTTPClient(customClient), ) assert.Nil(t, err) From 8a47b9f859d61c955446675fa58ab9819335b118 Mon Sep 17 00:00:00 2001 From: multimo Date: Tue, 9 Dec 2025 15:16:16 +0100 Subject: [PATCH 3/7] Swap implementation to use existing transport setting --- connector.go | 7 ---- connector_test.go | 36 -------------------- internal/config/config.go | 2 +- internal/rows/arrowbased/batchloader.go | 19 ++++++----- internal/rows/arrowbased/batchloader_test.go | 23 +++++++------ 5 files changed, 25 insertions(+), 62 deletions(-) diff --git a/connector.go b/connector.go index ed21b71f..fce77970 100644 --- a/connector.go +++ b/connector.go @@ -270,13 +270,6 @@ func WithCloudFetch(useCloudFetch bool) ConnOption { } } -// WithHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient. -func WithHTTPClient(httpClient *http.Client) ConnOption { - return func(c *config.Config) { - c.UserConfig.CloudFetchConfig.HTTPClient = httpClient - } -} - // WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10. func WithMaxDownloadThreads(numThreads int) ConnOption { return func(c *config.Config) { diff --git a/connector_test.go b/connector_test.go index a5e8632f..57554b98 100644 --- a/connector_test.go +++ b/connector_test.go @@ -246,42 +246,6 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) assert.False(t, coni.cfg.EnableMetricViewMetadata) }) - - t.Run("Connector test WithCloudFetchHTTPClient sets custom client", func(t *testing.T) { - host := "databricks-host" - accessToken := "token" - httpPath := "http-path" - customClient := &http.Client{Timeout: 5 * time.Second} - - con, err := NewConnector( - WithServerHostname(host), - WithAccessToken(accessToken), - WithHTTPPath(httpPath), - WithHTTPClient(customClient), - ) - assert.Nil(t, err) - - coni, ok := con.(*connector) - require.True(t, ok) - assert.Equal(t, customClient, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient) - }) - - t.Run("Connector test WithCloudFetchHTTPClient with nil client is accepted", func(t *testing.T) { - host := "databricks-host" - accessToken := "token" - httpPath := "http-path" - - con, err := NewConnector( - WithServerHostname(host), - WithAccessToken(accessToken), - WithHTTPPath(httpPath), - ) - assert.Nil(t, err) - - coni, ok := con.(*connector) - require.True(t, ok) - assert.Nil(t, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient) - }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index e13cb98f..6956ab37 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -479,7 +479,7 @@ type CloudFetchConfig struct { MaxFilesInMemory int MinTimeToExpiry time.Duration CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1) - HTTPClient *http.Client + Transport http.RoundTripper } func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 4d718b9e..545fe9c0 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -40,7 +40,7 @@ func NewCloudIPCStreamIterator( startRowOffset: startRowOffset, pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), downloadTasks: NewQueue[cloudFetchDownloadTask](), - httpClient: cfg.UserConfig.CloudFetchConfig.HTTPClient, + transport: cfg.UserConfig.Transport, } for _, link := range files { @@ -141,7 +141,7 @@ type cloudIPCStreamIterator struct { startRowOffset int64 pendingLinks Queue[cli_service.TSparkArrowResultLink] downloadTasks Queue[cloudFetchDownloadTask] - httpClient *http.Client + transport http.RoundTripper } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -164,7 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, - httpClient: bi.httpClient, + transport: bi.transport, } task.Run() bi.downloadTasks.Enqueue(task) @@ -213,7 +213,7 @@ type cloudFetchDownloadTask struct { link *cli_service.TSparkArrowResultLink resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 - httpClient *http.Client + transport http.RoundTripper } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { @@ -256,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.transport) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -304,7 +304,7 @@ func fetchBatchBytes( link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, - httpClient *http.Client, + transport http.RoundTripper, ) (io.ReadCloser, error) { if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) @@ -322,8 +322,11 @@ func fetchBatchBytes( } } - if httpClient == nil { - httpClient = http.DefaultClient + httpClient := http.DefaultClient + if transport != nil { + httpClient = &http.Client{ + Transport: transport, + } } startTime := time.Now() diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index c30e0e0b..e5fb12bc 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -254,8 +254,11 @@ func TestCloudFetchIterator(t *testing.T) { assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) - t.Run("should use custom HTTPClient when provided", func(t *testing.T) { - customClient := &http.Client{Timeout: 5 * time.Second} + t.Run("should use custom Transport when provided", func(t *testing.T) { + customTransport := &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 5, + } requestCount := 0 handler = func(w http.ResponseWriter, r *http.Request) { @@ -281,7 +284,7 @@ func TestCloudFetchIterator(t *testing.T) { cfg := config.WithDefaults() cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - cfg.UserConfig.CloudFetchConfig.HTTPClient = customClient + cfg.UserConfig.Transport = customTransport bi, err := NewCloudBatchIterator( context.Background(), @@ -291,21 +294,21 @@ func TestCloudFetchIterator(t *testing.T) { ) assert.Nil(t, err) - // Verify custom client is passed through the iterator chain + // Verify custom transport is passed through the iterator chain wrapper, ok := bi.(*batchIterator) assert.True(t, ok) cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) assert.True(t, ok) - assert.Equal(t, customClient, cbi.httpClient) + assert.Equal(t, customTransport, cbi.transport) - // Fetch should work with custom client + // Fetch should work with custom transport sab1, nextErr := bi.Next() assert.Nil(t, nextErr) assert.NotNil(t, sab1) assert.Greater(t, requestCount, 0) // Verify request was made }) - t.Run("should use http.DefaultClient when HTTPClient is nil", func(t *testing.T) { + t.Run("should use http.DefaultClient when Transport is nil", func(t *testing.T) { handler = func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) @@ -328,7 +331,7 @@ func TestCloudFetchIterator(t *testing.T) { cfg := config.WithDefaults() cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - // HTTPClient is nil by default + // Transport is nil by default bi, err := NewCloudBatchIterator( context.Background(), @@ -338,12 +341,12 @@ func TestCloudFetchIterator(t *testing.T) { ) assert.Nil(t, err) - // Verify nil client is passed through + // Verify nil transport is passed through wrapper, ok := bi.(*batchIterator) assert.True(t, ok) cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) assert.True(t, ok) - assert.Nil(t, cbi.httpClient) + assert.Nil(t, cbi.transport) // Fetch should work (falls back to http.DefaultClient) sab1, nextErr := bi.Next() From 934d2b50b4efc9b3f97c2d6d5a64ef11e32ee868 Mon Sep 17 00:00:00 2001 From: multimo Date: Tue, 9 Dec 2025 15:29:25 +0100 Subject: [PATCH 4/7] simplify tests --- internal/rows/arrowbased/batchloader_test.go | 79 ++++++-------------- 1 file changed, 24 insertions(+), 55 deletions(-) diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index e5fb12bc..70636fa8 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -255,31 +255,13 @@ func TestCloudFetchIterator(t *testing.T) { }) t.Run("should use custom Transport when provided", func(t *testing.T) { - customTransport := &http.Transport{ - MaxIdleConns: 10, - MaxIdleConnsPerHost: 5, - } - requestCount := 0 - handler = func(w http.ResponseWriter, r *http.Request) { - requestCount++ w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) - if err != nil { - panic(err) - } + w.Write(generateMockArrowBytes(generateArrowRecord())) } startRowOffset := int64(100) - - links := []*cli_service.TSparkArrowResultLink{ - { - FileLink: server.URL, - ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), - StartRowOffset: startRowOffset, - RowCount: 1, - }, - } + customTransport := &http.Transport{MaxIdleConns: 10} cfg := config.WithDefaults() cfg.UseLz4Compression = false @@ -288,70 +270,57 @@ func TestCloudFetchIterator(t *testing.T) { bi, err := NewCloudBatchIterator( context.Background(), - links, + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, startRowOffset, cfg, ) assert.Nil(t, err) - // Verify custom transport is passed through the iterator chain - wrapper, ok := bi.(*batchIterator) - assert.True(t, ok) - cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) - assert.True(t, ok) + cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) assert.Equal(t, customTransport, cbi.transport) - // Fetch should work with custom transport - sab1, nextErr := bi.Next() + // Verify fetch works + sab, nextErr := bi.Next() assert.Nil(t, nextErr) - assert.NotNil(t, sab1) - assert.Greater(t, requestCount, 0) // Verify request was made + assert.NotNil(t, sab) }) - t.Run("should use http.DefaultClient when Transport is nil", func(t *testing.T) { + t.Run("should fallback to http.DefaultClient when Transport is nil", func(t *testing.T) { handler = func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) - if err != nil { - panic(err) - } + w.Write(generateMockArrowBytes(generateArrowRecord())) } startRowOffset := int64(100) - - links := []*cli_service.TSparkArrowResultLink{ - { - FileLink: server.URL, - ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), - StartRowOffset: startRowOffset, - RowCount: 1, - }, - } - cfg := config.WithDefaults() cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - // Transport is nil by default bi, err := NewCloudBatchIterator( context.Background(), - links, + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, startRowOffset, cfg, ) assert.Nil(t, err) - // Verify nil transport is passed through - wrapper, ok := bi.(*batchIterator) - assert.True(t, ok) - cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) - assert.True(t, ok) + cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) assert.Nil(t, cbi.transport) - // Fetch should work (falls back to http.DefaultClient) - sab1, nextErr := bi.Next() + // Verify fetch works with default client + sab, nextErr := bi.Next() assert.Nil(t, nextErr) - assert.NotNil(t, sab1) + assert.NotNil(t, sab) }) } From 510c08bada03b77059a86549ffb79573cd91037e Mon Sep 17 00:00:00 2001 From: multimo Date: Tue, 16 Dec 2025 16:50:11 +0100 Subject: [PATCH 5/7] move the lifetime of the client to the IPCStreamIterator --- internal/rows/arrowbased/batchloader.go | 27 ++++++++++---------- internal/rows/arrowbased/batchloader_test.go | 5 ++-- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 545fe9c0..011db316 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -34,13 +34,21 @@ func NewCloudIPCStreamIterator( startRowOffset int64, cfg *config.Config, ) (IPCStreamIterator, dbsqlerr.DBError) { + transport := cfg.UserConfig.Transport + httpClient := http.DefaultClient + if transport != nil { + httpClient = &http.Client{ + Transport: transport, + } + } + bi := &cloudIPCStreamIterator{ ctx: ctx, cfg: cfg, startRowOffset: startRowOffset, pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), downloadTasks: NewQueue[cloudFetchDownloadTask](), - transport: cfg.UserConfig.Transport, + httpClient: httpClient, } for _, link := range files { @@ -141,7 +149,7 @@ type cloudIPCStreamIterator struct { startRowOffset int64 pendingLinks Queue[cli_service.TSparkArrowResultLink] downloadTasks Queue[cloudFetchDownloadTask] - transport http.RoundTripper + httpClient *http.Client } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -164,7 +172,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, - transport: bi.transport, + httpClient: bi.httpClient, } task.Run() bi.downloadTasks.Enqueue(task) @@ -213,7 +221,7 @@ type cloudFetchDownloadTask struct { link *cli_service.TSparkArrowResultLink resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 - transport http.RoundTripper + httpClient *http.Client } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { @@ -256,7 +264,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.transport) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -304,7 +312,7 @@ func fetchBatchBytes( link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, - transport http.RoundTripper, + httpClient *http.Client, ) (io.ReadCloser, error) { if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) @@ -322,13 +330,6 @@ func fetchBatchBytes( } } - httpClient := http.DefaultClient - if transport != nil { - httpClient = &http.Client{ - Transport: transport, - } - } - startTime := time.Now() res, err := httpClient.Do(req) if err != nil { diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 70636fa8..80da8206 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -282,7 +282,8 @@ func TestCloudFetchIterator(t *testing.T) { assert.Nil(t, err) cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) - assert.Equal(t, customTransport, cbi.transport) + assert.NotNil(t, cbi.httpClient) + assert.Equal(t, customTransport, cbi.httpClient.Transport) // Verify fetch works sab, nextErr := bi.Next() @@ -315,7 +316,7 @@ func TestCloudFetchIterator(t *testing.T) { assert.Nil(t, err) cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) - assert.Nil(t, cbi.transport) + assert.Equal(t, http.DefaultClient, cbi.httpClient) // Verify fetch works with default client sab, nextErr := bi.Next() From 1d447233ad415fd9ff39ccfdaa991f84e9d45287 Mon Sep 17 00:00:00 2001 From: multimo Date: Tue, 16 Dec 2025 17:05:17 +0100 Subject: [PATCH 6/7] move the http client instance to the cloudfetch config --- connector.go | 6 ++++++ connector_test.go | 20 ++++++++++++++++++++ internal/config/config.go | 2 +- internal/rows/arrowbased/batchloader.go | 7 ++----- internal/rows/arrowbased/batchloader_test.go | 15 +++++++++------ 5 files changed, 38 insertions(+), 12 deletions(-) diff --git a/connector.go b/connector.go index fce77970..dca595c2 100644 --- a/connector.go +++ b/connector.go @@ -260,6 +260,12 @@ func WithAuthenticator(authr auth.Authenticator) ConnOption { func WithTransport(t http.RoundTripper) ConnOption { return func(c *config.Config) { c.Transport = t + + if c.CloudFetchConfig.HTTPClient == nil { + c.CloudFetchConfig.HTTPClient = &http.Client{ + Transport: t, + } + } } } diff --git a/connector_test.go b/connector_test.go index 57554b98..bba5db1f 100644 --- a/connector_test.go +++ b/connector_test.go @@ -48,6 +48,7 @@ func TestNewConnector(t *testing.T) { MaxFilesInMemory: 10, MinTimeToExpiry: 0 * time.Second, CloudFetchSpeedThresholdMbps: 0.1, + HTTPClient: &http.Client{Transport: roundTripper}, } expectedUserConfig := config.UserConfig{ Host: host, @@ -246,6 +247,25 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) assert.False(t, coni.cfg.EnableMetricViewMetadata) }) + + t.Run("Connector test WithTransport sets HTTPClient in CloudFetchConfig", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + customTransport := &http.Transport{MaxIdleConns: 10} + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + WithTransport(customTransport), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.NotNil(t, coni.cfg.CloudFetchConfig.HTTPClient) + assert.Equal(t, customTransport, coni.cfg.CloudFetchConfig.HTTPClient.Transport) + }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index 6956ab37..e13cb98f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -479,7 +479,7 @@ type CloudFetchConfig struct { MaxFilesInMemory int MinTimeToExpiry time.Duration CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1) - Transport http.RoundTripper + HTTPClient *http.Client } func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 011db316..e12ea4e6 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -34,12 +34,9 @@ func NewCloudIPCStreamIterator( startRowOffset int64, cfg *config.Config, ) (IPCStreamIterator, dbsqlerr.DBError) { - transport := cfg.UserConfig.Transport httpClient := http.DefaultClient - if transport != nil { - httpClient = &http.Client{ - Transport: transport, - } + if cfg.UserConfig.CloudFetchConfig.HTTPClient != nil { + httpClient = cfg.UserConfig.CloudFetchConfig.HTTPClient } bi := &cloudIPCStreamIterator{ diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 80da8206..ef7823a0 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -254,19 +254,21 @@ func TestCloudFetchIterator(t *testing.T) { assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) - t.Run("should use custom Transport when provided", func(t *testing.T) { + t.Run("should use custom HTTPClient when provided", func(t *testing.T) { handler = func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(generateMockArrowBytes(generateArrowRecord())) } startRowOffset := int64(100) - customTransport := &http.Transport{MaxIdleConns: 10} + customHTTPClient := &http.Client{ + Transport: &http.Transport{MaxIdleConns: 10}, + } cfg := config.WithDefaults() cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - cfg.UserConfig.Transport = customTransport + cfg.UserConfig.CloudFetchConfig.HTTPClient = customHTTPClient bi, err := NewCloudBatchIterator( context.Background(), @@ -282,8 +284,7 @@ func TestCloudFetchIterator(t *testing.T) { assert.Nil(t, err) cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) - assert.NotNil(t, cbi.httpClient) - assert.Equal(t, customTransport, cbi.httpClient.Transport) + assert.Equal(t, customHTTPClient, cbi.httpClient) // Verify fetch works sab, nextErr := bi.Next() @@ -291,7 +292,7 @@ func TestCloudFetchIterator(t *testing.T) { assert.NotNil(t, sab) }) - t.Run("should fallback to http.DefaultClient when Transport is nil", func(t *testing.T) { + t.Run("should fallback to http.DefaultClient when HTTPClient is nil", func(t *testing.T) { handler = func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(generateMockArrowBytes(generateArrowRecord())) @@ -301,6 +302,8 @@ func TestCloudFetchIterator(t *testing.T) { cfg := config.WithDefaults() cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 + // Explicitly set HTTPClient to nil to verify fallback behavior + cfg.UserConfig.CloudFetchConfig.HTTPClient = nil bi, err := NewCloudBatchIterator( context.Background(), From 492c0c544b27ba70ca7f0a9116cd9346f85d1bef Mon Sep 17 00:00:00 2001 From: multimo Date: Wed, 17 Dec 2025 10:46:02 +0100 Subject: [PATCH 7/7] lint fix --- internal/rows/arrowbased/batchloader_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index ef7823a0..99538bbc 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -257,7 +257,10 @@ func TestCloudFetchIterator(t *testing.T) { t.Run("should use custom HTTPClient when provided", func(t *testing.T) { handler = func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write(generateMockArrowBytes(generateArrowRecord())) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } } startRowOffset := int64(100) @@ -295,7 +298,10 @@ func TestCloudFetchIterator(t *testing.T) { t.Run("should fallback to http.DefaultClient when HTTPClient is nil", func(t *testing.T) { handler = func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write(generateMockArrowBytes(generateArrowRecord())) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } } startRowOffset := int64(100)