Skip to content

Commit ecfe6ea

Browse files
committed
Merge branch 'main' into spanner-lib-create-pool-and-connection
2 parents c576fab + f5100ce commit ecfe6ea

File tree

10 files changed

+118
-52
lines changed

10 files changed

+118
-52
lines changed

benchmarks/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ require (
1010
cloud.google.com/go v0.122.0
1111
cloud.google.com/go/spanner v1.85.1
1212
github.com/google/uuid v1.6.0
13-
github.com/googleapis/go-sql-spanner v1.17.0
13+
github.com/googleapis/go-sql-spanner v1.18.0
1414
google.golang.org/api v0.249.0
1515
google.golang.org/grpc v1.75.1
1616
google.golang.org/protobuf v1.36.9

checksum_row_iterator.go

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"cloud.google.com/go/spanner"
2525
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
26+
"github.com/googleapis/go-sql-spanner/parser"
2627
"google.golang.org/api/iterator"
2728
"google.golang.org/grpc/codes"
2829
"google.golang.org/grpc/status"
@@ -51,10 +52,11 @@ type checksumRowIterator struct {
5152
*spanner.RowIterator
5253
metadata *sppb.ResultSetMetadata
5354

54-
ctx context.Context
55-
tx *readWriteTransaction
56-
stmt spanner.Statement
57-
options spanner.QueryOptions
55+
ctx context.Context
56+
tx *readWriteTransaction
57+
stmt spanner.Statement
58+
stmtType parser.StatementType
59+
options spanner.QueryOptions
5860
// nc (nextCount) indicates the number of times that next has been called
5961
// on the iterator. Next() will be called the same number of times during
6062
// a retry.
@@ -253,10 +255,5 @@ func (it *checksumRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
253255
}
254256

255257
func (it *checksumRowIterator) ResultSetStats() *sppb.ResultSetStats {
256-
// TODO: The Spanner client library should offer an option to get the full
257-
// ResultSetStats, instead of only the RowCount and QueryPlan.
258-
return &sppb.ResultSetStats{
259-
RowCount: &sppb.ResultSetStats_RowCountExact{RowCountExact: it.RowIterator.RowCount},
260-
QueryPlan: it.RowIterator.QueryPlan,
261-
}
258+
return createResultSetStats(it.RowIterator, it.stmtType)
262259
}

conn.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ type conn struct {
262262
resetForRetry bool
263263
database string
264264

265-
execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator
266-
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error)
265+
execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, bound spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator
266+
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error)
267267
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error)
268268
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (int64, error)
269269

@@ -860,9 +860,9 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
860860
if err != nil {
861861
return nil, err
862862
}
863-
statementType := c.parser.DetectStatementType(query)
863+
statementInfo := c.parser.DetectStatementType(query)
864864
// DDL statements are not supported in QueryContext so use the execContext method for the execution.
865-
if statementType.StatementType == parser.StatementTypeDdl {
865+
if statementInfo.StatementType == parser.StatementTypeDdl {
866866
res, err := c.execContext(ctx, query, execOptions, args)
867867
if err != nil {
868868
return nil, err
@@ -871,10 +871,10 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
871871
}
872872
var iter rowIterator
873873
if c.tx == nil {
874-
if statementType.StatementType == parser.StatementTypeDml {
874+
if statementInfo.StatementType == parser.StatementTypeDml {
875875
// Use a read/write transaction to execute the statement.
876876
var commitResponse *spanner.CommitResponse
877-
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions)
877+
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, statementInfo, execOptions)
878878
if err != nil {
879879
return nil, err
880880
}
@@ -887,13 +887,13 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
887887
// The statement was either detected as being a query, or potentially not recognized at all.
888888
// In that case, just default to using a single-use read-only transaction and let Spanner
889889
// return an error if the statement is not suited for that type of transaction.
890-
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.ReadOnlyStaleness(), execOptions)}
890+
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, statementInfo, c.ReadOnlyStaleness(), execOptions), statementInfo.StatementType}
891891
}
892892
} else {
893893
if execOptions.PartitionedQueryOptions.PartitionQuery {
894894
return c.tx.partitionQuery(ctx, stmt, execOptions)
895895
}
896-
iter, err = c.tx.Query(ctx, stmt, execOptions)
896+
iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions)
897897
if err != nil {
898898
return nil, err
899899
}
@@ -1341,7 +1341,7 @@ func (c *conn) Rollback(ctx context.Context) error {
13411341
return c.tx.Rollback()
13421342
}
13431343

1344-
func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
1344+
func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
13451345
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
13461346
}
13471347

@@ -1363,7 +1363,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
13631363
return r, nil
13641364
}
13651365

1366-
func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) {
1366+
func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) {
13671367
var result *wrappedRowIterator
13681368
options.QueryOptions.LastStatement = true
13691369
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
@@ -1372,6 +1372,7 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
13721372
if err == iterator.Done {
13731373
result = &wrappedRowIterator{
13741374
RowIterator: it,
1375+
stmtType: statementInfo.StatementType,
13751376
noRows: true,
13761377
}
13771378
} else if err != nil {
@@ -1380,6 +1381,7 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
13801381
} else {
13811382
result = &wrappedRowIterator{
13821383
RowIterator: it,
1384+
stmtType: statementInfo.StatementType,
13831385
firstRow: row,
13841386
}
13851387
}

driver_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) {
630630
logger: noopLogger,
631631
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
632632
batch: &batch{tp: parser.BatchTypeDdl},
633-
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
633+
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
634634
return &spanner.RowIterator{}
635635
},
636636
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
@@ -670,7 +670,7 @@ func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) {
670670
logger: noopLogger,
671671
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
672672
batch: &batch{tp: parser.BatchTypeDml},
673-
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
673+
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
674674
return &spanner.RowIterator{}
675675
},
676676
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
@@ -761,7 +761,7 @@ func TestConn_GetCommitResponseAfterAutocommitDml(t *testing.T) {
761761
parser: p,
762762
logger: noopLogger,
763763
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
764-
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
764+
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
765765
return &spanner.RowIterator{}
766766
},
767767
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
@@ -800,7 +800,7 @@ func TestConn_GetCommitResponseAfterAutocommitQuery(t *testing.T) {
800800
parser: p,
801801
logger: noopLogger,
802802
state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
803-
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
803+
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
804804
return &spanner.RowIterator{}
805805
},
806806
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {

driver_with_mockserver_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5384,6 +5384,73 @@ func TestReturnResultSetStats(t *testing.T) {
53845384
}
53855385
}
53865386

5387+
func TestReturnResultSetStatsForQuery(t *testing.T) {
5388+
t.Parallel()
5389+
5390+
db, server, teardown := setupTestDBConnection(t)
5391+
defer teardown()
5392+
query := "select id from singers where id=42598"
5393+
resultSet := testutil.CreateSingleColumnInt64ResultSet([]int64{42598}, "id")
5394+
_ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{
5395+
Type: testutil.StatementResultResultSet,
5396+
ResultSet: resultSet,
5397+
})
5398+
5399+
rows, err := db.QueryContext(context.Background(), query, ExecOptions{ReturnResultSetStats: true})
5400+
if err != nil {
5401+
t.Fatal(err)
5402+
}
5403+
defer func() { _ = rows.Close() }()
5404+
5405+
// The first result set should contain the data.
5406+
for want := int64(42598); rows.Next(); want++ {
5407+
cols, err := rows.Columns()
5408+
if err != nil {
5409+
t.Fatal(err)
5410+
}
5411+
if !cmp.Equal(cols, []string{"id"}) {
5412+
t.Fatalf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"id"})
5413+
}
5414+
var got int64
5415+
err = rows.Scan(&got)
5416+
if err != nil {
5417+
t.Fatal(err)
5418+
}
5419+
if got != want {
5420+
t.Fatalf("value mismatch\nGot: %v\nWant: %v", got, want)
5421+
}
5422+
}
5423+
if rows.Err() != nil {
5424+
t.Fatal(rows.Err())
5425+
}
5426+
5427+
// The next result set should contain the stats.
5428+
if !rows.NextResultSet() {
5429+
t.Fatal("missing stats result set")
5430+
}
5431+
5432+
// Get the stats.
5433+
if !rows.Next() {
5434+
t.Fatal("no stats rows")
5435+
}
5436+
var stats *sppb.ResultSetStats
5437+
if err := rows.Scan(&stats); err != nil {
5438+
t.Fatalf("failed to scan stats: %v", err)
5439+
}
5440+
// The stats should not contain any update count.
5441+
if stats.GetRowCount() != nil {
5442+
t.Fatalf("got update count for query")
5443+
}
5444+
if rows.Next() {
5445+
t.Fatal("more rows than expected")
5446+
}
5447+
5448+
// There should be no more result sets.
5449+
if rows.NextResultSet() {
5450+
t.Fatal("more result sets than expected")
5451+
}
5452+
}
5453+
53875454
func TestReturnResultSetMetadataAndStats(t *testing.T) {
53885455
t.Parallel()
53895456

partitioned_query.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"io"
2323

2424
"cloud.google.com/go/spanner"
25+
"github.com/googleapis/go-sql-spanner/parser"
2526
"google.golang.org/grpc/codes"
2627
"google.golang.org/grpc/status"
2728
)
@@ -231,7 +232,7 @@ func (pq *PartitionedQuery) execute(ctx context.Context, index int) (*rows, erro
231232
return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid partition index: %d", index))
232233
}
233234
spannerIter := pq.tx.Execute(ctx, pq.Partitions[index])
234-
iter := &readOnlyRowIterator{spannerIter}
235+
iter := &readOnlyRowIterator{spannerIter, parser.StatementTypeQuery}
235236
return &rows{it: iter, decodeOption: pq.execOptions.DecodeOption}, nil
236237
}
237238

snippets/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ replace github.com/googleapis/go-sql-spanner => ../
99
require (
1010
cloud.google.com/go/spanner v1.85.1
1111
github.com/docker/docker v28.4.0+incompatible
12-
github.com/googleapis/go-sql-spanner v1.17.0
12+
github.com/googleapis/go-sql-spanner v1.18.0
1313
github.com/testcontainers/testcontainers-go v0.38.0
1414
)
1515

spannerlib/shared/shared_lib_test.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,24 +205,14 @@ func TestExecute(t *testing.T) {
205205
t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w)
206206
}
207207

208-
// Get the ResultSetStats. For queries, this is an empty instance.
208+
// Get the ResultSetStats. For queries, this is nil.
209209
mem, code, _, length, data = ResultSetStats(poolId, connId, rowsId)
210210
if g, w := code, int32(0); g != w {
211211
t.Fatalf("ResultSetStats result code mismatch\n Got: %v\nWant: %v", g, w)
212212
}
213-
if length == int32(0) {
214-
t.Fatalf("ResultSetStats length mismatch: %v", length)
215-
}
216-
statsBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes()
217-
stats := &spannerpb.ResultSetStats{}
218-
if err := proto.Unmarshal(statsBytes, stats); err != nil {
219-
t.Fatal(err)
213+
if g, w := length, int32(0); g != w {
214+
t.Fatalf("ResultSetStats length mismatch\n Got: %v\nWant: %v", g, w)
220215
}
221-
// TODO: Enable when this branch is up to date with main
222-
// emptyStats := &spannerpb.ResultSetStats{}
223-
//if g, w := stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) {
224-
// t.Fatalf("ResultSetStats mismatch\n Got: %v\nWant: %v", g, w)
225-
//}
226216
if res := Release(mem); res != 0 {
227217
t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", res, 0)
228218
}

transaction.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ type contextTransaction interface {
4444
Commit() error
4545
Rollback() error
4646
resetForRetry(ctx context.Context) error
47-
Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error)
47+
Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error)
4848
partitionQuery(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (driver.Rows, error)
4949
ExecContext(ctx context.Context, stmt spanner.Statement, statementInfo *parser.StatementInfo, options spanner.QueryOptions) (*result, error)
5050

@@ -67,6 +67,7 @@ var _ rowIterator = &readOnlyRowIterator{}
6767

6868
type readOnlyRowIterator struct {
6969
*spanner.RowIterator
70+
stmtType parser.StatementType
7071
}
7172

7273
func (ri *readOnlyRowIterator) Next() (*spanner.Row, error) {
@@ -82,12 +83,19 @@ func (ri *readOnlyRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
8283
}
8384

8485
func (ri *readOnlyRowIterator) ResultSetStats() *sppb.ResultSetStats {
86+
return createResultSetStats(ri.RowIterator, ri.stmtType)
87+
}
88+
89+
func createResultSetStats(it *spanner.RowIterator, stmtType parser.StatementType) *sppb.ResultSetStats {
8590
// TODO: The Spanner client library should offer an option to get the full
8691
// ResultSetStats, instead of only the RowCount and QueryPlan.
87-
return &sppb.ResultSetStats{
88-
RowCount: &sppb.ResultSetStats_RowCountExact{RowCountExact: ri.RowIterator.RowCount},
89-
QueryPlan: ri.RowIterator.QueryPlan,
92+
stats := &sppb.ResultSetStats{
93+
QueryPlan: it.QueryPlan,
94+
}
95+
if stmtType == parser.StatementTypeDml {
96+
stats.RowCount = &sppb.ResultSetStats_RowCountExact{RowCountExact: it.RowCount}
9097
}
98+
return stats
9199
}
92100

93101
type txResult int
@@ -135,7 +143,7 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error {
135143
return nil
136144
}
137145

138-
func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error) {
146+
func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) {
139147
tx.logger.DebugContext(ctx, "Query", "stmt", stmt.SQL)
140148
if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
141149
if tx.boTx == nil {
@@ -152,7 +160,7 @@ func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement
152160
}
153161
return mi, nil
154162
}
155-
return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions)}, nil
163+
return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), stmtType}, nil
156164
}
157165

158166
func (tx *readOnlyTransaction) partitionQuery(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (driver.Rows, error) {
@@ -457,7 +465,7 @@ func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error {
457465
// Query executes a query using the read/write transaction and returns a
458466
// rowIterator that will automatically retry the read/write transaction if the
459467
// transaction is aborted during the query or while iterating the returned rows.
460-
func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error) {
468+
func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) {
461469
tx.logger.Debug("Query", "stmt", stmt.SQL)
462470
tx.active = true
463471
if err := tx.maybeRunAutoDmlBatch(ctx); err != nil {
@@ -466,7 +474,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen
466474
// If internal retries have been disabled, we don't need to keep track of a
467475
// running checksum for all results that we have seen.
468476
if !tx.retryAborts() {
469-
return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions)}, nil
477+
return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), stmtType}, nil
470478
}
471479

472480
// If retries are enabled, we need to use a row iterator that will keep
@@ -477,6 +485,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen
477485
ctx: ctx,
478486
tx: tx,
479487
stmt: stmt,
488+
stmtType: stmtType,
480489
options: execOptions.QueryOptions,
481490
buffer: buffer,
482491
enc: gob.NewEncoder(buffer),

0 commit comments

Comments
 (0)