Skip to content
6 changes: 6 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ func WithAuthenticator(authr auth.Authenticator) ConnOption {
func WithTransport(t http.RoundTripper) ConnOption {
return func(c *config.Config) {
c.Transport = t

if c.CloudFetchConfig.HTTPClient == nil {
c.CloudFetchConfig.HTTPClient = &http.Client{
Transport: t,
}
}
}
}

Expand Down
20 changes: 20 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func TestNewConnector(t *testing.T) {
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
CloudFetchSpeedThresholdMbps: 0.1,
HTTPClient: &http.Client{Transport: roundTripper},
}
expectedUserConfig := config.UserConfig{
Host: host,
Expand Down Expand Up @@ -246,6 +247,25 @@ func TestNewConnector(t *testing.T) {
require.True(t, ok)
assert.False(t, coni.cfg.EnableMetricViewMetadata)
})

t.Run("Connector test WithTransport sets HTTPClient in CloudFetchConfig", func(t *testing.T) {
host := "databricks-host"
accessToken := "token"
httpPath := "http-path"
customTransport := &http.Transport{MaxIdleConns: 10}
con, err := NewConnector(
WithServerHostname(host),
WithAccessToken(accessToken),
WithHTTPPath(httpPath),
WithTransport(customTransport),
)
assert.Nil(t, err)

coni, ok := con.(*connector)
require.True(t, ok)
assert.NotNil(t, coni.cfg.CloudFetchConfig.HTTPClient)
assert.Equal(t, customTransport, coni.cfg.CloudFetchConfig.HTTPClient.Transport)
})
}

type mockRoundTripper struct{}
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ type CloudFetchConfig struct {
MaxFilesInMemory int
MinTimeToExpiry time.Duration
CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1)
HTTPClient *http.Client
}

func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {
Expand Down
15 changes: 12 additions & 3 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ func NewCloudIPCStreamIterator(
startRowOffset int64,
cfg *config.Config,
) (IPCStreamIterator, dbsqlerr.DBError) {
httpClient := http.DefaultClient
if cfg.UserConfig.CloudFetchConfig.HTTPClient != nil {
httpClient = cfg.UserConfig.CloudFetchConfig.HTTPClient
}

bi := &cloudIPCStreamIterator{
ctx: ctx,
cfg: cfg,
startRowOffset: startRowOffset,
pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](),
downloadTasks: NewQueue[cloudFetchDownloadTask](),
httpClient: httpClient,
}

for _, link := range files {
Expand Down Expand Up @@ -140,6 +146,7 @@ type cloudIPCStreamIterator struct {
startRowOffset int64
pendingLinks Queue[cli_service.TSparkArrowResultLink]
downloadTasks Queue[cloudFetchDownloadTask]
httpClient *http.Client
}

var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
Expand All @@ -162,6 +169,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
resultChan: make(chan cloudFetchDownloadTaskResult),
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
httpClient: bi.httpClient,
}
task.Run()
bi.downloadTasks.Enqueue(task)
Expand Down Expand Up @@ -210,6 +218,7 @@ type cloudFetchDownloadTask struct {
link *cli_service.TSparkArrowResultLink
resultChan chan cloudFetchDownloadTaskResult
speedThresholdMbps float64
httpClient *http.Client
}

func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) {
Expand Down Expand Up @@ -252,7 +261,7 @@ func (cft *cloudFetchDownloadTask) Run() {
cft.link.StartRowOffset,
cft.link.RowCount,
)
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps)
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient)
if err != nil {
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
return
Expand Down Expand Up @@ -300,6 +309,7 @@ func fetchBatchBytes(
link *cli_service.TSparkArrowResultLink,
minTimeToExpiry time.Duration,
speedThresholdMbps float64,
httpClient *http.Client,
) (io.ReadCloser, error) {
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
return nil, errors.New(dbsqlerr.ErrLinkExpired)
Expand All @@ -318,8 +328,7 @@ func fetchBatchBytes(
}

startTime := time.Now()
client := http.DefaultClient
res, err := client.Do(req)
res, err := httpClient.Do(req)
if err != nil {
return nil, err
}
Expand Down
79 changes: 79 additions & 0 deletions internal/rows/arrowbased/batchloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,85 @@ func TestCloudFetchIterator(t *testing.T) {
assert.NotNil(t, err3)
assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound))
})

t.Run("should use custom HTTPClient when provided", func(t *testing.T) {
handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
if err != nil {
panic(err)
}
}

startRowOffset := int64(100)
customHTTPClient := &http.Client{
Transport: &http.Transport{MaxIdleConns: 10},
}

cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1
cfg.UserConfig.CloudFetchConfig.HTTPClient = customHTTPClient

bi, err := NewCloudBatchIterator(
context.Background(),
[]*cli_service.TSparkArrowResultLink{{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset,
RowCount: 1,
}},
startRowOffset,
cfg,
)
assert.Nil(t, err)

cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator)
assert.Equal(t, customHTTPClient, cbi.httpClient)

// Verify fetch works
sab, nextErr := bi.Next()
assert.Nil(t, nextErr)
assert.NotNil(t, sab)
})

t.Run("should fallback to http.DefaultClient when HTTPClient is nil", func(t *testing.T) {
handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
if err != nil {
panic(err)
}
}

startRowOffset := int64(100)
cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1
// Explicitly set HTTPClient to nil to verify fallback behavior
cfg.UserConfig.CloudFetchConfig.HTTPClient = nil

bi, err := NewCloudBatchIterator(
context.Background(),
[]*cli_service.TSparkArrowResultLink{{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset,
RowCount: 1,
}},
startRowOffset,
cfg,
)
assert.Nil(t, err)

cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator)
assert.Equal(t, http.DefaultClient, cbi.httpClient)

// Verify fetch works with default client
sab, nextErr := bi.Next()
assert.Nil(t, nextErr)
assert.NotNil(t, sab)
})
}

func generateArrowRecord() arrow.Record {
Expand Down
Loading