@@ -2,12 +2,17 @@ package forward
22
33import (
44 "context"
5- "fmt"
65 "github.com/spf13/cobra"
6+ "net/http"
7+ "os"
8+ "os/signal"
9+ "path/filepath"
710 "pbench/log"
811 "pbench/presto"
912 "pbench/utils"
1013 "sync"
14+ "sync/atomic"
15+ "syscall"
1116 "time"
1217)
1318
@@ -17,49 +22,137 @@ var (
1722 RunName string
1823 PollInterval time.Duration
1924
20- runningTasks sync.WaitGroup
25+ runningTasks sync.WaitGroup
26+ failedToForward atomic.Uint32
27+ forwarded atomic.Uint32
2128)
2229
23- type QueryHistory struct {
24- QueryId string `presto:"query_id"`
25- Query string `presto:"query"`
26- Created * time.Time `presto:"created"`
27- }
28-
2930func Run (_ * cobra.Command , _ []string ) {
30- //OutputPath = filepath.Join(OutputPath, RunName)
31- //utils.PrepareOutputDirectory(OutputPath)
32- //
33- //// also start to write logs to the output directory from this point on.
34- //logPath := filepath.Join(OutputPath, "forward.log")
35- //flushLog := utils.InitLogFile(logPath)
36- //defer flushLog()
31+ OutputPath = filepath .Join (OutputPath , RunName )
32+ utils .PrepareOutputDirectory (OutputPath )
3733
38- prestoClusters := PrestoFlagsArray .Assemble ()
34+ // also start to write logs to the output directory from this point on.
35+ logPath := filepath .Join (OutputPath , "forward.log" )
36+ flushLog := utils .InitLogFile (logPath )
37+ defer flushLog ()
38+
39+ ctx , cancel := context .WithCancel (context .Background ())
40+ timeToExit := make (chan os.Signal , 1 )
41+ signal .Notify (timeToExit , syscall .SIGINT , syscall .SIGTERM , syscall .SIGQUIT )
42+ // Handle SIGINT, SIGTERM, and SIGQUIT. When ctx is canceled, in-progress MySQL transactions and InfluxDB operations will roll back.
43+ go func () {
44+ sig := <- timeToExit
45+ if sig != nil {
46+ log .Info ().Msg ("abort forwarding" )
47+ cancel ()
48+ }
49+ }()
50+
51+ prestoClusters := PrestoFlagsArray .Pivot ()
3952 // The design here is to forward the traffic from cluster 0 to the rest.
4053 sourceClusterSize := 0
4154 clients := make ([]* presto.Client , 0 , len (prestoClusters ))
4255 for i , cluster := range prestoClusters {
4356 clients = append (clients , cluster .NewPrestoClient ())
44- if stats , _ , err := clients [i ].GetClusterInfo (context .Background ()); err != nil {
45- log .Fatal ().Err (err ).Msgf ("cannot connect to cluster at position %d" , i )
57+ // Check if we can connect to the cluster.
58+ if stats , _ , err := clients [i ].GetClusterInfo (ctx ); err != nil {
59+ log .Fatal ().Err (err ).Msgf ("cannot connect to cluster at position %d: %s" , i , cluster .ServerUrl )
4660 } else if i == 0 {
4761 sourceClusterSize = stats .ActiveWorkers
4862 } else if stats .ActiveWorkers != sourceClusterSize {
49- log .Warn ().Msgf ("source cluster size does not match target cluster %d size (%d != %d)" , i , stats .ActiveWorkers , sourceClusterSize )
63+ log .Warn ().Msgf ("the source cluster and target cluster %d do not match in size (%d != %d)" , i , sourceClusterSize , stats .ActiveWorkers )
5064 }
5165 }
5266
5367 sourceClient := clients [0 ]
5468 trueValue := true
55- states , _ , err := sourceClient .GetQueryState (context .Background (), & presto.GetQueryStatsOptions {
56- IncludeAllQueries : & trueValue ,
57- IncludeAllQueryProgressStats : nil ,
58- ExcludeResourceGroupPathInfo : nil ,
59- QueryTextSizeLimit : nil ,
60- })
61- if err != nil {
62- log .Fatal ().Err (err ).Msgf ("cannot get query states" )
69+ // lastQueryStateCheckCutoffTime is the query create time of the most recent query in the previous batch.
70+ // We only look at queries created later than this timestamp in the following batch.
71+ lastQueryStateCheckCutoffTime := time.Time {}
72+ // Keep running until the source cluster becomes unavailable or the user interrupts or quits using Ctrl + C or Ctrl + D.
73+ for ctx .Err () == nil {
74+ states , _ , err := sourceClient .GetQueryState (ctx , & presto.GetQueryStatsOptions {IncludeAllQueries : & trueValue })
75+ if err != nil {
76+ log .Error ().Err (err ).Msgf ("failed to get query states" )
77+ break
78+ }
79+ newCutoffTime := time.Time {}
80+ for _ , state := range states {
81+ if ! state .CreateTime .After (lastQueryStateCheckCutoffTime ) {
82+ // We looked at this query in the previous batch.
83+ continue
84+ }
85+ if newCutoffTime .Before (state .CreateTime ) {
86+ newCutoffTime = state .CreateTime
87+ }
88+ runningTasks .Add (1 )
89+ go forwardQuery (ctx , & state , clients )
90+ }
91+ if newCutoffTime .After (lastQueryStateCheckCutoffTime ) {
92+ lastQueryStateCheckCutoffTime = newCutoffTime
93+ }
94+ timer := time .NewTimer (PollInterval )
95+ select {
96+ case <- ctx .Done ():
97+ case <- timer .C :
98+ }
99+ }
100+ runningTasks .Wait ()
101+ // This causes the signal handler to exit.
102+ close (timeToExit )
103+ log .Info ().Uint32 ("forwarded" , forwarded .Load ()).Uint32 ("failed_to_forward" , failedToForward .Load ()).
104+ Msgf ("finished forwarding queries" )
105+ }
106+
107+ func forwardQuery (ctx context.Context , queryState * presto.QueryStateInfo , clients []* presto.Client ) {
108+ defer runningTasks .Done ()
109+ queryInfo , _ , queryInfoErr := clients [0 ].GetQueryInfo (ctx , queryState .QueryId , false , nil )
110+ if queryInfoErr != nil {
111+ log .Error ().Str ("query_id" , queryState .QueryId ).Err (queryInfoErr ).Msg ("failed to get query info for forwarding" )
112+ failedToForward .Add (1 )
113+ return
114+ }
115+ SessionPropertyHeader := clients [0 ].GenerateSessionParamsHeaderValue (queryInfo .Session .CollectSessionProperties ())
116+ successful , failed := atomic.Uint32 {}, atomic.Uint32 {}
117+ forwardedQueries := sync.WaitGroup {}
118+ for i := 1 ; i < len (clients ); i ++ {
119+ forwardedQueries .Add (1 )
120+ go func (client * presto.Client ) {
121+ defer forwardedQueries .Done ()
122+ clientResult , _ , queryErr := client .Query (ctx , queryInfo .Query , func (req * http.Request ) {
123+ if queryInfo .Session .Catalog != nil {
124+ req .Header .Set (presto .CatalogHeader , * queryInfo .Session .Catalog )
125+ }
126+ if queryInfo .Session .Schema != nil {
127+ req .Header .Set (presto .SchemaHeader , * queryInfo .Session .Schema )
128+ }
129+ req .Header .Set (presto .SessionHeader , SessionPropertyHeader )
130+ req .Header .Set (presto .SourceHeader , queryInfo .QueryId )
131+ })
132+ if queryErr != nil {
133+ log .Error ().Str ("source_query_id" , queryInfo .QueryId ).
134+ Str ("target_host" , client .GetHost ()).Err (queryErr ).Msg ("failed to execute query" )
135+ failed .Add (1 )
136+ return
137+ }
138+ rowCount := 0
139+ drainErr := clientResult .Drain (ctx , func (qr * presto.QueryResults ) error {
140+ rowCount += len (qr .Data )
141+ return nil
142+ })
143+ if drainErr != nil {
144+ log .Error ().Str ("source_query_id" , queryInfo .QueryId ).
145+ Str ("target_host" , client .GetHost ()).Err (drainErr ).Msg ("failed to fetch query result" )
146+ failed .Add (1 )
147+ return
148+ }
149+ successful .Add (1 )
150+ log .Info ().Str ("source_query_id" , queryInfo .QueryId ).
151+ Str ("target_host" , client .GetHost ()).Int ("row_count" , rowCount ).Msg ("query executed successfully" )
152+ }(clients [i ])
63153 }
64- fmt .Printf ("%#v" , states )
154+ forwardedQueries .Wait ()
155+ log .Info ().Str ("source_query_id" , queryInfo .QueryId ).Uint32 ("successful" , successful .Load ()).
156+ Uint32 ("failed" , failed .Load ()).Msg ("query forwarding finished" )
157+ forwarded .Add (1 )
65158}
0 commit comments