@@ -17,6 +17,7 @@ package api
1717import (
1818 "context"
1919 "database/sql"
20+ "database/sql/driver"
2021 "fmt"
2122 "strings"
2223 "sync"
@@ -25,6 +26,9 @@ import (
2526 "cloud.google.com/go/spanner"
2627 "cloud.google.com/go/spanner/apiv1/spannerpb"
2728 spannerdriver "github.com/googleapis/go-sql-spanner"
29+ "google.golang.org/grpc/codes"
30+ "google.golang.org/grpc/status"
31+ "google.golang.org/protobuf/types/known/timestamppb"
2832)
2933
3034// CloseConnection looks up the connection with the given poolId and connId and closes it.
@@ -42,6 +46,35 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error {
4246 return conn .close (ctx )
4347}
4448
49+ // BeginTransaction starts a new transaction on the given connection.
50+ // A connection can have at most one transaction at any time. This function therefore returns an error if the
51+ // connection has an active transaction.
52+ func BeginTransaction (ctx context.Context , poolId , connId int64 , txOpts * spannerpb.TransactionOptions ) error {
53+ conn , err := findConnection (poolId , connId )
54+ if err != nil {
55+ return err
56+ }
57+ return conn .BeginTransaction (ctx , txOpts )
58+ }
59+
60+ // Commit commits the current transaction on the given connection.
61+ func Commit (ctx context.Context , poolId , connId int64 ) (* spannerpb.CommitResponse , error ) {
62+ conn , err := findConnection (poolId , connId )
63+ if err != nil {
64+ return nil , err
65+ }
66+ return conn .commit (ctx )
67+ }
68+
69+ // Rollback rollbacks the current transaction on the given connection.
70+ func Rollback (ctx context.Context , poolId , connId int64 ) error {
71+ conn , err := findConnection (poolId , connId )
72+ if err != nil {
73+ return err
74+ }
75+ return conn .rollback (ctx )
76+ }
77+
4578func Execute (ctx context.Context , poolId , connId int64 , executeSqlRequest * spannerpb.ExecuteSqlRequest ) (int64 , error ) {
4679 conn , err := findConnection (poolId , connId )
4780 if err != nil {
@@ -59,23 +92,141 @@ type Connection struct {
5992 backend * sql.Conn
6093}
6194
95+ // spannerConn is an internal interface that contains the internal functions that are used by this API.
96+ // It is implemented by the spannerdriver.conn struct.
97+ type spannerConn interface {
98+ BeginReadOnlyTransaction (ctx context.Context , options * spannerdriver.ReadOnlyTransactionOptions ) (driver.Tx , error )
99+ BeginReadWriteTransaction (ctx context.Context , options * spannerdriver.ReadWriteTransactionOptions ) (driver.Tx , error )
100+ Commit (ctx context.Context ) (* spanner.CommitResponse , error )
101+ Rollback (ctx context.Context ) error
102+ }
103+
62104type queryExecutor interface {
63105 ExecContext (ctx context.Context , query string , args ... any ) (sql.Result , error )
64106 QueryContext (ctx context.Context , query string , args ... any ) (* sql.Rows , error )
65107}
66108
67109func (conn * Connection ) close (ctx context.Context ) error {
68110 conn .closeResults (ctx )
111+ // Rollback any open transactions on the connection.
112+ _ = conn .rollback (ctx )
113+
69114 err := conn .backend .Close ()
70115 if err != nil {
71116 return err
72117 }
73118 return nil
74119}
75120
121+ func (conn * Connection ) BeginTransaction (ctx context.Context , txOpts * spannerpb.TransactionOptions ) error {
122+ var err error
123+ if txOpts .GetReadOnly () != nil {
124+ return conn .beginReadOnlyTransaction (ctx , convertToReadOnlyOpts (txOpts ))
125+ } else if txOpts .GetPartitionedDml () != nil {
126+ err = spanner .ToSpannerError (status .Error (codes .InvalidArgument , "transaction type not supported" ))
127+ } else {
128+ return conn .beginReadWriteTransaction (ctx , convertToReadWriteTransactionOptions (txOpts ))
129+ }
130+ if err != nil {
131+ return err
132+ }
133+ return nil
134+ }
135+
136+ func (conn * Connection ) beginReadOnlyTransaction (ctx context.Context , opts * spannerdriver.ReadOnlyTransactionOptions ) error {
137+ return conn .backend .Raw (func (driverConn any ) (err error ) {
138+ sc , _ := driverConn .(spannerConn )
139+ _ , err = sc .BeginReadOnlyTransaction (ctx , opts )
140+ return err
141+ })
142+ }
143+
144+ func (conn * Connection ) beginReadWriteTransaction (ctx context.Context , opts * spannerdriver.ReadWriteTransactionOptions ) error {
145+ return conn .backend .Raw (func (driverConn any ) (err error ) {
146+ sc , _ := driverConn .(spannerConn )
147+ _ , err = sc .BeginReadWriteTransaction (ctx , opts )
148+ return err
149+ })
150+ }
151+
152+ func (conn * Connection ) commit (ctx context.Context ) (* spannerpb.CommitResponse , error ) {
153+ var response * spanner.CommitResponse
154+ if err := conn .backend .Raw (func (driverConn any ) (err error ) {
155+ spannerConn , _ := driverConn .(spannerConn )
156+ response , err = spannerConn .Commit (ctx )
157+ if err != nil {
158+ return err
159+ }
160+ return nil
161+ }); err != nil {
162+ return nil , err
163+ }
164+
165+ // The commit response is nil for read-only transactions.
166+ if response == nil {
167+ return nil , nil
168+ }
169+ // TODO: Include commit stats
170+ return & spannerpb.CommitResponse {CommitTimestamp : timestamppb .New (response .CommitTs )}, nil
171+ }
172+
173+ func (conn * Connection ) rollback (ctx context.Context ) error {
174+ return conn .backend .Raw (func (driverConn any ) (err error ) {
175+ spannerConn , _ := driverConn .(spannerConn )
176+ return spannerConn .Rollback (ctx )
177+ })
178+ }
179+
180+ func convertToReadOnlyOpts (txOpts * spannerpb.TransactionOptions ) * spannerdriver.ReadOnlyTransactionOptions {
181+ return & spannerdriver.ReadOnlyTransactionOptions {
182+ TimestampBound : convertTimestampBound (txOpts ),
183+ }
184+ }
185+
186+ func convertTimestampBound (txOpts * spannerpb.TransactionOptions ) spanner.TimestampBound {
187+ ro := txOpts .GetReadOnly ()
188+ if ro .GetStrong () {
189+ return spanner .StrongRead ()
190+ } else if ro .GetReadTimestamp () != nil {
191+ return spanner .ReadTimestamp (ro .GetReadTimestamp ().AsTime ())
192+ } else if ro .GetMinReadTimestamp () != nil {
193+ return spanner .ReadTimestamp (ro .GetMinReadTimestamp ().AsTime ())
194+ } else if ro .GetExactStaleness () != nil {
195+ return spanner .ExactStaleness (ro .GetExactStaleness ().AsDuration ())
196+ } else if ro .GetMaxStaleness () != nil {
197+ return spanner .MaxStaleness (ro .GetMaxStaleness ().AsDuration ())
198+ }
199+ return spanner.TimestampBound {}
200+ }
201+
202+ func convertToReadWriteTransactionOptions (txOpts * spannerpb.TransactionOptions ) * spannerdriver.ReadWriteTransactionOptions {
203+ readLockMode := spannerpb .TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED
204+ if txOpts .GetReadWrite () != nil {
205+ readLockMode = txOpts .GetReadWrite ().GetReadLockMode ()
206+ }
207+ return & spannerdriver.ReadWriteTransactionOptions {
208+ TransactionOptions : spanner.TransactionOptions {
209+ IsolationLevel : txOpts .GetIsolationLevel (),
210+ ReadLockMode : readLockMode ,
211+ },
212+ }
213+ }
214+
215+ func convertIsolationLevel (level spannerpb.TransactionOptions_IsolationLevel ) sql.IsolationLevel {
216+ switch level {
217+ case spannerpb .TransactionOptions_SERIALIZABLE :
218+ return sql .LevelSerializable
219+ case spannerpb .TransactionOptions_REPEATABLE_READ :
220+ return sql .LevelRepeatableRead
221+ }
222+ return sql .LevelDefault
223+ }
224+
76225func (conn * Connection ) closeResults (ctx context.Context ) {
77226 conn .results .Range (func (key , value interface {}) bool {
78- // TODO: Implement
227+ if r , ok := value .(* rows ); ok {
228+ _ = r .Close (ctx )
229+ }
79230 return true
80231 })
81232}
0 commit comments