Skip to content

Commit 6fc68b8

Browse files
author
Hoang Phan
committed
Change to using stateful source instead
1 parent 53d008c commit 6fc68b8

File tree

4 files changed

+214
-130
lines changed

4 files changed

+214
-130
lines changed

python/sources/mysql_cdc/main.py

Lines changed: 211 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,257 @@
11
from quixstreams import Application
2+
from quixstreams.sources.base import StatefulSource
23
import time
34
import os
45
import json
56
from setup_logger import logger
6-
from mysql_helper import connect_mysql, enable_binlog_if_needed, setup_mysql_cdc, create_binlog_stream, get_changes, perform_initial_snapshot, save_binlog_position
7+
from mysql_helper import connect_mysql, enable_binlog_if_needed, setup_mysql_cdc, create_binlog_stream, get_changes, perform_initial_snapshot
78

89
# Load environment variables (useful when working locally)
910
from dotenv import load_dotenv
1011
load_dotenv()
1112

12-
# Global Variables
13-
MYSQL_SCHEMA = os.environ["MYSQL_SCHEMA"] # MySQL database name
14-
MYSQL_TABLE = os.environ["MYSQL_TABLE"] # MySQL table name
15-
MYSQL_TABLE_NAME = f"{MYSQL_SCHEMA}.{MYSQL_TABLE}"
16-
WAIT_INTERVAL = 0.1
17-
18-
# Initial snapshot configuration
19-
INITIAL_SNAPSHOT = os.getenv("INITIAL_SNAPSHOT", "false").lower() == "true"
20-
SNAPSHOT_BATCH_SIZE = int(os.getenv("SNAPSHOT_BATCH_SIZE", "1000"))
21-
FORCE_SNAPSHOT = os.getenv("FORCE_SNAPSHOT", "false").lower() == "true"
22-
23-
# State management - use Quix state dir if available, otherwise default to "state"
24-
STATE_DIR = os.getenv("Quix__State__Dir", "state")
25-
SNAPSHOT_STATE_FILE = os.path.join(STATE_DIR, f"snapshot_completed_{MYSQL_SCHEMA}_{MYSQL_TABLE}.flag")
13+
class MySqlCdcSource(StatefulSource):
14+
def __init__(self, name: str = "mysql-cdc-source"):
15+
super().__init__(name=name)
16+
17+
# Load configuration from environment variables
18+
self.mysql_schema = os.environ["MYSQL_SCHEMA"] # MySQL database name
19+
self.mysql_table = os.environ["MYSQL_TABLE"] # MySQL table name
20+
self.mysql_table_name = f"{self.mysql_schema}.{self.mysql_table}"
21+
self.wait_interval = 0.1
22+
23+
# Initial snapshot configuration
24+
self.initial_snapshot = os.getenv("INITIAL_SNAPSHOT", "false").lower() == "true"
25+
self.snapshot_batch_size = int(os.getenv("SNAPSHOT_BATCH_SIZE", "1000"))
26+
self.force_snapshot = os.getenv("FORCE_SNAPSHOT", "false").lower() == "true"
27+
28+
# Connection objects - will be initialized in setup()
29+
self.conn = None
30+
self.binlog_stream = None
31+
32+
# Message buffering
33+
self.buffer = []
34+
self.last_flush_time = time.time()
35+
self.flush_interval = 0.5 # 500ms
2636

27-
def ensure_state_dir():
28-
"""Create state directory if it doesn't exist"""
29-
if not os.path.exists(STATE_DIR):
30-
os.makedirs(STATE_DIR)
31-
logger.info(f"Created state directory: {STATE_DIR}")
37+
def setup(self):
38+
"""Initialize MySQL connection and CDC setup"""
39+
try:
40+
enable_binlog_if_needed()
41+
setup_mysql_cdc(self.mysql_table)
42+
self.conn = connect_mysql()
43+
self.binlog_stream = create_binlog_stream()
44+
logger.info("MySQL CDC CONNECTED!")
45+
except Exception as e:
46+
logger.error(f"ERROR during MySQL CDC setup - {e}")
47+
raise
3248

33-
def is_snapshot_completed():
34-
"""Check if initial snapshot has been completed"""
35-
return os.path.exists(SNAPSHOT_STATE_FILE) and not FORCE_SNAPSHOT
49+
def is_snapshot_completed(self):
50+
"""Check if initial snapshot has been completed using state store"""
51+
snapshot_key = f"snapshot_completed_{self.mysql_schema}_{self.mysql_table}"
52+
return self.state.get(snapshot_key, False) and not self.force_snapshot
3653

37-
def mark_snapshot_completed():
38-
"""Mark initial snapshot as completed"""
39-
ensure_state_dir()
40-
with open(SNAPSHOT_STATE_FILE, 'w') as f:
41-
f.write(json.dumps({
54+
def mark_snapshot_completed(self):
55+
"""Mark initial snapshot as completed in state store"""
56+
snapshot_key = f"snapshot_completed_{self.mysql_schema}_{self.mysql_table}"
57+
snapshot_info = {
4258
"completed_at": time.time(),
43-
"schema": MYSQL_SCHEMA,
44-
"table": MYSQL_TABLE,
59+
"schema": self.mysql_schema,
60+
"table": self.mysql_table,
4561
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
46-
}))
47-
logger.info(f"Snapshot completion marked in: {SNAPSHOT_STATE_FILE}")
48-
49-
def get_snapshot_info():
50-
"""Get information about when snapshot was completed"""
51-
if os.path.exists(SNAPSHOT_STATE_FILE):
52-
try:
53-
with open(SNAPSHOT_STATE_FILE, 'r') as f:
54-
return json.loads(f.read())
55-
except:
56-
return None
57-
return None
58-
59-
# Create a Quix Application, this manages the connection to the Quix platform
60-
app = Application()
61-
62-
# Connect to MySQL and set up CDC
63-
try:
64-
enable_binlog_if_needed()
65-
setup_mysql_cdc(MYSQL_TABLE)
66-
conn = connect_mysql()
67-
binlog_stream = create_binlog_stream()
68-
logger.info("MySQL CDC CONNECTED!")
69-
except Exception as e:
70-
logger.error(f"ERROR! - {e}")
71-
raise
62+
}
63+
self.state.set(snapshot_key, True)
64+
self.state.set(f"snapshot_info_{self.mysql_schema}_{self.mysql_table}", snapshot_info)
65+
logger.info(f"Snapshot completion marked in state store for {self.mysql_table_name}")
7266

73-
# should the main loop run?
74-
run = True
67+
def get_snapshot_info(self):
68+
"""Get information about when snapshot was completed"""
69+
info_key = f"snapshot_info_{self.mysql_schema}_{self.mysql_table}"
70+
return self.state.get(info_key, None)
7571

76-
# Create the producer, this is used to write data to the output topic
77-
producer = app.get_producer()
72+
def save_binlog_position(self, log_file, log_pos):
73+
"""Save binlog position to state store"""
74+
binlog_key = f"binlog_position_{self.mysql_schema}_{self.mysql_table}"
75+
position_info = {
76+
"log_file": log_file,
77+
"log_pos": log_pos,
78+
"timestamp": time.time()
79+
}
80+
self.state.set(binlog_key, position_info)
7881

79-
# Check the output topic is configured
80-
output_topic_name = os.getenv("output", "")
81-
if output_topic_name == "":
82-
raise ValueError("output_topic environment variable is required")
83-
output_topic = app.topic(output_topic_name)
82+
def get_binlog_position(self):
83+
"""Get saved binlog position from state store"""
84+
binlog_key = f"binlog_position_{self.mysql_schema}_{self.mysql_table}"
85+
return self.state.get(binlog_key, None)
8486

85-
# get data from MySQL binlog and publish it to kafka
86-
# to reduce network traffic, we buffer the messages for 500ms
87-
def main():
88-
buffer = []
89-
last_flush_time = time.time()
90-
91-
# Perform initial snapshot if enabled and not already completed
92-
if INITIAL_SNAPSHOT:
93-
if is_snapshot_completed():
94-
snapshot_info = get_snapshot_info()
95-
if FORCE_SNAPSHOT:
87+
def perform_initial_snapshot_if_needed(self):
88+
"""Perform initial snapshot if enabled and not already completed"""
89+
if not self.initial_snapshot:
90+
logger.info("Initial snapshot is disabled - starting CDC stream only")
91+
return
92+
93+
if self.is_snapshot_completed():
94+
snapshot_info = self.get_snapshot_info()
95+
if self.force_snapshot:
9696
logger.info("Initial snapshot already completed but FORCE_SNAPSHOT=true - performing snapshot again...")
9797
else:
9898
logger.info(f"Initial snapshot already completed at {snapshot_info.get('timestamp', 'unknown time')} - skipping")
99+
return
99100
else:
100101
logger.info("Initial snapshot is enabled and not yet completed - performing snapshot...")
101-
102-
if not is_snapshot_completed():
102+
103+
if not self.is_snapshot_completed() or self.force_snapshot:
103104
try:
104-
snapshot_changes = perform_initial_snapshot(MYSQL_SCHEMA, MYSQL_TABLE, SNAPSHOT_BATCH_SIZE)
105+
snapshot_changes = perform_initial_snapshot(
106+
self.mysql_schema,
107+
self.mysql_table,
108+
self.snapshot_batch_size
109+
)
105110

106111
# Send snapshot data to Kafka immediately
107112
for change in snapshot_changes:
108-
producer.produce(topic=output_topic.name,
109-
key=MYSQL_TABLE_NAME,
110-
value=json.dumps(change))
113+
msg = self.serialize(
114+
key=self.mysql_table_name,
115+
value=change
116+
)
117+
self.produce(
118+
key=msg.key,
119+
value=msg.value,
120+
)
111121

112-
# Flush to ensure all snapshot data is sent
113-
producer.flush()
122+
# Flush to ensure all snapshot data is sent and commit state
123+
self.flush()
114124
logger.info(f"Initial snapshot completed - {len(snapshot_changes)} records sent to Kafka")
115125

116126
# Mark snapshot as completed
117-
mark_snapshot_completed()
127+
self.mark_snapshot_completed()
128+
# Flush again to save the snapshot completion state
129+
self.flush()
118130

119131
except Exception as e:
120132
logger.error(f"Failed to perform initial snapshot: {e}")
121133
raise
122-
else:
123-
logger.info("Initial snapshot is disabled - starting CDC stream only")
124134

125-
# Start CDC loop
126-
while run:
127-
# Get changes from MySQL binlog
128-
changes = get_changes(binlog_stream, MYSQL_SCHEMA, MYSQL_TABLE)
129-
for change in changes:
130-
buffer.append(change)
131-
132-
if len(buffer) > 0:
133-
print(f"Buffer length: {len(buffer)}")
134-
print(f"Buffer: {buffer}")
135-
136-
# Check if 500 milliseconds have passed
135+
def process_buffered_messages(self):
136+
"""Process and send buffered messages if flush interval has passed"""
137137
current_time = time.time()
138-
if (current_time - last_flush_time) >= 0.5 and len(buffer) > 0:
139-
# If 500ms have passed, produce all buffered messages
140-
for message in buffer:
141-
producer.produce(topic=output_topic.name,
142-
key=MYSQL_TABLE_NAME,
143-
value=json.dumps(message))
144-
print("Message sent to Kafka")
138+
139+
if (current_time - self.last_flush_time) >= self.flush_interval and len(self.buffer) > 0:
140+
logger.debug(f"Processing {len(self.buffer)} buffered messages")
145141

146-
# Flush the producer to send the messages
147-
producer.flush()
142+
# Send all buffered messages
143+
for message in self.buffer:
144+
msg = self.serialize(
145+
key=self.mysql_table_name,
146+
value=message
147+
)
148+
self.produce(
149+
key=msg.key,
150+
value=msg.value,
151+
)
148152

149-
# Save binlog position after successful send
150-
if hasattr(binlog_stream, 'log_file') and hasattr(binlog_stream, 'log_pos'):
151-
save_binlog_position(binlog_stream.log_file, binlog_stream.log_pos)
152-
153-
# Clear the buffer
154-
buffer = []
155-
# Update the last flush time
156-
last_flush_time = current_time
153+
# Save binlog position if available
154+
if hasattr(self.binlog_stream, 'log_file') and hasattr(self.binlog_stream, 'log_pos'):
155+
self.save_binlog_position(self.binlog_stream.log_file, self.binlog_stream.log_pos)
156+
157+
# Flush the producer and commit state changes
158+
self.flush()
159+
160+
# Clear the buffer and update flush time
161+
self.buffer = []
162+
self.last_flush_time = current_time
163+
164+
logger.debug("Buffered messages sent and state committed")
157165

158-
time.sleep(WAIT_INTERVAL)
166+
def run(self):
167+
"""Main CDC loop - runs while self.running is True"""
168+
logger.info(f"Starting MySQL CDC source for {self.mysql_table_name}")
169+
170+
# Perform initial snapshot if needed
171+
self.perform_initial_snapshot_if_needed()
172+
173+
# Log binlog position if available
174+
saved_position = self.get_binlog_position()
175+
if saved_position:
176+
logger.info(f"Resuming from binlog position: {saved_position}")
177+
178+
# Start CDC loop
179+
while self.running:
180+
try:
181+
# Get changes from MySQL binlog
182+
changes = get_changes(self.binlog_stream, self.mysql_schema, self.mysql_table)
183+
184+
# Add changes to buffer
185+
for change in changes:
186+
self.buffer.append(change)
187+
188+
if len(self.buffer) > 0:
189+
logger.debug(f"Buffer length: {len(self.buffer)}")
190+
191+
# Process buffered messages if flush interval has passed
192+
self.process_buffered_messages()
193+
194+
# Small sleep to prevent excessive CPU usage
195+
time.sleep(self.wait_interval)
196+
197+
except Exception as e:
198+
logger.error(f"Error in CDC loop: {e}")
199+
# Still continue running unless it's a fatal error
200+
time.sleep(1) # Wait a bit longer on error
159201

202+
def stop(self):
203+
"""Clean up resources when stopping"""
204+
logger.info("Stopping MySQL CDC source")
205+
206+
# Process any remaining buffered messages
207+
if len(self.buffer) > 0:
208+
logger.info(f"Processing {len(self.buffer)} remaining buffered messages")
209+
self.process_buffered_messages()
210+
211+
# Clean up connections
212+
if self.conn:
213+
self.conn.close()
214+
logger.info("MySQL connection closed")
215+
216+
if self.binlog_stream:
217+
self.binlog_stream.close()
218+
logger.info("Binlog stream closed")
219+
220+
super().stop()
160221

161-
if __name__ == "__main__":
222+
def main():
223+
"""Main function to run the MySQL CDC source"""
224+
# Create a Quix Application
225+
app = Application()
226+
227+
# Check the output topic is configured
228+
output_topic_name = os.getenv("output", "")
229+
if output_topic_name == "":
230+
raise ValueError("output_topic environment variable is required")
231+
232+
# Create the MySQL CDC source
233+
mysql_source = MySqlCdcSource(name="mysql-cdc-source")
234+
235+
# Create a StreamingDataFrame from the source
236+
sdf = app.dataframe(source=mysql_source)
237+
238+
# Print messages for debugging (you can replace this with your processing logic)
239+
# sdf.print(metadata=True) # Commented out to reduce verbose output
240+
241+
# Send CDC data to output topic
242+
sdf.to_topic(output_topic_name)
243+
244+
# Run the application
162245
try:
163-
main()
246+
logger.info("Starting MySQL CDC application")
247+
app.run()
164248
except KeyboardInterrupt:
165-
logger.info("Exiting.")
166-
run = False
249+
logger.info("Application interrupted by user")
250+
except Exception as e:
251+
logger.error(f"Application error: {e}")
252+
raise
167253
finally:
168-
if 'conn' in locals():
169-
conn.close()
170-
if 'binlog_stream' in locals():
171-
binlog_stream.close()
172-
logger.info("Connection to MySQL closed")
173-
logger.info("Exiting")
254+
logger.info("MySQL CDC application stopped")
255+
256+
if __name__ == "__main__":
257+
main()

python/sources/mysql_cdc/mysql_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def perform_initial_snapshot(schema_name: str, table_name: str, batch_size: int
286286
processed_rows += len(rows)
287287
offset += batch_size
288288

289-
if processed_rows % 10000 == 0: # Log progress every 10k rows
289+
if processed_rows % 50000 == 0: # Log progress every 50k rows
290290
logger.info(f"Snapshot progress: {processed_rows}/{total_rows} rows processed")
291291

292292
logger.info(f"Initial snapshot completed: {processed_rows} rows captured")

0 commit comments

Comments
 (0)