From e96f95891a24482e52184b2a51839766661e9c5b Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Fri, 18 Nov 2022 16:25:34 -0800 Subject: [PATCH 1/7] checkpoint Signed-off-by: Andre Furlan --- db.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ driver.go | 38 -------------------------------------- 2 files changed, 49 insertions(+), 38 deletions(-) create mode 100644 db.go diff --git a/db.go b/db.go new file mode 100644 index 00000000..9d37052c --- /dev/null +++ b/db.go @@ -0,0 +1,49 @@ +package dbsql + +import ( + "context" + "database/sql" + "database/sql/driver" +) + +type DatabricksDB interface { + QueryContextAsync(ctx context.Context, query string, args ...any) (rows *sql.Rows, queryId string, err error) +} + +type databricksDB struct { + db *sql.DB +} + +func OpenDB(c driver.Connector) DatabricksDB { + db := sql.OpenDB(c) + return &databricksDB{db} +} + +func (db *databricksDB) QueryContextAsync(ctx context.Context, query string, args ...any) (rows *sql.Rows, queryId string, err error) { + return nil, "", nil +} + +func (db *databricksDB) ExecContextAsync(ctx context.Context, query string, args ...any) (result sql.Result, queryId string) { + //go do something + return nil, "" +} + +func (db *databricksDB) CancelQuery(ctx context.Context, queryId string) error { + //go do something + return nil +} + +func (db *databricksDB) GetQueryStatus(ctx context.Context, queryId string) error { + //go do something + return nil +} + +func (db *databricksDB) FetchRows(ctx context.Context, queryId string) (rows *sql.Rows, err error) { + //go do something + return nil, nil +} + +func (db *databricksDB) FetchResult(ctx context.Context, queryId string) (rows sql.Result, err error) { + //go do something + return nil, nil +} diff --git a/driver.go b/driver.go index 80e54ef8..68d4eead 100644 --- a/driver.go +++ b/driver.go @@ -40,41 +40,3 @@ func (d *databricksDriver) OpenConnector(dsn string) (driver.Connector, error) { var _ driver.Driver = (*databricksDriver)(nil) var _ driver.DriverContext = (*databricksDriver)(nil) - -// type databricksDB struct { -// *sql.DB -// } - -// func OpenDB(c driver.Connector) *databricksDB { -// db := sql.OpenDB(c) -// return &databricksDB{db} -// } - -// func (db *databricksDB) QueryContextAsync(ctx context.Context, query string, args ...any) (rows *sql.Rows, queryId string, err error) { -// return nil, "", nil -// } - -// func (db *databricksDB) ExecContextAsync(ctx context.Context, query string, args ...any) (result sql.Result, queryId string) { -// //go do something -// return nil, "" -// } - -// func (db *databricksDB) CancelQuery(ctx context.Context, queryId string) error { -// //go do something -// return nil -// } - -// func (db *databricksDB) GetQueryStatus(ctx context.Context, queryId string) error { -// //go do something -// return nil -// } - -// func (db *databricksDB) FetchRows(ctx context.Context, queryId string) (rows *sql.Rows, err error) { -// //go do something -// return nil, nil -// } - -// func (db *databricksDB) FetchResult(ctx context.Context, queryId string) (rows sql.Result, err error) { -// //go do something -// return nil, nil -// } From 8a441d331b59af730662656e308b6aa5e4138b0d Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Mon, 21 Nov 2022 18:19:58 -0800 Subject: [PATCH 2/7] checkpoint Signed-off-by: Andre Furlan --- examples/asyncWorkflow/main.go | 89 ++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 examples/asyncWorkflow/main.go diff --git a/examples/asyncWorkflow/main.go b/examples/asyncWorkflow/main.go new file mode 100644 index 00000000..76b1e1cc --- /dev/null +++ b/examples/asyncWorkflow/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "fmt" + "os" + "strconv" + "time" + + dbsql "github.com/databricks/databricks-sql-go" + dbsqlctx "github.com/databricks/databricks-sql-go/driverctx" + dbsqllog "github.com/databricks/databricks-sql-go/logger" + "github.com/joho/godotenv" +) + +func main() { + // use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled` + if err := dbsqllog.SetLogLevel("debug"); err != nil { + panic(err) + } + // sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty + // dbsqllog.SetLogOutput(os.Stdout) + + // this is just to make it easy to load all variables + if err := godotenv.Load(); err != nil { + panic(err) + } + port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT")) + if err != nil { + panic(err) + } + + // programmatically initializes the connector + // another way is to use a DNS. In this case the equivalent DNS would be: + // "token:@hostname:port/http_path?catalog=hive_metastore&schema=default&timeout=60&maxRows=10&&timezone=America/Sao_Paulo&ANSI_MODE=true" + connector, err := dbsql.NewConnector( + // minimum configuration + dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")), + dbsql.WithPort(port), + dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")), + dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")), + //optional configuration + dbsql.WithSessionParams(map[string]string{"timezone": "America/Sao_Paulo", "ansi_mode": "true"}), + dbsql.WithUserAgentEntry("workflow-example"), + dbsql.WithInitialNamespace("hive_metastore", "default"), + dbsql.WithTimeout(time.Minute), // defaults to no timeout. Global timeout. Any query will be canceled if taking more than this time. + dbsql.WithMaxRows(10), // defaults to 10000 + ) + if err != nil { + // This will not be a connection error, but a DSN parse error or + // another initialization error. + panic(err) + + } + // Opening a driver typically will not attempt to connect to the database. + db := dbsql.OpenDB(connector) + // make sure to close it later + defer db.Close() + + // the "github.com/databricks/databricks-sql-go/driverctx" has some functions to help set the context for the driver + ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "workflow-example") + + for _, v := range []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"} { + i := v + // go func() { + _, exec, err := db.QueryContext(ogCtx, fmt.Sprintf("select %s", i)) + if err != nil { + panic(err) + } + rs, err := db.GetExecutionResult(ogCtx, exec) + if err != nil { + panic(err) + } + fmt.Println(rs) + // }() + } + // timezones are also supported + // var curTimestamp time.Time + // var curDate time.Time + // var curTimezone string + // if err := db.QueryRowContext(ogCtx, `select current_date(), current_timestamp(), current_timezone()`).Scan(&curDate, &curTimestamp, &curTimezone); err != nil { + // panic(err) + // } else { + // // this will print now at timezone America/Sao_Paulo is: 2022-11-16 20:25:15.282 -0300 -03 + // fmt.Printf("current timestamp at timezone %s is: %s\n", curTimezone, curTimestamp) + // fmt.Printf("current date at timezone %s is: %s\n", curTimezone, curDate) + // } + +} From a14ef8c8c108ce017f895b270146a028163df5e0 Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Mon, 21 Nov 2022 19:03:15 -0800 Subject: [PATCH 3/7] checkpoint Signed-off-by: Andre Furlan --- connection.go | 256 ++++++++++++++++++++++----------- connection_test.go | 60 ++++---- connector.go | 9 +- db.go | 92 +++++++++--- db_test.go | 1 + driverctx/ctx.go | 1 + examples/asyncWorkflow/main.go | 20 ++- internal/client/client.go | 19 +++ internal/config/config.go | 4 +- statement.go | 2 +- 10 files changed, 327 insertions(+), 137 deletions(-) create mode 100644 db_test.go diff --git a/connection.go b/connection.go index 29a15a08..ea32d3de 100644 --- a/connection.go +++ b/connection.go @@ -14,24 +14,25 @@ import ( "github.com/pkg/errors" ) -type conn struct { - id string - cfg *config.Config - client cli_service.TCLIService - session *cli_service.TOpenSessionResp +type Conn struct { + id string + cfg *config.Config + client cli_service.TCLIService + session *cli_service.TOpenSessionResp + execution *Execution } // The driver does not really implement prepared statements. -func (c *conn) Prepare(query string) (driver.Stmt, error) { +func (c *Conn) Prepare(query string) (driver.Stmt, error) { return &stmt{conn: c, query: query}, nil } // The driver does not really implement prepared statements. -func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { return &stmt{conn: c, query: query}, nil } -func (c *conn) Close() error { +func (c *Conn) Close() error { log := logger.WithContext(c.id, "", "") ctx := driverctx.NewContextWithConnId(context.Background(), c.id) sentinel := sentinel.Sentinel{ @@ -50,16 +51,16 @@ func (c *conn) Close() error { } // Not supported in Databricks -func (c *conn) Begin() (driver.Tx, error) { +func (c *Conn) Begin() (driver.Tx, error) { return nil, errors.New(ErrTransactionsNotSupported) } // Not supported in Databricks -func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { return nil, errors.New(ErrTransactionsNotSupported) } -func (c *conn) Ping(ctx context.Context) error { +func (c *Conn) Ping(ctx context.Context) error { log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") ctx = driverctx.NewContextWithConnId(ctx, c.id) ctx1, cancel := context.WithTimeout(ctx, 15*time.Second) @@ -73,12 +74,13 @@ func (c *conn) Ping(ctx context.Context) error { } // Implementation of SessionResetter -func (c *conn) ResetSession(ctx context.Context) error { +func (c *Conn) ResetSession(ctx context.Context) error { // For now our session does not have any important state to reset before re-use + c.execution = nil return nil } -func (c *conn) IsValid() bool { +func (c *Conn) IsValid() bool { return c.session.GetStatus().StatusCode == cli_service.TStatusCode_SUCCESS_STATUS } @@ -87,7 +89,7 @@ func (c *conn) IsValid() bool { // // ExecContext honors the context timeout and return when it is canceled. // Statement ExecContext is the same as connection ExecContext -func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") msg, start := logger.Track("ExecContext") ctx = driverctx.NewContextWithConnId(ctx, c.id) @@ -115,7 +117,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name // // QueryContext honors the context timeout and return when it is canceled. // Statement QueryContext is the same as connection QueryContext -func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { +func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { corrId := driverctx.CorrelationIdFromContext(ctx) log := logger.WithContext(c.id, corrId, "") msg, start := log.Track("QueryContext") @@ -124,42 +126,78 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam if len(args) > 0 { return nil, errors.New(ErrParametersNotSupported) } - // first we try to get the results synchronously. - // at any point in time that the context is done we must cancel and return - exStmtResp, _, err := c.runQuery(ctx, query, args) + if query == "" && c.execution != nil { + opHandle := &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: client.DecodeGuid(c.execution.Id), + Secret: c.execution.Secret, + }, + HasResultSet: c.execution.HasResultSet, + OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, + } + rows := rows{ + connId: c.id, + correlationId: corrId, + client: c.client, + opHandle: opHandle, + pageSize: int64(c.cfg.MaxRows), + location: c.cfg.Location, + } + return &rows, nil + } else { - if exStmtResp != nil && exStmtResp.OperationHandle != nil { - log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) - } - defer log.Duration(msg, start) + // first we try to get the results synchronously. + // at any point in time that the context is done we must cancel and return + exStmtResp, opStatus, err := c.runQuery(ctx, query, args) - if err != nil { - log.Err(err).Msgf("databricks: failed to run query: query %s", query) - return nil, wrapErrf(err, "failed to run query") - } - // hold on to the operation handle - opHandle := exStmtResp.OperationHandle + execId := "" + execStatus := "UNKNOWN" - rows := rows{ - connId: c.id, - correlationId: corrId, - client: c.client, - opHandle: opHandle, - pageSize: int64(c.cfg.MaxRows), - location: c.cfg.Location, - } + if opStatus != nil { + execStatus = opStatus.GetOperationState().String() + } + // hold on to the operation handle + opHandle := exStmtResp.OperationHandle - if exStmtResp.DirectResults != nil { - // return results - rows.fetchResults = exStmtResp.DirectResults.ResultSet - rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata + execId = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID) + log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), execId) + + defer log.Duration(msg, start) + + if err != nil { + log.Err(err).Msgf("databricks: failed to run query: query %s", query) + return nil, wrapErrf(err, "failed to run query") + } + rows := rows{ + connId: c.id, + correlationId: corrId, + client: c.client, + opHandle: opHandle, + pageSize: int64(c.cfg.MaxRows), + location: c.cfg.Location, + } + + if exStmtResp.DirectResults != nil { + // return results + rows.fetchResults = exStmtResp.DirectResults.ResultSet + rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata + } + + execPtr := execFromContext(ctx) + *execPtr = Execution{ + Id: execId, + Status: execStatus, + Secret: opHandle.OperationId.Secret, + HasResultSet: opHandle.HasResultSet, + } + return &rows, nil } - return &rows, nil } -func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) { +func (c *Conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) { + log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") // first we try to get the results synchronously. // at any point in time that the context is done we must cancel and return @@ -190,6 +228,38 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa return exStmtResp, opStatus, errors.New(opStatus.GetDisplayMessage()) // live states case cli_service.TOperationState_INITIALIZED_STATE, cli_service.TOperationState_PENDING_STATE, cli_service.TOperationState_RUNNING_STATE: + if !c.cfg.RunAsync { + + statusResp, err := c.pollOperation(ctx, opHandle) + if err != nil { + return exStmtResp, statusResp, err + } + switch statusResp.GetOperationState() { + // terminal states + // good + case cli_service.TOperationState_FINISHED_STATE: + // return handle to fetch results later + return exStmtResp, opStatus, nil + // bad + case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: + logBadQueryState(log, statusResp) + return exStmtResp, opStatus, errors.New(statusResp.GetDisplayMessage()) + // live states + default: + logBadQueryState(log, statusResp) + return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") + } + } else { + return exStmtResp, opStatus, nil + } + // weird states + default: + logBadQueryState(log, opStatus) + return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") + } + + } else { + if !c.cfg.RunAsync { statusResp, err := c.pollOperation(ctx, opHandle) if err != nil { return exStmtResp, statusResp, err @@ -199,41 +269,18 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa // good case cli_service.TOperationState_FINISHED_STATE: // return handle to fetch results later - return exStmtResp, opStatus, nil + return exStmtResp, statusResp, nil // bad case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: logBadQueryState(log, statusResp) - return exStmtResp, opStatus, errors.New(statusResp.GetDisplayMessage()) + return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage()) // live states default: logBadQueryState(log, statusResp) - return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") + return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened") } - // weird states - default: - logBadQueryState(log, opStatus) - return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") - } - - } else { - statusResp, err := c.pollOperation(ctx, opHandle) - if err != nil { - return exStmtResp, statusResp, err - } - switch statusResp.GetOperationState() { - // terminal states - // good - case cli_service.TOperationState_FINISHED_STATE: - // return handle to fetch results later - return exStmtResp, statusResp, nil - // bad - case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: - logBadQueryState(log, statusResp) - return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage()) - // live states - default: - logBadQueryState(log, statusResp) - return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened") + } else { + return exStmtResp, nil, nil } } } @@ -243,7 +290,7 @@ func logBadQueryState(log *logger.DBSQLLogger, opStatus *cli_service.TGetOperati log.Error().Msg(opStatus.GetErrorMessage()) } -func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) { +func (c *Conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) log := logger.WithContext(c.id, corrId, "") sentinel := sentinel.Sentinel{ @@ -251,7 +298,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver req := cli_service.TExecuteStatementReq{ SessionHandle: c.session.SessionHandle, Statement: query, - RunAsync: c.cfg.RunAsync, + RunAsync: true, QueryTimeout: int64(c.cfg.QueryTimeout / time.Second), // this is specific for databricks. It shortcuts server roundtrips GetDirectResults: &cli_service.TSparkGetDirectResults{ @@ -281,7 +328,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver return exStmtResp, err } -func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { +func (c *Conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) var statusResp *cli_service.TGetOperationStatusResp @@ -333,11 +380,58 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati return statusResp, nil } -var _ driver.Conn = (*conn)(nil) -var _ driver.Pinger = (*conn)(nil) -var _ driver.SessionResetter = (*conn)(nil) -var _ driver.Validator = (*conn)(nil) -var _ driver.ExecerContext = (*conn)(nil) -var _ driver.QueryerContext = (*conn)(nil) -var _ driver.ConnPrepareContext = (*conn)(nil) -var _ driver.ConnBeginTx = (*conn)(nil) +func (c *Conn) cancelOperation(ctx context.Context, execution Execution) error { + req := cli_service.TCancelOperationReq{ + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: client.DecodeGuid(execution.Id), + Secret: execution.Secret, + }, + OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, + HasResultSet: execution.HasResultSet, + }, + } + _, err := c.client.CancelOperation(ctx, &req) + return err +} + +func (c *Conn) getOperationStatus(ctx context.Context, execution Execution) (Execution, error) { + statusResp, err := c.client.GetOperationStatus(ctx, &cli_service.TGetOperationStatusReq{ + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: client.DecodeGuid(execution.Id), + Secret: execution.Secret, + }, + OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, + HasResultSet: execution.HasResultSet, + }, + }) + if err != nil { + return execution, err + } + exRet := Execution{ + Status: statusResp.GetOperationState().String(), + Id: execution.Id, + Secret: execution.Secret, + HasResultSet: execution.HasResultSet, + } + return exRet, nil +} + +func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { + ex, ok := nv.Value.(Execution) + if ok { + c.execution = &ex + return driver.ErrRemoveArgument + } + return nil +} + +var _ driver.Conn = (*Conn)(nil) +var _ driver.Pinger = (*Conn)(nil) +var _ driver.SessionResetter = (*Conn)(nil) +var _ driver.Validator = (*Conn)(nil) +var _ driver.ExecerContext = (*Conn)(nil) +var _ driver.QueryerContext = (*Conn)(nil) +var _ driver.ConnPrepareContext = (*Conn)(nil) +var _ driver.ConnBeginTx = (*Conn)(nil) diff --git a/connection_test.go b/connection_test.go index ecb4b71e..45f8f585 100644 --- a/connection_test.go +++ b/connection_test.go @@ -26,7 +26,7 @@ func TestConn_executeStatement(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -76,7 +76,7 @@ func TestConn_executeStatement(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -102,7 +102,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -132,7 +132,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -162,7 +162,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -192,7 +192,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -222,7 +222,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -254,7 +254,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -296,7 +296,7 @@ func TestConn_pollOperation(t *testing.T) { FnGetOperationStatus: getOperationStatus, FnCancelOperation: cancelOperation, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -342,7 +342,7 @@ func TestConn_pollOperation(t *testing.T) { } cfg := config.WithDefaults() cfg.PollInterval = 100 * time.Millisecond - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: cfg, @@ -388,7 +388,7 @@ func TestConn_pollOperation(t *testing.T) { } cfg := config.WithDefaults() cfg.PollInterval = 100 * time.Millisecond - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: cfg, @@ -423,7 +423,7 @@ func TestConn_runQuery(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -465,7 +465,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -509,7 +509,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -553,7 +553,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -605,7 +605,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -657,7 +657,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -709,7 +709,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -761,7 +761,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -782,7 +782,7 @@ func TestConn_ExecContext(t *testing.T) { var executeStatementCount int testClient := &client.TestClient{} - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -817,7 +817,7 @@ func TestConn_ExecContext(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -860,7 +860,7 @@ func TestConn_ExecContext(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -881,7 +881,7 @@ func TestConn_QueryContext(t *testing.T) { var executeStatementCount int testClient := &client.TestClient{} - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -916,7 +916,7 @@ func TestConn_QueryContext(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -959,7 +959,7 @@ func TestConn_QueryContext(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -994,7 +994,7 @@ func TestConn_Ping(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -1037,7 +1037,7 @@ func TestConn_Ping(t *testing.T) { FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -1051,7 +1051,7 @@ func TestConn_Ping(t *testing.T) { func TestConn_Begin(t *testing.T) { t.Run("Begin not supported", func(t *testing.T) { - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: &client.TestClient{}, cfg: config.WithDefaults(), @@ -1064,7 +1064,7 @@ func TestConn_Begin(t *testing.T) { func TestConn_BeginTx(t *testing.T) { t.Run("BeginTx not supported", func(t *testing.T) { - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: &client.TestClient{}, cfg: config.WithDefaults(), @@ -1077,7 +1077,7 @@ func TestConn_BeginTx(t *testing.T) { func TestConn_ResetSession(t *testing.T) { t.Run("ResetSession not currently supported", func(t *testing.T) { - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: &client.TestClient{}, cfg: config.WithDefaults(), diff --git a/connector.go b/connector.go index 74f78d8a..e8419039 100644 --- a/connector.go +++ b/connector.go @@ -59,7 +59,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, errors.New("databricks: invalid open session response") } - conn := &conn{ + conn := &Conn{ id: client.SprintGuid(session.SessionHandle.GetSessionId().GUID), cfg: c.cfg, client: tclient, @@ -182,3 +182,10 @@ func WithSessionParams(params map[string]string) connOption { c.SessionParams = params } } + +// WithRunAsync +func withRunAsync() connOption { + return func(c *config.Config) { + c.RunAsync = true + } +} diff --git a/db.go b/db.go index 9d37052c..096e498a 100644 --- a/db.go +++ b/db.go @@ -4,46 +4,100 @@ import ( "context" "database/sql" "database/sql/driver" + + "github.com/databricks/databricks-sql-go/driverctx" + "github.com/pkg/errors" ) type DatabricksDB interface { - QueryContextAsync(ctx context.Context, query string, args ...any) (rows *sql.Rows, queryId string, err error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, Execution, error) + CancelExecution(ctx context.Context, exec Execution) error + GetExecutionRows(ctx context.Context, exec Execution) (*sql.Rows, error) + CheckExecution(ctx context.Context, exec Execution) (Execution, error) + Close() error } type databricksDB struct { - db *sql.DB + sqldb *sql.DB } func OpenDB(c driver.Connector) DatabricksDB { + cnnr := c.(*connector) + cnnr.cfg.RunAsync = true db := sql.OpenDB(c) return &databricksDB{db} } -func (db *databricksDB) QueryContextAsync(ctx context.Context, query string, args ...any) (rows *sql.Rows, queryId string, err error) { - return nil, "", nil +func (db *databricksDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, Execution, error) { + exec := Execution{} + ctx2 := newContextWithExec(ctx, &exec) + ret, err := db.sqldb.QueryContext(ctx2, query, args...) + return ret, exec, err +} + +func (db *databricksDB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, string, error) { + // db.sqldb.ExecContext() + return nil, "", errors.New(ErrNotImplemented) +} + +func (db *databricksDB) Close() error { + return db.sqldb.Close() +} + +func (db *databricksDB) CancelExecution(ctx context.Context, exec Execution) error { + con, err := db.sqldb.Conn(ctx) + if err != nil { + return err + } + return con.Raw(func(driverConn any) error { + dbsqlcon, ok := driverConn.(*Conn) + if !ok { + return errors.New("invalid connection type") + } + return dbsqlcon.cancelOperation(ctx, exec) + }) +} + +func (db *databricksDB) CheckExecution(ctx context.Context, exec Execution) (Execution, error) { + con, err := db.sqldb.Conn(ctx) + if err != nil { + return exec, err + } + exRet := exec + err = con.Raw(func(driverConn any) error { + dbsqlcon, ok := driverConn.(*Conn) + if !ok { + return errors.New("invalid connection type") + } + exRet, err = dbsqlcon.getOperationStatus(ctx, exec) + return err + }) + return exRet, err } -func (db *databricksDB) ExecContextAsync(ctx context.Context, query string, args ...any) (result sql.Result, queryId string) { - //go do something - return nil, "" +func (db *databricksDB) GetExecutionRows(ctx context.Context, exec Execution) (*sql.Rows, error) { + return db.sqldb.QueryContext(ctx, "", exec) } -func (db *databricksDB) CancelQuery(ctx context.Context, queryId string) error { - //go do something - return nil +func (db *databricksDB) GetExecutionResult(ctx context.Context, exec Execution) (sql.Result, error) { + return db.sqldb.ExecContext(ctx, "", exec) } -func (db *databricksDB) GetQueryStatus(ctx context.Context, queryId string) error { - //go do something - return nil +type Execution struct { + Status string + Id string + Secret []byte + HasResultSet bool } -func (db *databricksDB) FetchRows(ctx context.Context, queryId string) (rows *sql.Rows, err error) { - //go do something - return nil, nil +func newContextWithExec(ctx context.Context, exec *Execution) context.Context { + return context.WithValue(ctx, driverctx.ExecutionContextKey, exec) } -func (db *databricksDB) FetchResult(ctx context.Context, queryId string) (rows sql.Result, err error) { - //go do something - return nil, nil +func execFromContext(ctx context.Context) *Execution { + execId, ok := ctx.Value(driverctx.ExecutionContextKey).(*Execution) + if !ok { + return nil + } + return execId } diff --git a/db_test.go b/db_test.go new file mode 100644 index 00000000..08a359d0 --- /dev/null +++ b/db_test.go @@ -0,0 +1 @@ +package dbsql diff --git a/driverctx/ctx.go b/driverctx/ctx.go index 21397f5f..60975a66 100644 --- a/driverctx/ctx.go +++ b/driverctx/ctx.go @@ -11,6 +11,7 @@ type contextKey int const ( CorrelationIdContextKey contextKey = iota ConnIdContextKey + ExecutionContextKey ) // NewContextWithCorrelationId creates a new context with correlationId value. Used by Logger to populate field corrId. diff --git a/examples/asyncWorkflow/main.go b/examples/asyncWorkflow/main.go index 76b1e1cc..cc138038 100644 --- a/examples/asyncWorkflow/main.go +++ b/examples/asyncWorkflow/main.go @@ -67,12 +67,26 @@ func main() { if err != nil { panic(err) } - rs, err := db.GetExecutionResult(ogCtx, exec) + ex, err := db.CheckExecution(ogCtx, exec) if err != nil { panic(err) } - fmt.Println(rs) - // }() + fmt.Println(ex.Status) + + rs, err := db.GetExecutionRows(ogCtx, exec) + if err != nil { + panic(err) + } + var res string + for rs.Next() { + err := rs.Scan(&res) + if err != nil { + fmt.Println(err) + rs.Close() + return + } + fmt.Println(res) + } } // timezones are also supported // var curTimestamp time.Time diff --git a/internal/client/client.go b/internal/client/client.go index e76441b7..271d87db 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -3,6 +3,7 @@ package client import ( "compress/zlib" "context" + "encoding/hex" "encoding/json" "fmt" "net/http" @@ -271,3 +272,21 @@ func SprintGuid(bts []byte) string { logger.Warn().Msgf("GUID not valid: %x", bts) return fmt.Sprintf("%x", bts) } + +func DecodeGuid(str string) []byte { + if len(str) == 36 { + bts, err := hex.DecodeString(str[0:8] + str[9:13] + str[14:18] + str[19:23] + str[24:36]) + if err != nil { + logger.Warn().Msgf("GUID not valid: %s", str) + return []byte{} + } + return bts + } + logger.Warn().Msgf("GUID not valid: %s", str) + bts, err := hex.DecodeString(str) + if err != nil { + logger.Warn().Msgf("GUID not valid: %s", str) + return []byte{} + } + return bts +} diff --git a/internal/config/config.go b/internal/config/config.go index 19c47a55..15d81027 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,7 +20,7 @@ type Config struct { TLSConfig *tls.Config // nil disables TLS Authenticator string //TODO for oauth - RunAsync bool // TODO + RunAsync bool PollInterval time.Duration ConnectTimeout time.Duration // max time to open session ClientTimeout time.Duration // max time the http request can last @@ -134,7 +134,7 @@ func WithDefaults() *Config { UserConfig: UserConfig{}.WithDefaults(), TLSConfig: &tls.Config{MinVersion: tls.VersionTLS12}, Authenticator: "", - RunAsync: true, + RunAsync: false, PollInterval: 1 * time.Second, ConnectTimeout: 60 * time.Second, ClientTimeout: 900 * time.Second, diff --git a/statement.go b/statement.go index 940649a2..7ade39a9 100644 --- a/statement.go +++ b/statement.go @@ -7,7 +7,7 @@ import ( ) type stmt struct { - conn *conn + conn *Conn query string } From 3682d07a4c9ab51056648e022f49a6444173b6b4 Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Mon, 21 Nov 2022 21:00:27 -0800 Subject: [PATCH 4/7] addind ExecutionStatus Signed-off-by: Andre Furlan --- connection.go | 127 ++++++++++++++++++++++++++++---------------------- connector.go | 7 --- db.go | 66 +++++++++++++++++--------- 3 files changed, 114 insertions(+), 86 deletions(-) diff --git a/connection.go b/connection.go index ea32d3de..b09b002c 100644 --- a/connection.go +++ b/connection.go @@ -15,11 +15,11 @@ import ( ) type Conn struct { - id string - cfg *config.Config - client cli_service.TCLIService - session *cli_service.TOpenSessionResp - execution *Execution + id string + cfg *config.Config + client cli_service.TCLIService + session *cli_service.TOpenSessionResp + exc *Execution } // The driver does not really implement prepared statements. @@ -76,7 +76,7 @@ func (c *Conn) Ping(ctx context.Context) error { // Implementation of SessionResetter func (c *Conn) ResetSession(ctx context.Context) error { // For now our session does not have any important state to reset before re-use - c.execution = nil + c.exc = nil return nil } @@ -126,15 +126,8 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam if len(args) > 0 { return nil, errors.New(ErrParametersNotSupported) } - if query == "" && c.execution != nil { - opHandle := &cli_service.TOperationHandle{ - OperationId: &cli_service.THandleIdentifier{ - GUID: client.DecodeGuid(c.execution.Id), - Secret: c.execution.Secret, - }, - HasResultSet: c.execution.HasResultSet, - OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, - } + if query == "" && c.exc != nil { + opHandle := toOperationHandle(c.exc) rows := rows{ connId: c.id, correlationId: corrId, @@ -150,17 +143,17 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam // at any point in time that the context is done we must cancel and return exStmtResp, opStatus, err := c.runQuery(ctx, query, args) - execId := "" - execStatus := "UNKNOWN" + excId := "" + excStatus := ExecutionUnknown if opStatus != nil { - execStatus = opStatus.GetOperationState().String() + excStatus = toExecutionStatus(opStatus.GetOperationState()) } // hold on to the operation handle opHandle := exStmtResp.OperationHandle - execId = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID) - log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), execId) + excId = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID) + log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), excId) defer log.Duration(msg, start) @@ -184,10 +177,10 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata } - execPtr := execFromContext(ctx) - *execPtr = Execution{ - Id: execId, - Status: execStatus, + excPtr := excFromContext(ctx) + *excPtr = Execution{ + Id: excId, + Status: excStatus, Secret: opHandle.OperationId.Secret, HasResultSet: opHandle.HasResultSet, } @@ -228,8 +221,9 @@ func (c *Conn) runQuery(ctx context.Context, query string, args []driver.NamedVa return exStmtResp, opStatus, errors.New(opStatus.GetDisplayMessage()) // live states case cli_service.TOperationState_INITIALIZED_STATE, cli_service.TOperationState_PENDING_STATE, cli_service.TOperationState_RUNNING_STATE: - if !c.cfg.RunAsync { - + if c.cfg.RunAsync { + return exStmtResp, opStatus, nil + } else { statusResp, err := c.pollOperation(ctx, opHandle) if err != nil { return exStmtResp, statusResp, err @@ -249,8 +243,6 @@ func (c *Conn) runQuery(ctx context.Context, query string, args []driver.NamedVa logBadQueryState(log, statusResp) return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") } - } else { - return exStmtResp, opStatus, nil } // weird states default: @@ -259,7 +251,9 @@ func (c *Conn) runQuery(ctx context.Context, query string, args []driver.NamedVa } } else { - if !c.cfg.RunAsync { + if c.cfg.RunAsync { + return exStmtResp, nil, nil + } else { statusResp, err := c.pollOperation(ctx, opHandle) if err != nil { return exStmtResp, statusResp, err @@ -279,8 +273,6 @@ func (c *Conn) runQuery(ctx context.Context, query string, args []driver.NamedVa logBadQueryState(log, statusResp) return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened") } - } else { - return exStmtResp, nil, nil } } } @@ -300,7 +292,7 @@ func (c *Conn) executeStatement(ctx context.Context, query string, args []driver Statement: query, RunAsync: true, QueryTimeout: int64(c.cfg.QueryTimeout / time.Second), - // this is specific for databricks. It shortcuts server roundtrips + // this is specific for databricks. It shortcuts server round-trips GetDirectResults: &cli_service.TSparkGetDirectResults{ MaxRows: int64(c.cfg.MaxRows), }, @@ -345,7 +337,7 @@ func (c *Conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati OperationHandle: opHandle, }) if statusResp != nil && statusResp.OperationState != nil { - log.Debug().Msgf("databricks: status %s", statusResp.GetOperationState().String()) + log.Debug().Msgf("databricks: status %s", toExecutionStatus(statusResp.GetOperationState())) } return func() bool { // which other states? @@ -380,40 +372,26 @@ func (c *Conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati return statusResp, nil } -func (c *Conn) cancelOperation(ctx context.Context, execution Execution) error { +func (c *Conn) cancelOperation(ctx context.Context, exc Execution) error { req := cli_service.TCancelOperationReq{ - OperationHandle: &cli_service.TOperationHandle{ - OperationId: &cli_service.THandleIdentifier{ - GUID: client.DecodeGuid(execution.Id), - Secret: execution.Secret, - }, - OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, - HasResultSet: execution.HasResultSet, - }, + OperationHandle: toOperationHandle(&exc), } _, err := c.client.CancelOperation(ctx, &req) return err } -func (c *Conn) getOperationStatus(ctx context.Context, execution Execution) (Execution, error) { +func (c *Conn) getOperationStatus(ctx context.Context, exc Execution) (Execution, error) { statusResp, err := c.client.GetOperationStatus(ctx, &cli_service.TGetOperationStatusReq{ - OperationHandle: &cli_service.TOperationHandle{ - OperationId: &cli_service.THandleIdentifier{ - GUID: client.DecodeGuid(execution.Id), - Secret: execution.Secret, - }, - OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, - HasResultSet: execution.HasResultSet, - }, + OperationHandle: toOperationHandle(&exc), }) if err != nil { - return execution, err + return exc, err } exRet := Execution{ - Status: statusResp.GetOperationState().String(), - Id: execution.Id, - Secret: execution.Secret, - HasResultSet: execution.HasResultSet, + Status: toExecutionStatus(statusResp.GetOperationState()), + Id: exc.Id, + Secret: exc.Secret, + HasResultSet: exc.HasResultSet, } return exRet, nil } @@ -421,12 +399,47 @@ func (c *Conn) getOperationStatus(ctx context.Context, execution Execution) (Exe func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { ex, ok := nv.Value.(Execution) if ok { - c.execution = &ex + c.exc = &ex return driver.ErrRemoveArgument } return nil } +func toExecutionStatus(state cli_service.TOperationState) ExecutionStatus { + switch state { + + case cli_service.TOperationState_INITIALIZED_STATE: + return ExecutionInitialized + case cli_service.TOperationState_RUNNING_STATE: + return ExecutionRunning + case cli_service.TOperationState_FINISHED_STATE: + return ExecutionFinished + case cli_service.TOperationState_CANCELED_STATE: + return ExecutionCanceled + case cli_service.TOperationState_CLOSED_STATE: + return ExecutionClosed + case cli_service.TOperationState_ERROR_STATE: + return ExecutionError + case cli_service.TOperationState_PENDING_STATE: + return ExecutionPending + case cli_service.TOperationState_TIMEDOUT_STATE: + return ExecutionTimedOut + default: + return ExecutionUnknown + } +} + +func toOperationHandle(ex *Execution) *cli_service.TOperationHandle { + return &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: client.DecodeGuid(ex.Id), + Secret: ex.Secret, + }, + OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, + HasResultSet: ex.HasResultSet, + } +} + var _ driver.Conn = (*Conn)(nil) var _ driver.Pinger = (*Conn)(nil) var _ driver.SessionResetter = (*Conn)(nil) diff --git a/connector.go b/connector.go index e8419039..b615263c 100644 --- a/connector.go +++ b/connector.go @@ -182,10 +182,3 @@ func WithSessionParams(params map[string]string) connOption { c.SessionParams = params } } - -// WithRunAsync -func withRunAsync() connOption { - return func(c *config.Config) { - c.RunAsync = true - } -} diff --git a/db.go b/db.go index 096e498a..8be78ab5 100644 --- a/db.go +++ b/db.go @@ -11,9 +11,9 @@ import ( type DatabricksDB interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, Execution, error) - CancelExecution(ctx context.Context, exec Execution) error - GetExecutionRows(ctx context.Context, exec Execution) (*sql.Rows, error) - CheckExecution(ctx context.Context, exec Execution) (Execution, error) + CancelExecution(ctx context.Context, exc Execution) error + GetExecutionRows(ctx context.Context, exc Execution) (*sql.Rows, error) + CheckExecution(ctx context.Context, exc Execution) (Execution, error) Close() error } @@ -29,10 +29,10 @@ func OpenDB(c driver.Connector) DatabricksDB { } func (db *databricksDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, Execution, error) { - exec := Execution{} - ctx2 := newContextWithExec(ctx, &exec) + exc := Execution{} + ctx2 := newContextWithExecution(ctx, &exc) ret, err := db.sqldb.QueryContext(ctx2, query, args...) - return ret, exec, err + return ret, exc, err } func (db *databricksDB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, string, error) { @@ -44,7 +44,7 @@ func (db *databricksDB) Close() error { return db.sqldb.Close() } -func (db *databricksDB) CancelExecution(ctx context.Context, exec Execution) error { +func (db *databricksDB) CancelExecution(ctx context.Context, exc Execution) error { con, err := db.sqldb.Conn(ctx) if err != nil { return err @@ -54,50 +54,72 @@ func (db *databricksDB) CancelExecution(ctx context.Context, exec Execution) err if !ok { return errors.New("invalid connection type") } - return dbsqlcon.cancelOperation(ctx, exec) + return dbsqlcon.cancelOperation(ctx, exc) }) } -func (db *databricksDB) CheckExecution(ctx context.Context, exec Execution) (Execution, error) { +func (db *databricksDB) CheckExecution(ctx context.Context, exc Execution) (Execution, error) { con, err := db.sqldb.Conn(ctx) if err != nil { - return exec, err + return exc, err } - exRet := exec + exRet := exc err = con.Raw(func(driverConn any) error { dbsqlcon, ok := driverConn.(*Conn) if !ok { return errors.New("invalid connection type") } - exRet, err = dbsqlcon.getOperationStatus(ctx, exec) + exRet, err = dbsqlcon.getOperationStatus(ctx, exc) return err }) return exRet, err } -func (db *databricksDB) GetExecutionRows(ctx context.Context, exec Execution) (*sql.Rows, error) { - return db.sqldb.QueryContext(ctx, "", exec) +func (db *databricksDB) GetExecutionRows(ctx context.Context, exc Execution) (*sql.Rows, error) { + return db.sqldb.QueryContext(ctx, "", exc) } -func (db *databricksDB) GetExecutionResult(ctx context.Context, exec Execution) (sql.Result, error) { - return db.sqldb.ExecContext(ctx, "", exec) +func (db *databricksDB) GetExecutionResult(ctx context.Context, exc Execution) (sql.Result, error) { + return db.sqldb.ExecContext(ctx, "", exc) } type Execution struct { - Status string + Status ExecutionStatus Id string Secret []byte HasResultSet bool } -func newContextWithExec(ctx context.Context, exec *Execution) context.Context { - return context.WithValue(ctx, driverctx.ExecutionContextKey, exec) +type ExecutionStatus string + +const ( + // live state Initialized + ExecutionInitialized ExecutionStatus = "Initialized" + // live state Running + ExecutionRunning ExecutionStatus = "Running" + // terminal state Finished + ExecutionFinished ExecutionStatus = "Finished" + // terminal state Canceled + ExecutionCanceled ExecutionStatus = "Canceled" + // terminal state Closed + ExecutionClosed ExecutionStatus = "Closed" + // terminal state Error + ExecutionError ExecutionStatus = "Error" + ExecutionUnknown ExecutionStatus = "Unknown" + // live state Pending + ExecutionPending ExecutionStatus = "Pending" + // terminal state TimedOut + ExecutionTimedOut ExecutionStatus = "TimedOut" +) + +func newContextWithExecution(ctx context.Context, exc *Execution) context.Context { + return context.WithValue(ctx, driverctx.ExecutionContextKey, exc) } -func execFromContext(ctx context.Context) *Execution { - execId, ok := ctx.Value(driverctx.ExecutionContextKey).(*Execution) +func excFromContext(ctx context.Context) *Execution { + excId, ok := ctx.Value(driverctx.ExecutionContextKey).(*Execution) if !ok { return nil } - return execId + return excId } From 420742f1b77bbc0f35a26728bacecd04a0f41b0b Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Mon, 21 Nov 2022 22:04:02 -0800 Subject: [PATCH 5/7] experimenting with pool = 1 Signed-off-by: Andre Furlan --- connection.go | 24 ++++++++++++----- db.go | 19 +++++++++++++ examples/asyncWorkflow/main.go | 49 ++++++++++++++++++++++++---------- 3 files changed, 72 insertions(+), 20 deletions(-) diff --git a/connection.go b/connection.go index b09b002c..f6009929 100644 --- a/connection.go +++ b/connection.go @@ -96,6 +96,10 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name if len(args) > 0 { return nil, errors.New(ErrParametersNotSupported) } + if query == "" && c.exc != nil { + //TODO + return nil, errors.New(ErrNotImplemented) + } exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args) if exStmtResp != nil && exStmtResp.OperationHandle != nil { @@ -175,15 +179,21 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam // return results rows.fetchResults = exStmtResp.DirectResults.ResultSet rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata + } else { + // set to closed so clients won't ask for it + excStatus = ExecutionClosed } + if c.cfg.RunAsync { - excPtr := excFromContext(ctx) - *excPtr = Execution{ - Id: excId, - Status: excStatus, - Secret: opHandle.OperationId.Secret, - HasResultSet: opHandle.HasResultSet, + excPtr := excFromContext(ctx) + *excPtr = Execution{ + Id: excId, + Status: excStatus, + Secret: opHandle.OperationId.Secret, + HasResultSet: opHandle.HasResultSet, + } } + return &rows, nil } @@ -373,6 +383,7 @@ func (c *Conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati } func (c *Conn) cancelOperation(ctx context.Context, exc Execution) error { + // TODO wrap in Sentinel req := cli_service.TCancelOperationReq{ OperationHandle: toOperationHandle(&exc), } @@ -381,6 +392,7 @@ func (c *Conn) cancelOperation(ctx context.Context, exc Execution) error { } func (c *Conn) getOperationStatus(ctx context.Context, exc Execution) (Execution, error) { + // TODO wrap in Sentinel statusResp, err := c.client.GetOperationStatus(ctx, &cli_service.TGetOperationStatusReq{ OperationHandle: toOperationHandle(&exc), }) diff --git a/db.go b/db.go index 8be78ab5..324ebd64 100644 --- a/db.go +++ b/db.go @@ -15,6 +15,8 @@ type DatabricksDB interface { GetExecutionRows(ctx context.Context, exc Execution) (*sql.Rows, error) CheckExecution(ctx context.Context, exc Execution) (Execution, error) Close() error + Stats() sql.DBStats + SetMaxOpenConns(n int) } type databricksDB struct { @@ -44,6 +46,14 @@ func (db *databricksDB) Close() error { return db.sqldb.Close() } +func (db *databricksDB) Stats() sql.DBStats { + return db.sqldb.Stats() +} + +func (db *databricksDB) SetMaxOpenConns(n int) { + db.sqldb.SetMaxOpenConns(n) +} + func (db *databricksDB) CancelExecution(ctx context.Context, exc Execution) error { con, err := db.sqldb.Conn(ctx) if err != nil { @@ -112,6 +122,15 @@ const ( ExecutionTimedOut ExecutionStatus = "TimedOut" ) +func (e ExecutionStatus) Terminal() bool { + switch e { + case ExecutionInitialized, ExecutionPending, ExecutionRunning: + return false + default: + return true + } +} + func newContextWithExecution(ctx context.Context, exc *Execution) context.Context { return context.WithValue(ctx, driverctx.ExecutionContextKey, exc) } diff --git a/examples/asyncWorkflow/main.go b/examples/asyncWorkflow/main.go index cc138038..e1491c1d 100644 --- a/examples/asyncWorkflow/main.go +++ b/examples/asyncWorkflow/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "log" "os" "strconv" "time" @@ -40,7 +41,6 @@ func main() { dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")), dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")), //optional configuration - dbsql.WithSessionParams(map[string]string{"timezone": "America/Sao_Paulo", "ansi_mode": "true"}), dbsql.WithUserAgentEntry("workflow-example"), dbsql.WithInitialNamespace("hive_metastore", "default"), dbsql.WithTimeout(time.Minute), // defaults to no timeout. Global timeout. Any query will be canceled if taking more than this time. @@ -57,27 +57,42 @@ func main() { // make sure to close it later defer db.Close() + db.SetMaxOpenConns(1) + // the "github.com/databricks/databricks-sql-go/driverctx" has some functions to help set the context for the driver - ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "workflow-example") + ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "asyncWorkflow-example") - for _, v := range []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"} { - i := v - // go func() { - _, exec, err := db.QueryContext(ogCtx, fmt.Sprintf("select %s", i)) - if err != nil { - panic(err) - } - ex, err := db.CheckExecution(ogCtx, exec) - if err != nil { - panic(err) + // for _, v := range []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"} { + // i := v + // go func() { + // _, exc, err := db.QueryContext(ogCtx, fmt.Sprintf("select %s", i)) + rs, exc, err := db.QueryContext(ogCtx, `SELECT id FROM RANGE(100000000) ORDER BY RANDOM() + 2 asc`) + if err != nil { + panic(err) + } + for { + if exc.Status.Terminal() { + break + } else { + // TODO: how to prevent the connection being locked when rows has no data?? + exc, err = db.CheckExecution(ogCtx, exc) + if err != nil { + log.Fatal(err) + } } - fmt.Println(ex.Status) + fmt.Println(db.Stats()) + time.Sleep(time.Second) + } - rs, err := db.GetExecutionRows(ogCtx, exec) + fmt.Println(exc.Status) + if exc.Status == dbsql.ExecutionFinished { + rs, err = db.GetExecutionRows(ogCtx, exc) if err != nil { panic(err) } + defer rs.Close() var res string + i := 0 for rs.Next() { err := rs.Scan(&res) if err != nil { @@ -86,8 +101,14 @@ func main() { return } fmt.Println(res) + if i < 10 { + i++ + } else { + rs.Close() + } } } + // } // timezones are also supported // var curTimestamp time.Time // var curDate time.Time From a415d46f7306e11a1c7e5a3a44196c5f8673975d Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Mon, 21 Nov 2022 23:11:22 -0800 Subject: [PATCH 6/7] got it to work Signed-off-by: Andre Furlan --- examples/asyncWorkflow/main.go | 35 +++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/examples/asyncWorkflow/main.go b/examples/asyncWorkflow/main.go index e1491c1d..d2844bb7 100644 --- a/examples/asyncWorkflow/main.go +++ b/examples/asyncWorkflow/main.go @@ -68,13 +68,31 @@ func main() { // _, exc, err := db.QueryContext(ogCtx, fmt.Sprintf("select %s", i)) rs, exc, err := db.QueryContext(ogCtx, `SELECT id FROM RANGE(100000000) ORDER BY RANDOM() + 2 asc`) if err != nil { - panic(err) + log.Fatal(err) + } + defer rs.Close() + if exc.Status == dbsql.ExecutionFinished { + var res string + i := 0 + for rs.Next() { + err := rs.Scan(&res) + if err != nil { + fmt.Println(err) + rs.Close() + return + } + fmt.Println(res) + if i < 10 { + i++ + } else { + return + } + } } for { if exc.Status.Terminal() { break } else { - // TODO: how to prevent the connection being locked when rows has no data?? exc, err = db.CheckExecution(ogCtx, exc) if err != nil { log.Fatal(err) @@ -84,29 +102,28 @@ func main() { time.Sleep(time.Second) } - fmt.Println(exc.Status) if exc.Status == dbsql.ExecutionFinished { rs, err = db.GetExecutionRows(ogCtx, exc) if err != nil { - panic(err) + log.Fatal(err) } - defer rs.Close() var res string i := 0 for rs.Next() { err := rs.Scan(&res) if err != nil { fmt.Println(err) - rs.Close() - return + break } fmt.Println(res) - if i < 10 { + if i < 12 { i++ } else { - rs.Close() + return } } + } else { + fmt.Println(exc.Status) } // } // timezones are also supported From 60131e06501ddc15a370236abc14dd70051113f9 Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Tue, 22 Nov 2022 11:12:22 -0800 Subject: [PATCH 7/7] got it to work better Signed-off-by: Andre Furlan --- connection.go | 29 ++++++++++++++++++----------- db.go | 9 +++++++-- examples/asyncWorkflow/main.go | 14 ++++++++++++-- rows.go | 26 +++++++++++++++++++------- 4 files changed, 56 insertions(+), 22 deletions(-) diff --git a/connection.go b/connection.go index f6009929..100dfe31 100644 --- a/connection.go +++ b/connection.go @@ -161,6 +161,17 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam defer log.Duration(msg, start) + if c.cfg.RunAsync { + + excPtr := excFromContext(ctx) + *excPtr = Execution{ + Id: excId, + Status: excStatus, + Secret: opHandle.OperationId.Secret, + HasResultSet: opHandle.HasResultSet, + } + } + if err != nil { log.Err(err).Msgf("databricks: failed to run query: query %s", query) return nil, wrapErrf(err, "failed to run query") @@ -179,19 +190,15 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam // return results rows.fetchResults = exStmtResp.DirectResults.ResultSet rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata - } else { - // set to closed so clients won't ask for it - excStatus = ExecutionClosed } - if c.cfg.RunAsync { - excPtr := excFromContext(ctx) - *excPtr = Execution{ - Id: excId, - Status: excStatus, - Secret: opHandle.OperationId.Secret, - HasResultSet: opHandle.HasResultSet, - } + // if the direct results has all rows, the operation will be deleted, so + // set it to closed so clients won't ask for it + // excStatus = ExecutionClosed + + if c.cfg.RunAsync && excStatus != ExecutionFinished { + rows.opHandle = nil + } return &rows, nil diff --git a/db.go b/db.go index 324ebd64..555e3460 100644 --- a/db.go +++ b/db.go @@ -33,8 +33,11 @@ func OpenDB(c driver.Connector) DatabricksDB { func (db *databricksDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, Execution, error) { exc := Execution{} ctx2 := newContextWithExecution(ctx, &exc) - ret, err := db.sqldb.QueryContext(ctx2, query, args...) - return ret, exc, err + rs, err := db.sqldb.QueryContext(ctx2, query, args...) + if exc.Status != ExecutionFinished && rs != nil { + rs.Close() + } + return rs, exc, err } func (db *databricksDB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, string, error) { @@ -59,6 +62,7 @@ func (db *databricksDB) CancelExecution(ctx context.Context, exc Execution) erro if err != nil { return err } + defer con.Close() return con.Raw(func(driverConn any) error { dbsqlcon, ok := driverConn.(*Conn) if !ok { @@ -73,6 +77,7 @@ func (db *databricksDB) CheckExecution(ctx context.Context, exc Execution) (Exec if err != nil { return exc, err } + defer con.Close() exRet := exc err = con.Raw(func(driverConn any) error { dbsqlcon, ok := driverConn.(*Conn) diff --git a/examples/asyncWorkflow/main.go b/examples/asyncWorkflow/main.go index d2844bb7..ddb960c5 100644 --- a/examples/asyncWorkflow/main.go +++ b/examples/asyncWorkflow/main.go @@ -57,7 +57,7 @@ func main() { // make sure to close it later defer db.Close() - db.SetMaxOpenConns(1) + db.SetMaxOpenConns(2) // the "github.com/databricks/databricks-sql-go/driverctx" has some functions to help set the context for the driver ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "asyncWorkflow-example") @@ -66,12 +66,22 @@ func main() { // i := v // go func() { // _, exc, err := db.QueryContext(ogCtx, fmt.Sprintf("select %s", i)) - rs, exc, err := db.QueryContext(ogCtx, `SELECT id FROM RANGE(100000000) ORDER BY RANDOM() + 2 asc`) + rs, exc, err := db.QueryContext(ogCtx, `SELECT id FROM RANGE(100) ORDER BY RANDOM() + 2 asc`) if err != nil { log.Fatal(err) } defer rs.Close() + // can't do this. If direct results is done, the operation is gone + exc, err = db.CheckExecution(ogCtx, exc) + if err != nil { + log.Fatal(err) + } + if exc.Status == dbsql.ExecutionFinished { + rs, err = db.GetExecutionRows(ogCtx, exc) + if err != nil { + log.Fatal(err) + } var res string i := 0 for rs.Next() { diff --git a/rows.go b/rows.go index 5c9da4ca..a8d2e671 100644 --- a/rows.go +++ b/rows.go @@ -51,6 +51,10 @@ func (r *rows) Columns() []string { return []string{} } + if r.opHandle == nil { + return []string{} + } + resultMetadata, err := r.getResultMetadata() if err != nil { return []string{} @@ -76,15 +80,17 @@ func (r *rows) Close() error { if err != nil { return err } + if r.opHandle != nil { - req := cli_service.TCloseOperationReq{ - OperationHandle: r.opHandle, - } - ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) + req := cli_service.TCloseOperationReq{ + OperationHandle: r.opHandle, + } + ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) - _, err1 := r.client.CloseOperation(ctx, &req) - if err1 != nil { - return err1 + _, err1 := r.client.CloseOperation(ctx, &req) + if err1 != nil { + return err1 + } } return nil } @@ -103,6 +109,9 @@ func (r *rows) Next(dest []driver.Value) error { if err != nil { return err } + if r.opHandle == nil { + return io.EOF + } // if the next row is not in the current result page // fetch the containing page @@ -334,6 +343,9 @@ func (r *rows) getResultMetadata() (*cli_service.TGetResultSetMetadataResp, erro if err != nil { return nil, err } + if r.opHandle == nil { + return nil, errors.New("metadata not available") + } req := cli_service.TGetResultSetMetadataReq{ OperationHandle: r.opHandle,