diff --git a/.gitignore b/.gitignore index 994f80e8..388e744a 100644 --- a/.gitignore +++ b/.gitignore @@ -28,8 +28,14 @@ checkpoints.json # CSV Files (except symbols.csv) *.csv !symbols.csv + +*.png +*.txt +*.py +======= !all_symbols.csv + # Environments .env .venv diff --git a/.sample.env b/.sample.env index 61aec1dd..d7c3d562 100644 --- a/.sample.env +++ b/.sample.env @@ -204,4 +204,3 @@ CSRF_TIME_LIMIT = '' # Examples: 'instance1_session', 'user1_session', 'app_session', etc. SESSION_COOKIE_NAME = 'session' CSRF_COOKIE_NAME = 'csrf_token' - diff --git a/Live Trading.docx b/Live Trading.docx new file mode 100644 index 00000000..9bf29511 Binary files /dev/null and b/Live Trading.docx differ diff --git a/blueprints/apikey.py b/blueprints/apikey.py index 8807dcf8..cfd5bf5e 100644 --- a/blueprints/apikey.py +++ b/blueprints/apikey.py @@ -39,7 +39,8 @@ def manage_api_key(): # Generate new API key api_key = generate_api_key() - + logger.info(f"Generated API key for user: {api_key}") + # Store the API key (auth_db will handle both hashing and encryption) key_id = upsert_api_key(user_id, api_key) diff --git a/blueprints/core.py b/blueprints/core.py index 8f9c27d7..ba547123 100644 --- a/blueprints/core.py +++ b/blueprints/core.py @@ -44,6 +44,7 @@ def setup(): # Automatically generate and save API key api_key = generate_api_key() + logger.info(f"Generated API key for user: {api_key}") key_id = upsert_api_key(username, api_key) if not key_id: logger.error(f"Failed to create API key for user {username}") diff --git a/broker/flattrade/api/data.py b/broker/flattrade/api/data.py index 11b2c415..7358450d 100644 --- a/broker/flattrade/api/data.py +++ b/broker/flattrade/api/data.py @@ -41,7 +41,7 @@ def get_api_response(endpoint, auth, method="POST", payload=None): data = response.text # Print raw response for debugging - logger.info(f"Raw Response: {data}") + # logger.info(f"Raw Response: {data}") try: return json.loads(data) diff --git a/broker/flattrade/streaming/flattrade_adapter.py b/broker/flattrade/streaming/flattrade_adapter.py index cbe3be2c..2c36f60e 100644 --- a/broker/flattrade/streaming/flattrade_adapter.py +++ b/broker/flattrade/streaming/flattrade_adapter.py @@ -228,6 +228,14 @@ def normalize(data: Dict[str, Any], msg_type: str) -> Dict[str, Any]: return result +# RedPanda/Kafka imports +try: + from kafka import KafkaProducer + from kafka.errors import KafkaError + KAFKA_AVAILABLE = True +except ImportError: + KAFKA_AVAILABLE = False + class FlattradeWebSocketAdapter(BaseBrokerWebSocketAdapter): """Flattrade WebSocket adapter with improved structure and error handling""" @@ -256,8 +264,17 @@ def _setup_connection_management(self): """Initialize connection management""" self.running = False self.connected = False + + # Kafka/RedPanda specific attributes + self.kafka_enabled = self.redpanda_enabled and KAFKA_AVAILABLE + self.kafka_publish_lock = threading.Lock() self.lock = threading.Lock() self.reconnect_attempts = 0 + + if self.kafka_enabled: + self.logger.info("Kafka publishing enabled for Flattrade adapter") + else: + self.logger.info("Kafka publishing disabled - using ZMQ only") def _setup_normalizers(self): """Initialize data normalizers""" @@ -567,7 +584,7 @@ def _resubscribe_all(self): elif mode == Config.MODE_DEPTH: depth_scrips.add(scrip) self.ws_subscription_refs[scrip]['depth_count'] += 1 - + # Resubscribe in batches if touchline_scrips: scrip_list = '#'.join(touchline_scrips) @@ -603,6 +620,27 @@ def _on_message(self, ws, message): self.logger.error(f"JSON decode error: {e}, message: {message}") except Exception as e: self.logger.error(f"Message processing error: {e}", exc_info=True) + + def publish_market_data(self, topic: str, data: dict) -> None: + # --- 1) ZMQ Publish (existing) --- + if self.zmq_enabled: + self.logger.info(f"[ZMQ PUBLISH] Topic: {topic} | Data: {data}") + self.logger.debug(f"[DEBUG] ZMQ publish call for topic: {topic}, data keys: {list(data.keys())}") + super().publish_market_data(topic, data) + + # --- 2) Kafka Publish (new) --- + if self.kafka_enabled and self.kafka_producer: + try: + with self.kafka_publish_lock: + # The KafkaProducer was set up in BaseBrokerWebSocketAdapter._setup_redpanda() + # We assume the topic already exists or auto‐creation is enabled. + # You can also prefix or map your ZMQ topic to a Kafka topic namespace here. + self.kafka_producer.send("tick_data", key=topic, value=data) + # Optionally flush immediately (costly): + # self.kafka_producer.flush() + self.logger.info(f"[KAFKA PUBLISH] Topic: {"tick_data"} | Key:{topic} | Data: {data}") + except KafkaError as e: + self.logger.error(f"[KAFKA ERROR] Failed to publish to Kafka topic {"tick_data"}:{topic}: {e}") def _process_market_message(self, data: Dict[str, Any]) -> None: """Process market data messages with better error handling""" diff --git a/broker/flattrade/streaming/flattrade_dummy_api.py b/broker/flattrade/streaming/flattrade_dummy_api.py new file mode 100644 index 00000000..43d23eb2 --- /dev/null +++ b/broker/flattrade/streaming/flattrade_dummy_api.py @@ -0,0 +1,202 @@ +import asyncio +import websockets +import json +import random +from datetime import datetime + +# Mock database +valid_tokens = { + "valid_token_123": {"user_id": "FZ15709", "client_id": "74eb594de4a944558aeacd623a714d16"} +} + +# Initialize market data variables +last_price = 1300.0 +volume = 300 +open_price = 1296.0 +high = 1302.0 +low = 1294.0 +close = 1299.5 + +last_price_2 = 425.0 +volume_2 = 1000 +open_price_2 = 422.0 +high_2 = 430.0 +low_2 = 420.0 +close_2 = 426.0 + +async def handle_connection(websocket): + print("Client connected") + authenticated = False + subscribed = False + + try: + # Create a task for sending market data + market_data_task = None + + async for message in websocket: + try: + data = json.loads(message) + print(f"Received: {data}") + + # Authentication handling + if data.get("t") == "c" and not authenticated: + response = { + "t": "ck", + "s": "OK", + "uid": "FZ15709" + } + authenticated = True + await websocket.send(json.dumps(response)) + print("Authentication successful") + + # Subscription handling + elif data.get("t") == "t" and authenticated and not subscribed: + response = { + "t": "tk", + "e": "NSE", + "tk": "2885", + #"ts": "RELIANCE-EQ", + "ti": "1", + "ls": "1", + "lp": str(last_price), + "pc": "0.5", + "v": str(volume), + "o": str(open_price), + "h": str(high), + "l": str(low), + "c": str(close), + "ap": str(last_price) + } + response_2 = { + "t": "tk", + "e": "NSE", + "tk": "11536", + "ts": "TCS-EQ", + "ti": "1", + "ls": "1", + "lp": str(last_price_2), + "pc": "0.5", + "v": str(volume_2), + "o": str(open_price_2), + "h": str(high_2), + "l": str(low_2), + "c": str(close_2), + "ap": str(last_price_2) + } + subscribed = True + await websocket.send(json.dumps(response)) + await websocket.send(json.dumps(response_2)) + print("Subscribed") + + # Start sending market data after subscription + market_data_task = asyncio.create_task(send_market_data(websocket)) + + elif data.get("t") == "u" and subscribed: + if market_data_task: + market_data_task.cancel() + try: + await market_data_task + except asyncio.CancelledError: + pass + + response = { + "t": "uk", + "k": "NSE|2885#NSE|11536" + } + subscribed = False + await websocket.send(json.dumps(response)) + print("Unsubscribed") + + + # Unknown message type + else: + if not authenticated: + response = { + "t": "error", + "emsg": "Not authenticated" + } + await websocket.send(json.dumps(response)) + print("Not Authenticated") + + except json.JSONDecodeError: + response = { + "t": "error", + "emsg": "Invalid JSON" + } + await websocket.send(json.dumps(response)) + + except websockets.exceptions.ConnectionClosed: + print("Client disconnected") + if market_data_task: + market_data_task.cancel() + +async def send_market_data(websocket): + """Continuously send market data updates""" + global last_price, volume, open_price, high, low, close + global last_price_2, volume_2, open_price_2, high_2, low_2, close_2 + + print("Starting market data stream") + while True: + try: + # Update market data + last_price += round(random.uniform(-2, 2), 2) + volume += random.randint(100, 1000) + open_price += round(random.uniform(-2, 2), 2) + high += round(random.uniform(-2, 2), 2) + low += round(random.uniform(-2, 2), 2) + close += round(random.uniform(-2, 2), 2) + + last_price_2 += round(random.uniform(-2, 2), 2) + volume_2 += random.randint(100, 1000) + open_price_2 += round(random.uniform(-2, 2), 2) + high_2 += round(random.uniform(-2, 2), 2) + low_2 += round(random.uniform(-2, 2), 2) + close_2 += round(random.uniform(-2, 2), 2) + + touchline_data = { + "t": "tf", + "e": "NSE", + "tk": "2885", + "lp": str(last_price), + "pc": "0.5", + "v": str(volume), + "o": str(open_price), + "h": str(high), + "l": str(low), + "c": str(close), + "ap": str(last_price) + } + + touchline_data_2 = { + "t": "tf", + "e": "NSE", + "tk": "11536", + "lp": str(last_price_2), + "pc": "0.5", + "v": str(volume_2), + "o": str(open_price_2), + "h": str(high_2), + "l": str(low_2), + "c": str(close_2), + "ap": str(last_price_2) + } + + await websocket.send(json.dumps(touchline_data)) + await websocket.send(json.dumps(touchline_data_2)) + print("Sent market data update") + await asyncio.sleep(2) + + except websockets.exceptions.ConnectionClosed: + print("Connection closed while sending market data") + break + except Exception as e: + print(f"Error in market data stream: {e}") + break + +async def main(): + async with websockets.serve(handle_connection, "localhost", 8766): + print("Dummy WebSocket server running on ws://localhost:8766") + await asyncio.Future() # Run forever + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/broker/flattrade/streaming/flattrade_websocket.py b/broker/flattrade/streaming/flattrade_websocket.py index d3a7b715..b6a82b28 100644 --- a/broker/flattrade/streaming/flattrade_websocket.py +++ b/broker/flattrade/streaming/flattrade_websocket.py @@ -15,6 +15,8 @@ class FlattradeWebSocket: # Connection constants WS_URL = "wss://piconnect.flattrade.in/PiConnectWSTp/" + #WS_URL = "ws://localhost:8766" + CONNECTION_TIMEOUT = 15 THREAD_JOIN_TIMEOUT = 5 diff --git a/broker/flattrade/streaming/kafkaconsum.py b/broker/flattrade/streaming/kafkaconsum.py new file mode 100644 index 00000000..490f6838 --- /dev/null +++ b/broker/flattrade/streaming/kafkaconsum.py @@ -0,0 +1,17 @@ +from kafka import KafkaConsumer +import json + +consumer = KafkaConsumer( + #'tick_raw', + 'tick_data', + bootstrap_servers='localhost:9092', + group_id='test-group', + auto_offset_reset='earliest', + key_deserializer=lambda k: k.decode('utf-8') if k else None, + value_deserializer=lambda v: json.loads(v.decode('utf-8')) +) + +print("Listening for messages...") + +for msg in consumer: + print(f"{msg.key} => {msg.value}") diff --git a/broker/flattrade/streaming/kafkaprod.py b/broker/flattrade/streaming/kafkaprod.py new file mode 100644 index 00000000..b9b2f7d7 --- /dev/null +++ b/broker/flattrade/streaming/kafkaprod.py @@ -0,0 +1,20 @@ +from kafka import KafkaProducer +import json + +producer = KafkaProducer( + bootstrap_servers='localhost:9092', + key_serializer=lambda k: k.encode('utf-8'), + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# Sample tick data +tick_data = { + "symbol": "RELIANCE", + "ltp": 2810.50, + "volume": 100, + "timestamp": 1720458436000 +} + +producer.send('tick_raw', key='RELIANCE', value=tick_data) +producer.flush() +print("Tick sent!") diff --git a/broker/flattrade/streaming/timescaledb.py b/broker/flattrade/streaming/timescaledb.py new file mode 100644 index 00000000..31b7fa05 --- /dev/null +++ b/broker/flattrade/streaming/timescaledb.py @@ -0,0 +1,843 @@ +import psycopg2 +from psycopg2 import sql +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +from psycopg2.extras import execute_batch +from kafka import KafkaConsumer +from kafka.errors import KafkaError +import pytz +import json +import logging +from threading import Lock +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +import os +import random +from dateutil import parser # For flexible ISO date parsing +import traceback +import argparse +from openalgo import api +import pandas as pd + +from dotenv import load_dotenv + +load_dotenv() + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Check if Kafka is available +try: + KAFKA_AVAILABLE = True +except ImportError: + KAFKA_AVAILABLE = False + logger.warning("Kafka library not available. Install with: pip install kafka-python") + +class TimescaleDBManager: + def __init__(self, dbname=os.getenv('TIMESCALE_DB_NAME'), user=os.getenv('TIMESCALE_DB_USER'), password=os.getenv('TIMESCALE_DB_PASSWORD'), host=os.getenv('TIMESCALE_DB_HOST'), port=os.getenv('TIMESCALE_DB_PORT')): + self.dbname = dbname + self.user = user + self.password = password + self.host = host + self.port = port + self.admin_conn = None + self.app_conn = None + self.logger = logging.getLogger(f"TimeScaleDBManager") + + self.logger.info(f"Initializing TimescaleDB connection to {self.host}:{self.port} as user '{self.user}' for database '{self.dbname}'") + + def _get_admin_connection(self): + """Connection without specifying database (for admin operations)""" + try: + conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname='postgres' # Connect to default admin DB + ) + # Set autocommit mode for DDL operations like CREATE DATABASE + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + return conn + except psycopg2.Error as e: + self.logger.error(f"Failed to connect to PostgreSQL server: {e}") + raise + + def _database_exists(self): + """Check if database exists""" + try: + with self._get_admin_connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", + (self.dbname,) + ) + return cursor.fetchone() is not None + except Exception as e: + self.logger.error(f"Error checking database existence: {e}") + return False + + + def _create_database(self): + """Create new database with TimescaleDB extension""" + try: + self.logger.info(f"Creating database '{self.dbname}'...") + + # Create database with autocommit connection + conn = self._get_admin_connection() + try: + with conn.cursor() as cursor: + # Create database + cursor.execute( + sql.SQL("CREATE DATABASE {}").format( + sql.Identifier(self.dbname) + ) + ) + self.logger.info(f"Database '{self.dbname}' created successfully") + finally: + conn.close() + + # Connect to new database to install extensions + self.logger.info("Installing TimescaleDB extension...") + conn_newdb = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname=self.dbname + ) + try: + with conn_newdb.cursor() as cursor_new: + cursor_new.execute("CREATE EXTENSION IF NOT EXISTS timescaledb") + conn_newdb.commit() + self.logger.info("TimescaleDB extension installed successfully") + finally: + conn_newdb.close() + + self.logger.info(f"Created database {self.dbname} with TimescaleDB extension") + return True + + except psycopg2.Error as e: + self.logger.error(f"PostgreSQL error creating database: {e}") + return False + except Exception as e: + self.logger.error(f"Error creating database: {e}") + return False + + + def _create_tables(self): + """Create required tables and hypertables""" + commands = [ + """ + CREATE TABLE IF NOT EXISTS ticks ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ticks', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_1m ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_1m', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_5m ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_5m', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_15m ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_15m', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_D ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_D', 'time', if_not_exists => TRUE) + """ + ] + + try: + conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname=self.dbname + ) + try: + with conn.cursor() as cursor: + for i, command in enumerate(commands): + try: + cursor.execute(command) + self.logger.debug(f"Executed command {i+1}/{len(commands)}") + except psycopg2.Error as e: + # Skip hypertable creation if table already exists as hypertable + if "already a hypertable" in str(e): + self.logger.info(f"Table already exists as hypertable, skipping: {e}") + continue + else: + self.logger.error(f"Error executing command {i+1}: {e}") + self.logger.error(f"Command was: {command}") + raise + conn.commit() + self.logger.info("Created tables and hypertables successfully") + finally: + conn.close() + + except psycopg2.Error as e: + self.logger.error(f"PostgreSQL error creating tables: {e}") + raise + except Exception as e: + self.logger.error(f"Error creating tables: {e}") + raise + + def test_connection(self): + """Test database connection""" + try: + conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname='postgres' # Test with default database first + ) + conn.close() + self.logger.info("Database connection test successful") + return True + except psycopg2.Error as e: + self.logger.error(f"Database connection test failed: {e}") + return False + + def initialize_database(self): + """Main initialization method""" + # Test connection first + if not self.test_connection(): + raise RuntimeError("Cannot connect to PostgreSQL server. Check your connection parameters.") + + if not self._database_exists(): + self.logger.info(f"Database {self.dbname} not found, creating...") + if not self._create_database(): + raise RuntimeError("Failed to create database") + else: + self.logger.info(f"Database {self.dbname} already exists") + + self._create_tables() + + # Return an application connection + try: + self.app_conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname=self.dbname + ) + self.logger.info("Database connection established successfully") + return self.app_conn + + except psycopg2.Error as e: + self.logger.error(f"Database connection failed: {e}") + raise + except Exception as e: + self.logger.error(f"Database connection failed: {e}") + raise + +# Integration with your existing code +class MarketDataProcessor: + def __init__(self): + # Initialize TimescaleDBManager and connect to the database + self.db_manager = TimescaleDBManager() + self.db_conn = self.db_manager.initialize_database() + self.logger = logging.getLogger(f"MarketDataProcessor") + + self.consumer = KafkaConsumer( + 'tick_data', + bootstrap_servers='localhost:9092', + group_id='tick-processor', + auto_offset_reset='earliest' + #key_deserializer=lambda k: k.decode('utf-8') if k else None, + #value_deserializer=lambda v: json.loads(v.decode('utf-8')) + ) + + self.logger.info("Starting consumer with configuration:") + self.logger.info(f"Group ID: {self.consumer.config['group_id']}") + self.logger.info(f"Brokers: {self.consumer.config['bootstrap_servers']}") + + self.aggregation_lock = Lock() + + self.executor = ThreadPoolExecutor(max_workers=8) + + # Initialize aggregation buffers + self.reset_aggregation_buffers() + + # volume tracking + # Initialize volume tracking for all timeframes + self.last_period_volume = { + '1m': {}, + '5m': {}, + '15m': {} + } + + def clean_database(self): + """Clear all records from all tables in the database""" + try: + self.logger.info("Cleaning database tables...") + tables = ['ticks', 'ohlc_1m', 'ohlc_5m', 'ohlc_15m', 'ohlc_D'] # Add all your table names here + + with self.db_conn.cursor() as cursor: + # Disable triggers temporarily to avoid hypertable constraints + cursor.execute("SET session_replication_role = 'replica';") + + for table in tables: + try: + cursor.execute(f"TRUNCATE TABLE {table} CASCADE;") + self.logger.info(f"Cleared table: {table}") + except Exception as e: + self.logger.error(f"Error clearing table {table}: {e}") + self.db_conn.rollback() + continue + + # Re-enable triggers + cursor.execute("SET session_replication_role = 'origin';") + self.db_conn.commit() + + self.logger.info("Database cleaning completed successfully") + return True + + except Exception as e: + self.logger.error(f"Database cleaning failed: {e}") + self.db_conn.rollback() + return False + + + def insert_historical_data(self, df, symbol, interval): + """ + Insert historical data into the appropriate database table + + Args: + df (pd.DataFrame): DataFrame containing historical data + symbol (str): Stock symbol (e.g., 'RELIANCE') + interval (str): Time interval ('1m', '5m', '15m', '1d') + """ + try: + if df.empty: + self.logger.warning(f"No data to insert for {symbol} {interval}") + return False + + # Reset index to make timestamp a column + df = df.reset_index() + + # Rename columns to match database schema + df = df.rename(columns={ + 'timestamp': 'time', + 'open': 'open', + 'high': 'high', + 'low': 'low', + 'close': 'close', + 'volume': 'volume' + }) + + # Handle timezone conversion differently for intraday vs daily data + df['time'] = pd.to_datetime(df['time']) + if interval == 'D': + # Set to market open time (09:15:00 IST) for each date + df['time'] = df['time'].dt.tz_localize(None) # Remove any timezone + df['time'] = df['time'] + pd.Timedelta(hours=9, minutes=15) + df['time'] = df['time'].dt.tz_localize('Asia/Kolkata') + else: + if df['time'].dt.tz is None: + df['time'] = df['time'].dt.tz_localize('Asia/Kolkata') + else: + df['time'] = df['time'].dt.tz_convert('Asia/Kolkata') + + # Convert to UTC for database storage + df['time'] = df['time'].dt.tz_convert('UTC') + + # Add symbol column + df['symbol'] = symbol + + # Select and order the columns we need (excluding 'oi' which we don't store) + required_columns = ['time', 'symbol', 'open', 'high', 'low', 'close', 'volume'] + df = df[required_columns] + + # Convert numeric columns to appropriate types + numeric_cols = ['open', 'high', 'low', 'close'] + df[numeric_cols] = df[numeric_cols].astype(float) + df['volume'] = df['volume'].astype(int) + + # Determine the target table based on interval + table_name = f'ohlc_{interval.lower()}' + + # Convert DataFrame to list of tuples + records = [tuple(x) for x in df.to_numpy()] + + # Debug: print first record to verify format + self.logger.debug(f"First record sample: {records[0] if records else 'No records'}") + + with self.db_conn.cursor() as cursor: + # Use execute_batch for efficient bulk insertion + execute_batch(cursor, f""" + INSERT INTO {table_name} + (time, symbol, open, high, low, close, volume) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (time, symbol) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume + """, records) + + self.db_conn.commit() + self.logger.info(f"Successfully inserted {len(df)} records for {symbol} ({interval}) into {table_name}") + return True + + except KeyError as e: + self.logger.error(f"Missing required column in data for {symbol} {interval}: {e}") + self.logger.error(f"Available columns: {df.columns.tolist()}") + return False + except Exception as e: + self.logger.error(f"Error inserting historical data for {symbol} {interval}: {e}") + self.logger.error(traceback.format_exc()) + self.db_conn.rollback() + return False + + + def reset_aggregation_buffers(self): + """Initialize/reset aggregation buffers""" + with self.aggregation_lock: + self.tick_buffer = { + '1m': {}, + '5m': {}, + '15m': {} + } + now = datetime.now(pytz.utc) + self.last_agg_time = { + '1m': self.floor_to_interval(now, 1), + '5m': self.floor_to_interval(now, 5), + '15m': self.floor_to_interval(now, 15) + } + self.aggregation_state = { + '1m': {}, + '5m': {}, + '15m': {} + } + + # Reset volume tracking + self.last_period_volume = { + '1m': {}, + '5m': {}, + '15m': {} + } + + def process_messages(self): + """Main processing loop""" + + self.consumer.subscribe(['tick_data']) + self.logger.info("Started listening messages...") + + try: + while True: + raw_msg = self.consumer.poll(1000.0) + self.logger.info(f"\n\n\n\nReceived messages: {raw_msg}") + + if raw_msg is None: + self.logger.info("No messages received during timeout period ----------->") + continue + + for topic_partition, messages in raw_msg.items(): + for message in messages: + try: + # Extract key and value + key = message.key.decode('utf-8') # 'NSE_RELIANCE_LTP' + value = json.loads(message.value.decode('utf-8')) + + self.logger.info(f"Processing {key}: {value['symbol']}@{value['close']}") + + # Process the message + self.process_single_message(key, value) + + except Exception as e: + self.logger.error(f"Error processing message: {e}") + + except KeyboardInterrupt: + self.logger.info("Kafka Consumer Shutting down...") + finally: + self.shutdown() + + def _handle_kafka_error(self, error): + """Handle Kafka protocol errors""" + error_codes = { + KafkaError._PARTITION_EOF: "End of partition", + KafkaError.UNKNOWN_TOPIC_OR_PART: "Topic/partition does not exist", + KafkaError.NOT_COORDINATOR_FOR_GROUP: "Coordinator changed", + KafkaError.ILLEGAL_GENERATION: "Consumer group rebalanced", + KafkaError.UNKNOWN_MEMBER_ID: "Member ID expired" + } + + if error.code() in error_codes: + self.logger.warning(error_codes[error.code()]) + else: + self.logger.error(f"Kafka error [{error.code()}]: {error.str()}") + + + def process_single_message(self, key, value): + """Process extracted tick data""" + try: + # Extract components from key + components = key.split('_') + exchange = components[0] # 'NSE' + symbol = components[1] # 'RELIANCE' + data_type = components[2] # 'LTP' or 'QUOTE' + + # Convert timestamp (handling milliseconds since epoch) + timestamp = value['timestamp'] + if not isinstance(timestamp, (int, float)): + raise ValueError(f"Invalid timestamp type: {type(timestamp)}") + + # Convert to proper datetime object + # Ensure milliseconds (not seconds or microseconds) + if timestamp < 1e12: # Likely in seconds + timestamp *= 1000 + elif timestamp > 1e13: # Likely in microseconds + timestamp /= 1000 + + dt = datetime.fromtimestamp(timestamp / 1000, tz=pytz.UTC) + + # Validate date range + if dt.year < 2020 or dt.year > 2030: + raise ValueError(f"Implausible date {dt} from timestamp {timestamp}") + + # Prepare database record + record = { + 'time': dt, # Convert ms to seconds + 'symbol': symbol, + 'open': float(value['ltp']), + 'high': float(value['ltp']), + 'low': float(value['ltp']), + 'close': float(value['ltp']), + 'volume': int(value['volume']) + } + + #self.logger.info(f"Record---------> {record}") + + # Store in TimescaleDB + self.store_tick(record) + + # Add to aggregation buffers + self.buffer_tick(record) + + # Check for aggregation opportunities + self.check_aggregation(record['time']) + + except Exception as e: + self.logger.error(f"Tick processing failed: {e}") + self.logger.debug(traceback.format_exc()) + + + def store_tick(self, record): + """Store raw tick in database""" + try: + with self.db_conn.cursor() as cursor: + cursor.execute(""" + INSERT INTO ticks (time, symbol, open, high, low, close, volume) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (time, symbol) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume + """, (record['time'], record['symbol'], record['open'], record['high'], record['low'], record['close'], record['volume'])) + self.db_conn.commit() + except Exception as e: + logger.error(f"Error storing tick: {e}") + self.db_conn.rollback() + + def buffer_tick(self, record): + """Add tick to aggregation buffers""" + with self.aggregation_lock: + for timeframe in ['1m', '5m', '15m']: + minutes = int(timeframe[:-1]) + symbol = record['symbol'] + aligned_time = self.floor_to_interval(record['time'], minutes) + + if symbol not in self.tick_buffer[timeframe]: + self.tick_buffer[timeframe][symbol] = {} + + # Initialize this specific minute bucket + if aligned_time not in self.tick_buffer[timeframe][symbol]: + self.tick_buffer[timeframe][symbol][aligned_time] = { + 'opens': [], + 'highs': [], + 'lows': [], + 'closes': [], + 'volumes': [], + 'first_tick': None # Track the first tick separately + } + + bucket = self.tick_buffer[timeframe][symbol][aligned_time] + + # For the first tick in this interval, store it separately + if bucket['first_tick'] is None: + bucket['first_tick'] = record + + bucket['opens'].append(record['open']) + bucket['highs'].append(record['high']) + bucket['lows'].append(record['low']) + bucket['closes'].append(record['close']) + bucket['volumes'].append(record['volume']) + + def check_aggregation(self, current_time): + """Check if aggregation should occur for any timeframe""" + timeframes = ['1m', '5m', '15m'] + + for timeframe in timeframes: + agg_interval = timedelta(minutes=int(timeframe[:-1])) + last_agg = self.last_agg_time[timeframe] + + self.logger.info(f"{timeframe}: current_time={current_time}, last_agg={last_agg}, interval={agg_interval}") + + if current_time - last_agg >= agg_interval: + if self.aggregate_data(timeframe, current_time): + self.last_agg_time[timeframe] = self.floor_to_interval(current_time, int(timeframe[:-1])) + + + def floor_to_interval(self, dt, minutes=1): + """Floor a datetime to the start of its minute/5m/15m interval""" + discard = timedelta( + minutes=dt.minute % minutes, + seconds=dt.second, + microseconds=dt.microsecond + ) + return dt - discard + + def aggregate_data(self, timeframe, agg_time): + with self.aggregation_lock: + symbol_buckets = self.tick_buffer[timeframe] + if not symbol_buckets: + return False + + aggregated = [] + table_name = f"ohlc_{timeframe}" + + for symbol, buckets in symbol_buckets.items(): + for bucket_start, data in list(buckets.items()): + if bucket_start >= self.last_agg_time[timeframe] + timedelta(minutes=int(timeframe[:-1])): + # Don't process future buckets + continue + + if not data['opens']: + continue + + try: + # Get OHLC values + if data['first_tick'] is not None: + open_ = data['first_tick']['open'] + else: + open_ = data['opens'][0] + + #open_ = data['opens'][0] + high = max(data['highs']) + low = min(data['lows']) + close = data['closes'][-1] + + # Calculate volume correctly for cumulative data + current_last_volume = data['volumes'][-1] + previous_last_volume = self.last_period_volume[timeframe].get(symbol, current_last_volume) + volume = max(0, current_last_volume - previous_last_volume) + + # Store the current last volume for next period + self.last_period_volume[timeframe][symbol] = current_last_volume + + aggregated.append((bucket_start, symbol, open_, high, low, close, volume)) + + # Remove this bucket to avoid re-aggregation + del self.tick_buffer[timeframe][symbol][bucket_start] + + except Exception as e: + self.logger.error(f"Error aggregating {symbol} for {timeframe}: {e}") + continue + + if aggregated: + try: + with self.db_conn.cursor() as cursor: + execute_batch(cursor, f""" + INSERT INTO {table_name} + (time, symbol, open, high, low, close, volume) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (time, symbol) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume + """, aggregated) + self.db_conn.commit() + self.logger.info(f"Aggregated {len(aggregated)} symbols to {table_name}") + return True + except Exception as e: + self.logger.error(f"Error aggregating {timeframe} data: {e}") + self.db_conn.rollback() + return False + return False + + def shutdown(self): + """Clean shutdown""" + logger.info("Shutting down processors") + self.executor.shutdown(wait=True) + self.consumer.close() + self.db_conn.close() + logger.info("Clean shutdown complete") + +if __name__ == "__main__": + + client = api( + api_key="8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c", # Replace with your API key + host="http://127.0.0.1:5000" + ) + + # Argument parsing + parser = argparse.ArgumentParser(description='Market Data Processor') + parser.add_argument('--mode', type=str, choices=['live', 'backtest'], required=True, + help='Run mode: "live" for live processing, "backtest" for backtesting') + + parser.add_argument('--from_date', type=str, + help='Start date for backtest (DD-MM-YYYY format)') + parser.add_argument('--to_date', type=str, + help='End date for backtest (DD-MM-YYYY format)') + args = parser.parse_args() + + # Validate arguments + if args.mode == 'backtest': + if not args.from_date or not args.to_date: + parser.error("--from_date and --to_date are required in backtest mode") + + try: + from_date = datetime.strptime(args.from_date, '%d-%m-%Y').date() + to_date = datetime.strptime(args.to_date, '%d-%m-%Y').date() + + if from_date > to_date: + parser.error("--from_date cannot be after --to_date") + + except ValueError as e: + parser.error(f"Invalid date format. Please use DD-MM-YYYY. Error: {e}") + + # Initialize the processor + processor = MarketDataProcessor() + try: + if args.mode == 'live': + # Clean the database at the start of the intraday trading session(9:00 AM IST) + if datetime.now().hour == 9 and datetime.now().minute == 0: + processor.clean_database() + + # Fetch the last 10 days historical data(1 min, 5 min, 15min, D) and insert in the DB + # Dynamic date range: 7 days back to today + end_date = datetime.now().strftime("%Y-%m-%d") + start_date = (datetime.now() - timedelta(days=10)).strftime("%Y-%m-%d") + + # Import symbol list from CSV file + symbol_list = pd.read_csv('symbol_list.csv') + symbol_list = symbol_list['Symbol'].tolist() + + # Fetch historical data for each symbol + for symbol in symbol_list: + for interval in ["1m", "5m", "15m", "D"]: + df = client.history( + symbol=symbol, + exchange='NSE', + interval=interval, + start_date=start_date, + end_date=end_date + ) + #print(df.head()) + # Insert historical data into the database + processor.insert_historical_data(df, symbol, interval) + + # Process the real-time data + processor.process_messages() + + elif args.mode == 'backtest': + logger.info(f"Running in backtest mode from {args.from_date} to {args.to_date}") + # Clean the database + processor.clean_database() + + # Load historical data for the specified date range + # Fetch the last 10 days historical data(1 min, 5 min, 15min, D) and insert in the DB + # Dynamic date range: 7 days back to today + end_date = to_date.strftime("%Y-%m-%d") + start_date = from_date.strftime("%Y-%m-%d") + + # Import symbol list from CSV file + symbol_list = pd.read_csv('symbol_list_backtest.csv') + symbol_list = symbol_list['Symbol'].tolist() + + # Fetch historical data for each symbol + for symbol in symbol_list: + for interval in ["1m", "5m", "15m", "D"]: + df = client.history( + symbol=symbol, + exchange='NSE', + interval=interval, + start_date=start_date, + end_date=end_date + ) + + #print(df.head()) + # Insert historical data into the database + processor.insert_historical_data(df, symbol, interval) + + # Process data in simulation mode + + except Exception as e: + logger.error(f"Fatal error: {e}") + processor.shutdown() diff --git a/database/auth_db.py b/database/auth_db.py index eba8cdcc..e2ecae90 100644 --- a/database/auth_db.py +++ b/database/auth_db.py @@ -48,9 +48,9 @@ def get_encryption_key(): engine = create_engine( DATABASE_URL, - pool_size=50, - max_overflow=100, - pool_timeout=10 + pool_size=300, + max_overflow=600, + pool_timeout=60 ) db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) diff --git a/requirements.txt b/requirements.txt index 18bf27cf..c938f0fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -128,3 +128,12 @@ wsproto==1.2.0 wtforms==3.1.2 zipp==3.19.2 zmq==0.0.0 +kafka-python==2.2.15 +asyncio-mqtt==0.16.2 +redpanda==0.6.0 +python-snappy==0.7.3 +psycopg2==2.9.10 +google==3.0.0 +matplotlib==3.10.3 +tabulate==0.9.0 +ta-lib==0.6.4 diff --git a/strategies/LIVE_TRADING_README.md b/strategies/LIVE_TRADING_README.md new file mode 100644 index 00000000..c8b5f668 --- /dev/null +++ b/strategies/LIVE_TRADING_README.md @@ -0,0 +1,273 @@ +# Live Trading Engine + +A high-frequency, multi-symbol live trading system that exactly mirrors the backtesting logic with real-time execution. + +## πŸ—οΈ Architecture + +### Core Components + +1. **LiveTradingEngine** - Main orchestrator +2. **PositionManager** - Handles positions and constraints +3. **LiveDataManager** - Real-time data and indicators +4. **OrderManager** - Order execution through API +5. **TradingDashboard** - Monitoring interface + +### Data Flow + +``` +TimescaleDB (Historical + Real-time) + ↓ +LiveDataManager (Indicators) + ↓ +Symbol Scanner (Entry Signals) + ↓ +PositionManager (Risk Checks) + ↓ +OrderManager (Execution) + ↓ +Position Monitor (Exit Signals) +``` + +## πŸš€ Quick Start + +### 1. Setup Environment + +```bash +# Set database environment variables +export TIMESCALE_DB_USER="your_user" +export TIMESCALE_DB_PASSWORD="your_password" +export TIMESCALE_DB_HOST="localhost" +export TIMESCALE_DB_PORT="5432" +export TIMESCALE_DB_NAME="your_db" +``` + +### 2. Configure Symbols + +Edit `symbol_list_live.csv` with your trading symbols: +```csv +Symbol +RELIANCE +TCS +INFY +... +``` + +### 3. Start Live Trading + +```bash +# Full live trading +python run_live_trading.py --symbols live + +# Test mode with limited symbols +python run_live_trading.py --symbols test --max-symbols 10 + +# Dry run (no actual orders) +python run_live_trading.py --dry-run --debug +``` + +### 4. Monitor with Dashboard + +```bash +# In separate terminal +python trading_dashboard.py +``` + +## βš™οΈ Configuration + +### Trading Hours +- **Market**: 09:15 - 15:30 IST +- **Trading**: 09:20 - 15:20 IST (buffer for startup/shutdown) + +### Risk Constraints +- **Max Open Positions**: 3 +- **Max Daily Trades**: 5 +- **Max Strategy Use**: 1 per day per strategy +- **Position Size**: 30% of capital per trade + +### Scanning Frequency +- **Symbol Scan**: Every 1 second +- **Position Monitor**: Every 2 seconds +- **Status Update**: Every 60 seconds + +## πŸ“Š Trading Logic + +### Entry Strategies + +**15-Minute Timeframe:** +- **Strategy 8**: Short entry on bearish breakout +- **Strategy 12**: Long entry on bullish breakout + +**5-Minute Timeframe:** +- **Strategy 10**: Short entry on bearish momentum +- **Strategy 11**: Long entry on bullish momentum +- **Strategy 9**: Alternative short strategy + +### Exit Conditions + +1. **Trailing Stop**: ATR-based (1.5x ATR distance) +2. **End of Day**: Force close at 15:20 +3. **Manual**: Via position manager + +### Indicators Used + +- **ATR (14)**: Volatility measurement +- **EMA (50, 100, 200)**: Trend direction +- **Zero-Lag MACD**: Momentum signals +- **Range Analysis**: Breakout detection +- **Volume Analysis**: Confirmation + +## πŸ”§ Advanced Configuration + +### Custom Position Sizing + +```python +# In live_config.py +TRADING_CONFIG = { + 'capital': 200000, # 2 Lakh + 'leverage': 10, # 10x leverage + 'capital_alloc_pct': 25, # 25% per trade +} +``` + +### Risk Management + +```python +# Modify constraints +TRADING_CONFIG = { + 'max_positions': 5, # Allow 5 positions + 'max_daily_trades': 10, # Allow 10 trades + 'trail_atr_multiple': 2.0, # Wider trailing stops +} +``` + +## πŸ“ File Structure + +``` +strategies/ +β”œβ”€β”€ live_trading_engine.py # Main engine +β”œβ”€β”€ live_config.py # Configuration +β”œβ”€β”€ run_live_trading.py # Startup script +β”œβ”€β”€ trading_dashboard.py # Monitoring +β”œβ”€β”€ symbol_list_live.csv # Trading symbols +β”œβ”€β”€ symbol_list_test.csv # Test symbols +└── LIVE_TRADING_README.md # This file +``` + +## πŸ“ˆ Monitoring + +### Real-time Dashboard +- Open positions with P&L +- Today's trade history +- Constraint status +- System statistics +- Market status + +### Log Files +- `live_trading_YYYYMMDD.log` - Daily trading logs +- Real-time position updates +- Entry/exit confirmations +- Error tracking + +### Key Metrics +- **Win Rate**: Percentage of profitable trades +- **Daily P&L**: Running profit/loss +- **Position Utilization**: Active vs max positions +- **Strategy Performance**: Per-strategy results + +## 🚨 Safety Features + +### Pre-trade Validation +- Position limit checks +- Daily trade limits +- Strategy usage limits +- Market hours validation + +### Error Handling +- API connection failures +- Database disconnections +- Invalid market data +- Order execution errors + +### Emergency Shutdown +- **Ctrl+C**: Graceful shutdown +- **SIGTERM**: Service shutdown +- **Market Close**: Auto position closure + +## πŸ” Troubleshooting + +### Common Issues + +**"No data for symbol"** +```bash +# Check if historical data exists +python -c " +import psycopg2 +conn = psycopg2.connect(...) +cursor = conn.cursor() +cursor.execute('SELECT COUNT(*) FROM ohlc_15m WHERE symbol = %s', ('RELIANCE',)) +print(cursor.fetchone()) +" +``` + +**"API connection failed"** +- Check API key and host configuration +- Verify OpenAlgo server is running +- Check network connectivity + +**"Database connection failed"** +- Verify environment variables +- Check TimescaleDB service status +- Validate connection parameters + +### Debug Mode + +```bash +python run_live_trading.py --debug --dry-run +``` + +This enables: +- Detailed logging +- No actual orders +- Step-by-step execution traces + +## πŸ“‹ Production Checklist + +### Before Going Live + +- [ ] Test with paper trading +- [ ] Verify API connectivity +- [ ] Check position sizing calculations +- [ ] Validate risk constraints +- [ ] Test emergency shutdown +- [ ] Monitor resource usage +- [ ] Set up alerting + +### Daily Operations + +- [ ] Check market calendar +- [ ] Verify system resources +- [ ] Review previous day's performance +- [ ] Monitor log files +- [ ] Check data quality +- [ ] Validate positions at EOD + +## πŸ†˜ Support + +### Logs Location +- Main logs: `live_trading_YYYYMMDD.log` +- Error logs: Check console output +- Position logs: Within main log file + +### Performance Metrics +- Memory usage: Monitor via dashboard +- CPU usage: Check system stats +- Scan latency: Review debug logs + +### Emergency Contacts +- Trading desk: [Your contact] +- Technical support: [Your contact] +- Risk management: [Your contact] + +--- + +**⚠️ Important**: This is a live trading system. Always test thoroughly before deploying with real money. Past performance does not guarantee future results. diff --git a/strategies/backtest_engine.py b/strategies/backtest_engine.py new file mode 100644 index 00000000..2234431b --- /dev/null +++ b/strategies/backtest_engine.py @@ -0,0 +1,1192 @@ +# strategy/backtest_engine.py +import pandas as pd +from datetime import datetime, timedelta, time +import psycopg2 +import logging +import matplotlib.pyplot as plt +import matplotlib +import os +import pytz +IST = pytz.timezone('Asia/Kolkata') +import gc +import numpy as np +import talib + +matplotlib.use('Agg') # use Anti-Grain Geometry backend (non-GUI) + +class BacktestEngine: + def __init__(self, conn, symbol, start_date, end_date, lookback_days=10, tp_pct=1.5, sl_pct=1.5, trail_activation_pct=0.9, + trail_stop_gap_pct=0.2, trail_increment_pct=0.2, capital=100000, leverage=5, capital_alloc_pct=30): + self.conn = conn + self.symbol = symbol + self.start_date = datetime.strptime(start_date, "%Y-%m-%d").date() + self.end_date = datetime.strptime(end_date, "%Y-%m-%d").date() + self.lookback_days = lookback_days + self.logger = logging.getLogger(f"BacktestEngine[{symbol}]") + self.position = 0 + self.entry_price = None + self.trades = [] + self.tp_pct = tp_pct / 100 + self.sl_pct = sl_pct / 100 + self.trail_activation_pct = trail_activation_pct / 100 + self.trail_stop_gap_pct = trail_stop_gap_pct / 100 + self.trail_increment_pct = trail_increment_pct / 100 + self.capital = capital + self.leverage = leverage + self.capital_alloc_pct = capital_alloc_pct + + def daterange(self): + for n in range(int((self.end_date - self.start_date).days) + 1): + yield self.start_date + timedelta(n) + + def fetch_lookback_data(self, end_day, interval, lookback_days): + lookback_start = end_day - timedelta(days=lookback_days) + #self.logger.info(f"Fetching lookback data from {lookback_start} to {end_day}") + query = f""" + SELECT * FROM ohlc_{interval} + WHERE symbol = %s AND time >= %s AND time < %s + ORDER BY time ASC + """ + df = pd.read_sql(query, self.conn, params=(self.symbol, lookback_start, end_day + timedelta(days=1))) + if df.empty: + self.logger.warning(f"No data found for {self.symbol} {interval} between {lookback_start} and {end_day}") + + return df + + def fetch_lookback_data_2(self, start_day, end_day, interval): + + #self.logger.info(f"Fetching lookback data from {lookback_start} to {end_day}") + if interval == "nifty_15m": + interval = "15m" + + query = f""" + SELECT * FROM ohlc_{interval} + WHERE symbol = %s AND time >= %s AND time < %s + ORDER BY time ASC + """ + if interval == '1h': + df = pd.read_sql(query, self.conn, params=('NIFTY', start_day, end_day + timedelta(days=1))) + elif interval == 'nifty_15m': + df = pd.read_sql(query, self.conn, params=('NIFTY', start_day, end_day + timedelta(days=1))) + else: + df = pd.read_sql(query, self.conn, params=(self.symbol, start_day, end_day + timedelta(days=1))) + + if df.empty: + if interval == '1h': + self.logger.warning(f"No data found for NIFTY {interval} between {start_day} and {end_day}") + elif interval == 'nifty_15m': + self.logger.warning(f"No data found for NIFTY {interval} between {start_day} and {end_day}") + else: + self.logger.warning(f"No data found for {self.symbol} {interval} between {start_day} and {end_day}") + + return df + + def exclude_first_30min(self, group): + """Calculate expanding mean of range excluding first 30 minutes""" + try: + if group.empty or 'range' not in group.columns: + return pd.Series(dtype=float) + + mask = ~( + (group['time'].dt.time >= time(3, 45)) & + (group['time'].dt.time < time(4, 15)) # Time in UTC - hence 3.45 to 4.15 + ) + filtered_group = group[mask] + if filtered_group.empty: + return pd.Series(dtype=float, index=group.index) + + result = filtered_group['range'].expanding().mean() + # Ensure we return a Series aligned with the original group + return result.reindex(group.index).ffill() + except Exception as e: + # Return a Series of NaNs if there's an error + return pd.Series(float('nan'), index=group.index) + + def zero_lag_ema(self, series, period): + ema1 = series.ewm(span=period, adjust=False).mean() + ema2 = ema1.ewm(span=period, adjust=False).mean() + return ema1 + (ema1 - ema2) + + def hull_moving_average(self, close, window=50): + """HMA reduces lag significantly vs EMA.""" + wma_half = talib.WMA(close, timeperiod=window//2) + wma_full = talib.WMA(close, timeperiod=window) + hma = talib.WMA(2 * wma_half - wma_full, timeperiod=int(np.sqrt(window))) + return hma + + def zero_lag_macd(self, close, fast=12, slow=26, signal=9): + """MACD using Zero-Lag EMAs (TEMA).""" + ema_fast = talib.TEMA(close, timeperiod=fast) + ema_slow = talib.TEMA(close, timeperiod=slow) + macd = ema_fast - ema_slow + signal_line = talib.TEMA(macd, timeperiod=signal) + return macd, signal_line + + def relative_momentum_index(self, close, window=50): + """RMI is a more responsive alternative to ADX.""" + delta = close.diff(1) + gain = delta.where(delta > 0, 0) + loss = -delta.where(delta < 0, 0) + avg_gain = gain.rolling(window).mean() + avg_loss = loss.rolling(window).mean() + rmi = 100 * (avg_gain / (avg_gain + avg_loss)) + return rmi + + def classify_trend(self, row, interval): + # Primary Conditions (1H) + ema_bullish = row['close'] > row[f'nifty_{interval}_ema_50'] > row[f'nifty_{interval}_ema_200'] + ema_bearish = row['close'] < row[f'nifty_{interval}_ema_50'] < row[f'nifty_{interval}_ema_200'] + #hma_bullish = row['close'] > row[f'nifty_{interval}_hma_50'] > row[f'nifty_{interval}_hma_200'] + #hma_bearish = row['close'] < row[f'nifty_{interval}_hma_50'] < row[f'nifty_{interval}_hma_200'] + rmi_strong = row[f'nifty_{interval}_RMI'] > 60 # RMI > 60 = strong trend + adx_strong = row[f'nifty_{interval}_adx'] > 20 + di_bullish = row[f'nifty_{interval}_+DI'] > row[f'nifty_{interval}_-DI'] + di_bearish = row[f'nifty_{interval}_-DI'] > row[f'nifty_{interval}_+DI'] + macd_bullish = row[f'nifty_{interval}_MACD'] > row[f'nifty_{interval}_MACD_Signal'] + macd_bearish = row[f'nifty_{interval}_MACD'] < row[f'nifty_{interval}_MACD_Signal'] + #volume_ok = row['Volume_Spike'] + + # Trend Logic + if ema_bullish and adx_strong and di_bullish and macd_bullish: + return 1 + elif ema_bearish and adx_strong and di_bearish and macd_bearish: + return -1 + else: + return 0 + + + def classify_trend_2(self, row, interval): + # Strong Uptrend + if (row['close'] > row[f'nifty_{interval}_ema_50'] > row[f'nifty_{interval}_ema_200']) and \ + (row[f'nifty_{interval}_MACD'] > row[f'nifty_{interval}_MACD_Signal']) and \ + (row[f'nifty_{interval}_adx'] > 25) and (row[f'nifty_{interval}_+DI'] > row[f'nifty_{interval}_-DI']): + return 2 + + # Strong Downtrend + elif (row['close'] < row[f'nifty_{interval}_ema_50'] < row[f'nifty_{interval}_ema_200']) and \ + (row[f'nifty_{interval}_MACD'] < row[f'nifty_{interval}_MACD_Signal']) and \ + (row[f'nifty_{interval}_adx'] > 25) and (row[f'nifty_{interval}_-DI'] > row[f'nifty_{interval}_+DI']): + return -2 + + # Up-Sideways + elif (row['close'] > row[f'nifty_{interval}_ema_50']) and \ + (row[f'nifty_{interval}_adx'] < 25) and \ + (abs(row[f'nifty_{interval}_MACD'] - row[f'nifty_{interval}_MACD_Signal']) < 1.0): # MACD hovering near signal + return 1 + + # Down-Sideways + elif (row['close'] < row[f'nifty_{interval}_ema_50']) and \ + (row[f'nifty_{interval}_adx'] < 25) and \ + (abs(row[f'nifty_{interval}_MACD'] - row[f'nifty_{interval}_MACD_Signal']) < 1.0): + return -1 + + # Neutral Sideways + else: + return 0 + + def classify_trend_3(self, row, interval): + rsi_1h = row[f'nifty_{interval}_RSI'] + adx_1h = row[f'nifty_{interval}_adx'] + rvol_1h = row[f'nifty_{interval}_RVOL'] + hh, hl, ll, lh = row[f'nifty_{interval}_HH'], row[f'nifty_{interval}_HL'], row[f'nifty_{interval}_LL'], row[f'nifty_{interval}_LH'] + + rsi_d = row['rsi_14'] + adx_d = row['adx_14'] + + # Pullback detection + pullback_bull = rsi_d > 55 and rsi_1h < rsi_d - 10 and rsi_1h > 40 + pullback_bear = rsi_d < 45 and rsi_1h > rsi_d + 10 and rsi_1h < 60 + + # Classification + if hh and hl and rsi_1h > 60 and adx_1h > 25 and rvol_1h > 1.5 and rsi_d > 55 and adx_d > 20: + return 2 + elif hh and hl and 50 < rsi_1h <= 60 and adx_1h < 25 and rsi_d > 55: + return 1 + elif pullback_bull: + return 1 + elif ll and lh and rsi_1h < 40 and adx_1h > 25 and rvol_1h > 1.5 and rsi_d < 45 and adx_d > 20: + return -2 + elif ll and lh and 40 <= rsi_1h < 50 and adx_1h < 25 and rsi_d < 45: + return -1 + elif pullback_bear: + return -1 + elif 45 <= rsi_1h <= 55 and adx_1h < 20: + return 0 + else: + return 0 + + def calculate_all_indicators_once(self, df_all_dict): + """ + Calculate all indicators once for the entire dataset + Returns: Dictionary with pre-calculated dataframes + """ + self.logger.info(f"Calculating indicators once for entire dataset for {self.symbol}") + + # Extract dataframes + df_15m = df_all_dict['15m'].copy() + df_5m = df_all_dict['5m'].copy() + df_1m = df_all_dict['1m'].copy() + df_daily = df_all_dict['d'].copy() + #df_nifty_1h = df_all_dict['1h'].copy() + df_nifty_15m = df_all_dict['nifty_15m'].copy() + + # === ENSURE TIME COLUMNS ARE DATETIME === + for df in [df_15m, df_5m, df_1m, df_daily, df_nifty_15m]: + if not df.empty: + df['time'] = pd.to_datetime(df['time']) + + # === EARLY RETURN IF CRITICAL DATA IS MISSING === + if df_15m.empty or df_5m.empty or df_daily.empty or df_nifty_15m.empty: + self.logger.warning(f"Missing critical data for {self.symbol} - skipping indicator calculations") + return { + '15m': df_15m, + '5m': df_5m, + '1m': df_1m, + 'd': df_daily, + 'nifty_15m': df_nifty_15m + } + + # === DAILY INDICATORS (ATR & Volume) === + atr_period = 14 + volume_period = 14 + df_daily['prev_close'] = df_daily['close'].shift(1) + df_daily['tr1'] = df_daily['high'] - df_daily['low'] + df_daily['tr2'] = abs(df_daily['high'] - df_daily['prev_close']) + df_daily['tr3'] = abs(df_daily['low'] - df_daily['prev_close']) + df_daily['tr'] = df_daily[['tr1', 'tr2', 'tr3']].max(axis=1) + df_daily['atr_10'] = df_daily['tr'].ewm(span=10, adjust=False).mean() + df_daily['volume_10'] = df_daily['volume'].rolling(window=10).mean() + df_daily['atr_14'] = df_daily['tr'].ewm(span=14, adjust=False).mean() + df_daily['volume_14'] = df_daily['volume'].rolling(window=14).mean() + df_daily['close_10'] = df_daily['close'].rolling(window=10).mean() + df_daily['close_14'] = df_daily['close'].rolling(window=14).mean() + df_daily['rsi_14'] = talib.RSI(df_daily['close'], timeperiod=14) + df_daily['adx_14'] = talib.ADX(df_daily['high'], df_daily['low'], df_daily['close'], timeperiod=14) + df_daily.drop(['prev_close', 'tr1', 'tr2', 'tr3', 'tr'], axis=1, inplace=True) + + # Add date columns for merging + df_15m['date'] = pd.to_datetime(df_15m['time'].dt.date) + df_5m['date'] = pd.to_datetime(df_5m['time'].dt.date) + df_daily['date'] = pd.to_datetime(df_daily['time'].dt.date) + #df_nifty_1h['date'] = pd.to_datetime(df_nifty_1h['time'].dt.date) + df_nifty_15m['date'] = pd.to_datetime(df_nifty_15m['time'].dt.date) + + # Merge ATR from daily data + df_15m = df_15m.merge(df_daily[['date', 'atr_10', 'volume_10', 'close_10', 'atr_14', 'volume_14', 'close_14']], on='date', how='left') + df_5m = df_5m.merge(df_daily[['date', 'atr_10', 'volume_10', 'close_10', 'atr_14', 'volume_14', 'close_14']], on='date', how='left') + + # === HOURLY INDICATORS (Nifty 50EMA) === + # df_nifty_1h = df_nifty_1h.merge(df_daily[['date', 'atr_10', 'volume_10', 'close_10', 'atr_14', 'volume_14', 'close_14', 'rsi_14', 'adx_14']], on='date', how='left') + + # df_nifty_1h['nifty_1hr_ema_50'] = df_nifty_1h['close'].ewm(span=50, adjust=False).mean() + # df_nifty_1h['nifty_1hr_ema_200'] = df_nifty_1h['close'].ewm(span=200, adjust=False).mean() + # df_nifty_1h['nifty_1hr_adx'] = talib.ADX(df_nifty_1h['high'], df_nifty_1h['low'], df_nifty_1h['close'], timeperiod=50) + # df_nifty_1h['nifty_1hr_+DI'] = talib.PLUS_DI(df_nifty_1h['high'], df_nifty_1h['low'], df_nifty_1h['close'], timeperiod=50) + # df_nifty_1h['nifty_1hr_-DI'] = talib.MINUS_DI(df_nifty_1h['high'], df_nifty_1h['low'], df_nifty_1h['close'], timeperiod=50) + # df_nifty_1h['nifty_1hr_MACD'], df_nifty_1h['nifty_1hr_MACD_Signal'], _ = talib.MACD(df_nifty_1h['close'], fastperiod=12, slowperiod=26, signalperiod=9) + # df_nifty_1h['nifty_1hr_Volume_MA20'] = talib.MA(df_nifty_1h['volume'], timeperiod=20) + # df_nifty_1h['nifty_1hr_Volume_Spike'] = df_nifty_1h['volume'] > 1.5 * df_nifty_1h['nifty_1hr_Volume_MA20'] + # df_nifty_1h['nifty_1hr_RSI'] = talib.RSI(df_nifty_1h['close'], timeperiod=14) + # df_nifty_1h['nifty_1hr_volume_sma_20'] = df_nifty_1h['volume'].rolling(window=20).mean() + # df_nifty_1h['nifty_1hr_RVOL'] = df_nifty_1h['volume'] / df_nifty_1h['nifty_1hr_volume_sma_20'] + + # # Price structure + # df_nifty_1h['nifty_1hr_HH'] = df_nifty_1h['high'] > df_nifty_1h['high'].shift(1) + # df_nifty_1h['nifty_1hr_HL'] = df_nifty_1h['low'] > df_nifty_1h['low'].shift(1) + # df_nifty_1h['nifty_1hr_LL'] = df_nifty_1h['low'] < df_nifty_1h['low'].shift(1) + # df_nifty_1h['nifty_1hr_LH'] = df_nifty_1h['high'] < df_nifty_1h['high'].shift(1) + + # # Trend (HMA replaces EMA) + # df_nifty_1h['nifty_1hr_hma_50'] = self.hull_moving_average(df_nifty_1h['close'], window=50) + # df_nifty_1h['nifty_1hr_hma_200'] = self.hull_moving_average(df_nifty_1h['close'], window=200) + # #df_nifty_1h['nifty_1hr_MACD'], df_nifty_1h['nifty_1hr_MACD_Signal'] = self.zero_lag_macd(df_nifty_1h['close']) + # df_nifty_1h['nifty_1hr_RMI'] = self.relative_momentum_index(df_nifty_1h['close']) + + # #df_nifty_1h['nifty_trend'] = np.where(df_nifty_1h['nifty_1hr_ema_50'] > (df_nifty_1h['nifty_1hr_ema_200'] * 1.01), 1, np.where(df_nifty_1h['nifty_1hr_ema_50'] < (df_nifty_1h['nifty_1hr_ema_200'] * 0.99), -1, 0)) + # #df_nifty_1h['nifty_trend'] = np.where(df_nifty_1h['nifty_1hr_hma_50'] > (df_nifty_1h['nifty_1hr_hma_200'] * 1.01), 1, np.where(df_nifty_1h['nifty_1hr_hma_50'] < (df_nifty_1h['nifty_1hr_hma_200'] * 0.99), -1, 0)) + # df_nifty_1h['nifty_trend'] = df_nifty_1h.apply(self.classify_trend, args=('1hr',), axis=1) + # #df_nifty_1h['nifty_trend'] = df_nifty_1h.apply(lambda row: self.classify_trend_2(row), axis=1) + # #df_nifty_1h['nifty_trend'] = df_nifty_1h.apply(lambda row: self.classify_trend_3(row), axis=1) + + # # Merge trend into the 5min and 15min df + # # Keep only the first record each day + # df_nifty_1h = df_nifty_1h.groupby('date').last().reset_index() + # df_nifty_1h['nifty_trend'] = df_nifty_1h['nifty_trend'].shift(1) + # df_15m = df_15m.merge(df_nifty_1h[['date', 'nifty_trend']], on='date', how='left') + # df_5m = df_5m.merge(df_nifty_1h[['date', 'nifty_trend']], on='date', how='left') + + # === NIFTY 15m INDICATORS (Nifty 50EMA) === + df_nifty_15m = df_nifty_15m.merge(df_daily[['date', 'atr_10', 'volume_10', 'close_10', 'atr_14', 'volume_14', 'close_14', 'rsi_14', 'adx_14']], on='date', how='left') + + df_nifty_15m['nifty_15m_ema_50'] = df_nifty_15m['close'].ewm(span=50, adjust=False).mean() + df_nifty_15m['nifty_15m_ema_200'] = df_nifty_15m['close'].ewm(span=200, adjust=False).mean() + df_nifty_15m['nifty_15m_adx'] = talib.ADX(df_nifty_15m['high'], df_nifty_15m['low'], df_nifty_15m['close'], timeperiod=125) + df_nifty_15m['nifty_15m_+DI'] = talib.PLUS_DI(df_nifty_15m['high'], df_nifty_15m['low'], df_nifty_15m['close'], timeperiod=125) + df_nifty_15m['nifty_15m_-DI'] = talib.MINUS_DI(df_nifty_15m['high'], df_nifty_15m['low'], df_nifty_15m['close'], timeperiod=125) + df_nifty_15m['nifty_15m_MACD'], df_nifty_15m['nifty_15m_MACD_Signal'], _ = talib.MACD(df_nifty_15m['close'], fastperiod=20, slowperiod=50, signalperiod=10) + df_nifty_15m['nifty_15m_Volume_MA20'] = talib.MA(df_nifty_15m['volume'], timeperiod=20) + df_nifty_15m['nifty_15m_Volume_Spike'] = df_nifty_15m['volume'] > 1.5 * df_nifty_15m['nifty_15m_Volume_MA20'] + df_nifty_15m['nifty_15m_RSI'] = talib.RSI(df_nifty_15m['close'], timeperiod=14) + df_nifty_15m['nifty_15m_volume_sma_20'] = df_nifty_15m['volume'].rolling(window=20).mean() + df_nifty_15m['nifty_15m_RVOL'] = df_nifty_15m['volume'] / df_nifty_15m['nifty_15m_volume_sma_20'] + + df_nifty_15m['nifty_15m_RMI'] = self.relative_momentum_index(df_nifty_15m['close']) + + df_nifty_15m['nifty_trend_15m'] = df_nifty_15m.apply(self.classify_trend, args=('15m',), axis=1) + #df_nifty_15m['nifty_trend'] = df_nifty_15m.apply(lambda row: self.classify_trend_2(row), axis=1) + #df_nifty_15m['nifty_trend'] = df_nifty_15m.apply(lambda row: self.classify_trend_3(row), axis=1) + + # Merge trend into the 5min and 15min df + # Keep only the first record each day + df_nifty_15m = df_nifty_15m.groupby('date').last().reset_index() + df_nifty_15m['nifty_trend_15m'] = df_nifty_15m['nifty_trend_15m'].shift(1) + df_15m = df_15m.merge(df_nifty_15m[['date', 'nifty_trend_15m']], on='date', how='left') + df_5m = df_5m.merge(df_nifty_15m[['date', 'nifty_trend_15m']], on='date', how='left') + + + # === 5MIN INDICATORS (EMAs) === + df_5m['ema_50'] = df_5m['close'].ewm(span=50, adjust=False).mean() + df_5m['ema_100'] = df_5m['close'].ewm(span=100, adjust=False).mean() + df_5m['ema_200'] = df_5m['close'].ewm(span=200, adjust=False).mean() + + # === RANGE CALCULATIONS - 15m === + df_15m['range'] = df_15m['high'] - df_15m['low'] + df_15m['date_only'] = df_15m['time'].dt.date + df_15m['avg_range_all'] = df_15m.groupby('date_only')['range'].expanding().mean().reset_index(level=0, drop=True) + + # Fix for "Cannot set a DataFrame with multiple columns" error + try: + avg_ex_first_30min_15m = df_15m.groupby('date_only').apply(self.exclude_first_30min) + # Ensure we get a Series, not DataFrame + if isinstance(avg_ex_first_30min_15m, pd.DataFrame): + avg_ex_first_30min_15m = avg_ex_first_30min_15m.iloc[:, 0] # Take first column + avg_ex_first_30min_15m = avg_ex_first_30min_15m.reset_index(level=0, drop=True) + df_15m['avg_range_ex_first_30min'] = avg_ex_first_30min_15m.reindex(df_15m.index).ffill() + except Exception as e: + self.logger.warning(f"Error calculating avg_range_ex_first_30min for 15m: {e}, using avg_range_all instead") + df_15m['avg_range_ex_first_30min'] = df_15m['avg_range_all'] + df_15m['is_range_bullish'] = ( + (df_15m['range'] > 0.7 * df_15m['avg_range_ex_first_30min']) & + (df_15m['close'] > df_15m['open']) & + (df_15m['close'] > (((df_15m['high'] - df_15m['open']) * 0.5) + df_15m['open'])) + ) + df_15m['is_range_bearish'] = ( + (df_15m['range'] > 0.7 * df_15m['avg_range_ex_first_30min']) & + (df_15m['close'] < df_15m['open']) & + (df_15m['close'] < (((df_15m['open'] - df_15m['low']) * 0.5) + df_15m['low'])) + ) + df_15m.drop('date_only', axis=1, inplace=True) + + # === RANGE CALCULATIONS - 5m === + df_5m['range'] = df_5m['high'] - df_5m['low'] + df_5m['date_only'] = df_5m['time'].dt.date + df_5m['avg_range_all'] = df_5m.groupby('date_only')['range'].expanding().mean().reset_index(level=0, drop=True) + + # Fix for "Cannot set a DataFrame with multiple columns" error + try: + avg_ex_first_30min_5m = df_5m.groupby('date_only').apply(self.exclude_first_30min) + # Ensure we get a Series, not DataFrame + if isinstance(avg_ex_first_30min_5m, pd.DataFrame): + avg_ex_first_30min_5m = avg_ex_first_30min_5m.iloc[:, 0] # Take first column + avg_ex_first_30min_5m = avg_ex_first_30min_5m.reset_index(level=0, drop=True) + df_5m['avg_range_ex_first_30min'] = avg_ex_first_30min_5m.reindex(df_5m.index).ffill() + except Exception as e: + self.logger.warning(f"Error calculating avg_range_ex_first_30min for 5m: {e}, using avg_range_all instead") + df_5m['avg_range_ex_first_30min'] = df_5m['avg_range_all'] + df_5m['is_range_bullish'] = ( + (df_5m['range'] > 0.7 * df_5m['avg_range_ex_first_30min']) & + (df_5m['close'] > df_5m['open']) & + (df_5m['close'] > (((df_5m['high'] - df_5m['open']) * 0.5) + df_5m['open'])) + ) + df_5m['is_range_bearish'] = ( + (df_5m['range'] > 0.7 * df_5m['avg_range_ex_first_30min']) & + (df_5m['close'] < df_5m['open']) & + (df_5m['close'] < (((df_5m['open'] - df_5m['low']) * 0.5) + df_5m['low'])) + ) + df_5m.drop('date_only', axis=1, inplace=True) + + # === ZERO LAG MACD (15m only) === + fast_period, slow_period, signal_period = 12, 26, 9 + df_15m['fast_zlema'] = self.zero_lag_ema(df_15m['close'], fast_period) + df_15m['slow_zlema'] = self.zero_lag_ema(df_15m['close'], slow_period) + df_15m['zl_macd'] = df_15m['fast_zlema'] - df_15m['slow_zlema'] + df_15m['zl_signal'] = df_15m['zl_macd'].ewm(span=signal_period, adjust=False).mean() + df_15m['zl_hist'] = df_15m['zl_macd'] - df_15m['zl_signal'] + + # Generate MACD Signals + df_15m['zl_macd_signal'] = 0 + df_15m.loc[(df_15m['zl_macd'] > df_15m['zl_signal']) & + (df_15m['zl_macd'].shift(1) <= df_15m['zl_signal'].shift(1)), 'zl_macd_signal'] = 1 + df_15m.loc[(df_15m['zl_macd'] < df_15m['zl_signal']) & + (df_15m['zl_macd'].shift(1) >= df_15m['zl_signal'].shift(1)), 'zl_macd_signal'] = -1 + df_15m.drop(['fast_zlema', 'slow_zlema', 'zl_macd', 'zl_signal', 'zl_hist'], axis=1, inplace=True) + + # === SINGLE PRINT CALCULATIONS - 15m === + df_15m['is_first_bullish_confirmed'] = False + df_15m['is_first_bearish_confirmed'] = False + df_15m['candle_count'] = df_15m.groupby(df_15m['date']).cumcount() + 1 + df_15m['cum_high_prev'] = df_15m.groupby('date')['high'].expanding().max().shift(1).reset_index(level=0, drop=True) + df_15m['cum_low_prev'] = df_15m.groupby('date')['low'].expanding().min().shift(1).reset_index(level=0, drop=True) + df_15m['cum_high'] = df_15m.groupby('date')['high'].expanding().max().reset_index(level=0, drop=True) + df_15m['cum_low'] = df_15m.groupby('date')['low'].expanding().min().reset_index(level=0, drop=True) + df_15m['sp_confirmed_bullish'] = ( + (df_15m['close'] > df_15m['cum_high_prev']) & + (df_15m['close'] > df_15m['open']) & + (df_15m['candle_count'] >= 2) + ) + df_15m['sp_confirmed_bearish'] = ( + (df_15m['close'] < df_15m['cum_low_prev']) & + (df_15m['close'] < df_15m['open']) & + (df_15m['candle_count'] >= 2) + ) + + # Mark first confirmations + bullish_conf_15m = df_15m[df_15m['sp_confirmed_bullish']] + bearish_conf_15m = df_15m[df_15m['sp_confirmed_bearish']] + first_bullish_idx_15m = bullish_conf_15m.groupby('date').head(1).index + first_bearish_idx_15m = bearish_conf_15m.groupby('date').head(1).index + df_15m.loc[first_bullish_idx_15m, 'is_first_bullish_confirmed'] = True + df_15m.loc[first_bearish_idx_15m, 'is_first_bearish_confirmed'] = True + + # SP levels for 15m + sp_levels_bullish_15m = df_15m[df_15m['is_first_bullish_confirmed']][['date', 'close', 'cum_high_prev']] + sp_levels_bearish_15m = df_15m[df_15m['is_first_bearish_confirmed']][['date', 'close', 'cum_low_prev']] + sp_levels_bullish_15m['sp_high_bullish'] = sp_levels_bullish_15m['close'] + sp_levels_bullish_15m['sp_low_bullish'] = sp_levels_bullish_15m['cum_high_prev'] + sp_levels_bearish_15m['sp_high_bearish'] = sp_levels_bearish_15m['cum_low_prev'] + sp_levels_bearish_15m['sp_low_bearish'] = sp_levels_bearish_15m['close'] + sp_levels_bullish_15m.drop(['close', 'cum_high_prev'], axis=1, inplace=True) + sp_levels_bearish_15m.drop(['close', 'cum_low_prev'], axis=1, inplace=True) + + # Merge back for 15m + df_15m = df_15m.merge(sp_levels_bullish_15m, on='date', how='left') + df_15m = df_15m.merge(sp_levels_bearish_15m, on='date', how='left') + + # Forward fill SP levels for 15m + df_15m['sp_high_bullish'] = df_15m.groupby('date')['sp_high_bullish'].transform(lambda x: x.ffill() if x.notna().any() else x) + df_15m['sp_low_bullish'] = df_15m.groupby('date')['sp_low_bullish'].transform(lambda x: x.ffill() if x.notna().any() else x) + df_15m['sp_high_bearish'] = df_15m.groupby('date')['sp_high_bearish'].transform(lambda x: x.ffill() if x.notna().any() else x) + df_15m['sp_low_bearish'] = df_15m.groupby('date')['sp_low_bearish'].transform(lambda x: x.ffill() if x.notna().any() else x) + + # Set pre-confirmation values to NaN for 15m + df_15m.loc[~df_15m['sp_confirmed_bullish'].cummax(), ['sp_high_bullish', 'sp_low_bullish']] = None + df_15m.loc[~df_15m['sp_confirmed_bearish'].cummax(), ['sp_high_bearish', 'sp_low_bearish']] = None + + # Calculate SP range percentages for 15m + df_15m['sp_bullish_range_pct'] = (df_15m['sp_high_bullish'] - df_15m['sp_low_bullish']) / df_15m['sp_low_bullish'] * 100 + df_15m['sp_bearish_range_pct'] = (df_15m['sp_high_bearish'] - df_15m['sp_low_bearish']) / df_15m['sp_low_bearish'] * 100 + df_15m['cum_sp_bullish'] = df_15m.groupby('date')['sp_confirmed_bullish'].cumsum() + df_15m['cum_sp_bearish'] = df_15m.groupby('date')['sp_confirmed_bearish'].cumsum() + + # === SINGLE PRINT CALCULATIONS - 5m === + df_5m['is_first_bullish_confirmed'] = False + df_5m['is_first_bearish_confirmed'] = False + df_5m['candle_count'] = df_5m.groupby(df_5m['date']).cumcount() + 1 + df_5m['cum_high_prev'] = df_5m.groupby('date')['high'].expanding().max().shift(1).reset_index(level=0, drop=True) + df_5m['cum_low_prev'] = df_5m.groupby('date')['low'].expanding().min().shift(1).reset_index(level=0, drop=True) + df_5m['cum_high'] = df_5m.groupby('date')['high'].expanding().max().reset_index(level=0, drop=True) + df_5m['cum_low'] = df_5m.groupby('date')['low'].expanding().min().reset_index(level=0, drop=True) + df_5m['sp_confirmed_bullish'] = ( + (df_5m['close'] > df_5m['cum_high_prev']) & + (df_5m['close'] > df_5m['open']) & + (df_5m['candle_count'] >= 2) + ) + df_5m['sp_confirmed_bearish'] = ( + (df_5m['close'] < df_5m['cum_low_prev']) & + (df_5m['close'] < df_5m['open']) & + (df_5m['candle_count'] >= 2) + ) + + # Mark first confirmations for 5m + bullish_conf_5m = df_5m[df_5m['sp_confirmed_bullish']] + bearish_conf_5m = df_5m[df_5m['sp_confirmed_bearish']] + first_bullish_idx_5m = bullish_conf_5m.groupby('date').head(1).index + first_bearish_idx_5m = bearish_conf_5m.groupby('date').head(1).index + df_5m.loc[first_bullish_idx_5m, 'is_first_bullish_confirmed'] = True + df_5m.loc[first_bearish_idx_5m, 'is_first_bearish_confirmed'] = True + + # SP levels for 5m + sp_levels_bullish_5m = df_5m[df_5m['is_first_bullish_confirmed']][['date', 'close', 'cum_high_prev']] + sp_levels_bearish_5m = df_5m[df_5m['is_first_bearish_confirmed']][['date', 'close', 'cum_low_prev']] + sp_levels_bullish_5m['sp_high_bullish'] = sp_levels_bullish_5m['close'] + sp_levels_bullish_5m['sp_low_bullish'] = sp_levels_bullish_5m['cum_high_prev'] + sp_levels_bearish_5m['sp_high_bearish'] = sp_levels_bearish_5m['cum_low_prev'] + sp_levels_bearish_5m['sp_low_bearish'] = sp_levels_bearish_5m['close'] + sp_levels_bullish_5m.drop(['close', 'cum_high_prev'], axis=1, inplace=True) + sp_levels_bearish_5m.drop(['close', 'cum_low_prev'], axis=1, inplace=True) + + # Merge back for 5m + df_5m = df_5m.merge(sp_levels_bullish_5m, on='date', how='left') + df_5m = df_5m.merge(sp_levels_bearish_5m, on='date', how='left') + + # Forward fill SP levels for 5m + df_5m['sp_high_bullish'] = df_5m.groupby('date')['sp_high_bullish'].transform(lambda x: x.ffill() if x.notna().any() else x) + df_5m['sp_low_bullish'] = df_5m.groupby('date')['sp_low_bullish'].transform(lambda x: x.ffill() if x.notna().any() else x) + df_5m['sp_high_bearish'] = df_5m.groupby('date')['sp_high_bearish'].transform(lambda x: x.ffill() if x.notna().any() else x) + df_5m['sp_low_bearish'] = df_5m.groupby('date')['sp_low_bearish'].transform(lambda x: x.ffill() if x.notna().any() else x) + + # Set pre-confirmation values to NaN for 5m + df_5m.loc[~df_5m['sp_confirmed_bullish'].cummax(), ['sp_high_bullish', 'sp_low_bullish']] = None + df_5m.loc[~df_5m['sp_confirmed_bearish'].cummax(), ['sp_high_bearish', 'sp_low_bearish']] = None + + # Calculate SP range percentages for 5m + df_5m['sp_bullish_range_pct'] = (df_5m['sp_high_bullish'] - df_5m['sp_low_bullish']) / df_5m['sp_low_bullish'] * 100 + df_5m['sp_bearish_range_pct'] = (df_5m['sp_high_bearish'] - df_5m['sp_low_bearish']) / df_5m['sp_low_bearish'] * 100 + df_5m['cum_sp_bullish'] = df_5m.groupby('date')['sp_confirmed_bullish'].cumsum() + df_5m['cum_sp_bearish'] = df_5m.groupby('date')['sp_confirmed_bearish'].cumsum() + + # === VOLUME & RANGE CALCULATIONS - 15m === + df_15m['cum_intraday_volume'] = df_15m.groupby('date')['volume'].cumsum() + df_15m['curtop'] = df_15m.groupby('date')['high'].cummax() + df_15m['curbot'] = df_15m.groupby('date')['low'].cummin() + df_15m['predicted_today_high'] = df_15m['curbot'] + df_15m['atr_10'] + df_15m['predicted_today_low'] = df_15m['curtop'] - df_15m['atr_10'] + df_15m['today_range'] = df_15m['curtop'] - df_15m['curbot'] + df_15m['today_range_pct_10'] = df_15m['today_range'] / df_15m['atr_10'] + df_15m['today_range_pct_14'] = df_15m['today_range'] / df_15m['atr_14'] + df_15m['volume_range_pct_10'] = (df_15m['cum_intraday_volume'] / df_15m['volume_10']) / df_15m['today_range_pct_10'] + df_15m['volume_range_pct_14'] = (df_15m['cum_intraday_volume'] / df_15m['volume_14']) / df_15m['today_range_pct_14'] + + # === VOLUME & RANGE CALCULATIONS - 5m === + df_5m['cum_intraday_volume'] = df_5m.groupby('date')['volume'].cumsum() + df_5m['curtop'] = df_5m.groupby('date')['high'].cummax() + df_5m['curbot'] = df_5m.groupby('date')['low'].cummin() + df_5m['predicted_today_high'] = df_5m['curbot'] + df_5m['atr_10'] + df_5m['predicted_today_low'] = df_5m['curtop'] - df_5m['atr_10'] + df_5m['today_range'] = df_5m['curtop'] - df_5m['curbot'] + df_5m['today_range_pct_10'] = df_5m['today_range'] / df_5m['atr_10'] + df_5m['today_range_pct_14'] = df_5m['today_range'] / df_5m['atr_14'] + df_5m['volume_range_pct_10'] = (df_5m['cum_intraday_volume'] / df_5m['volume_10']) / df_5m['today_range_pct_10'] + df_5m['volume_range_pct_14'] = (df_5m['cum_intraday_volume'] / df_5m['volume_14']) / df_5m['today_range_pct_14'] + + # === STRATEGY DEFINITIONS === + # Strategy 8 & 12 (15m) + df_15m['s_8'] = ( + (df_15m['time'].dt.time >= time(4, 0)) & + (df_15m['time'].dt.time < time(8, 15)) & + (df_15m['cum_sp_bullish'] >= 1) & + (df_15m['sp_bullish_range_pct'] > 0.8) & + (df_15m['sp_bullish_range_pct'] < 1.3) & + (df_15m['zl_macd_signal'] == -1) & + (df_15m['volume_range_pct_10'] > 1) & + (df_15m['atr_10'] / df_15m['close_10'] < 0.04) & + (df_15m['nifty_trend_15m'] >= 0) + ) + df_15m['strategy_8'] = False + first_true_idx_8 = df_15m[df_15m['s_8']].groupby('date').head(1).index + df_15m.loc[first_true_idx_8, 'strategy_8'] = True + + df_15m['s_12'] = ( + (df_15m['time'].dt.time >= time(4, 0)) & + (df_15m['time'].dt.time < time(8, 15)) & + (df_15m['cum_sp_bearish'] >= 1) & + (df_15m['sp_bearish_range_pct'] > 1) & + (df_15m['zl_macd_signal'] == 1) & + (df_15m['volume_range_pct_10'] > 0) & + (df_15m['volume_range_pct_10'] < 0.4) & + (df_15m['atr_10'] / df_15m['close_10'] < 0.04) & + (df_15m['nifty_trend_15m'] <= 0) + ) + df_15m['strategy_12'] = False + first_true_idx_12 = df_15m[df_15m['s_12']].groupby('date').head(1).index + df_15m.loc[first_true_idx_12, 'strategy_12'] = True + + + # Strategy 10 & 11 (5m) + df_5m['s_10'] = ( + (df_5m['time'].dt.time >= time(3, 50)) & + (df_5m['time'].dt.time < time(8, 15)) & + (df_5m['cum_sp_bearish'] >= 1) & + (df_5m['sp_bearish_range_pct'] > 0.6) & + (df_5m['close'] < df_5m['ema_50']) & + (df_5m['close'] < df_5m['ema_100']) & + (df_5m['close'] < df_5m['ema_200']) & + (df_5m['is_range_bearish']) & + (df_5m['volume_range_pct_10'] > 0.3) & + (df_5m['volume_range_pct_10'] < 0.7) & + (df_5m['atr_10'] / df_5m['close_10'] > 0.04) & + (df_5m['nifty_trend_15m'] != 1) + ) + df_5m['strategy_10'] = False + first_true_idx_10 = df_5m[df_5m['s_10']].groupby('date').head(1).index + df_5m.loc[first_true_idx_10, 'strategy_10'] = True + + df_5m['s_11'] = ( + (df_5m['time'].dt.time >= time(3, 50)) & + (df_5m['time'].dt.time < time(8, 15)) & + (df_5m['cum_sp_bullish'] >= 1) & + (df_5m['sp_bullish_range_pct'] > 0.8) & + (df_5m['close'] > df_5m['ema_50']) & + (df_5m['close'] > df_5m['ema_100']) & + (df_5m['close'] > df_5m['ema_200']) & + (df_5m['is_range_bullish']) & + (df_5m['volume_range_pct_10'] > 0) & + (df_5m['volume_range_pct_10'] < 0.3) & + (df_5m['atr_10'] / df_5m['close_10'] > 0.04) & + (df_5m['nifty_trend_15m'] != -1) + ) + df_5m['strategy_11'] = False + first_true_idx_11 = df_5m[df_5m['s_11']].groupby('date').head(1).index + df_5m.loc[first_true_idx_11, 'strategy_11'] = True + + df_5m['s_9'] = ( + (df_5m['time'].dt.time >= time(3, 50)) & + (df_5m['time'].dt.time < time(8, 15)) & + (df_5m['cum_sp_bullish'] >= 1) & + (df_5m['sp_bullish_range_pct'] > 0.8) & + (df_5m['close'] > df_5m['ema_50']) & + (df_5m['close'] > df_5m['ema_100']) & + (df_5m['close'] > df_5m['ema_200']) & + (df_5m['is_range_bullish']) & + (df_5m['volume_range_pct_10'] > 0.3) & + (df_5m['volume_range_pct_10'] < 0.6) & + (df_5m['atr_10'] / df_5m['close_10'] < 0.04) & + (df_5m['nifty_trend_15m'] != 1) + ) + df_5m['strategy_9'] = False + first_true_idx_9 = df_5m[df_5m['s_9']].groupby('date').head(1).index + df_5m.loc[first_true_idx_9, 'strategy_9'] = True + + # Clean up date columns + df_15m.drop('date', axis=1, inplace=True) + df_5m.drop('date', axis=1, inplace=True) + + self.logger.info(f"βœ… All indicators calculated once for {self.symbol}") + + #df_15m.to_csv('15m.csv', index=False) + #df_5m.to_csv('5m.csv', index=False) + + return { + '15m': df_15m, + '5m': df_5m, + '1m': df_1m, + 'd': df_daily, + #'1h': df_nifty_1h, + 'nifty_15m': df_nifty_15m + } + + def get_day_data_optimized(self, df_with_indicators, day, lookback_days): + """ + Fast slicing of pre-calculated data for a specific day + """ + if df_with_indicators.empty: + return df_with_indicators + + start_date = day - timedelta(days=lookback_days) + end_date = day + timedelta(days=1) # Include full day + + # Ensure time column is datetime + if not pd.api.types.is_datetime64_any_dtype(df_with_indicators['time']): + df_with_indicators['time'] = pd.to_datetime(df_with_indicators['time']) + + # Use boolean indexing (faster than .loc for large datasets) + mask = (df_with_indicators['time'].dt.date >= start_date) & (df_with_indicators['time'].dt.date <= day) + return df_with_indicators[mask].copy() + + def get_today_data_only(self, df_with_indicators, day): + """Get only today's data (for trading logic)""" + if df_with_indicators.empty: + return df_with_indicators + + # Ensure time column is datetime + if not pd.api.types.is_datetime64_any_dtype(df_with_indicators['time']): + df_with_indicators['time'] = pd.to_datetime(df_with_indicators['time']) + + start_dt = datetime.combine(day, time.min).replace(tzinfo=pytz.UTC) + end_dt = datetime.combine(day, time.max).replace(tzinfo=pytz.UTC) + + mask = (df_with_indicators['time'] >= start_dt) & (df_with_indicators['time'] <= end_dt) + return df_with_indicators[mask].copy() + + + def pre_index_data_by_date(self, df_dict): + """Pre-index all dataframes by date for O(1) lookups""" + indexed_data = {} + for interval, df in df_dict.items(): + df['date'] = df['time'].dt.date + indexed_data[interval] = df.groupby('date') + return indexed_data + + def run(self): + self.logger.info(f"Running backtest from {self.start_date} to {self.end_date}") + trades = [] + self.trailing_chart_data = [] + in_position = False + + # === STEP 1: Fetch all data once === + df_all_dict = { + '15m': self.fetch_lookback_data_2(self.start_date - timedelta(days=20), self.end_date, '15m'), + '5m': self.fetch_lookback_data_2(self.start_date - timedelta(days=10), self.end_date, '5m'), + '1m': self.fetch_lookback_data_2(self.start_date - timedelta(days=3), self.end_date, '1m'), + 'd': self.fetch_lookback_data_2(self.start_date - timedelta(days=20), self.end_date, 'd'), + #'1h': self.fetch_lookback_data_2(self.start_date - timedelta(days=45), self.end_date, '1h'), + 'nifty_15m': self.fetch_lookback_data_2(self.start_date - timedelta(days=20), self.end_date, 'nifty_15m'), + } + + # === STEP 2: Calculate all indicators once === + try: + data_with_indicators = self.calculate_all_indicators_once(df_all_dict) + except Exception as e: + self.logger.error(f"Error in calculate_all_indicators_once for {self.symbol}: {e}") + raise + + # === STEP 3: Process each day with pre-calculated data === + for day in self.daterange(): + try: + # Fast slicing of pre-calculated data (no more daily filtering/copying) + df = self.get_day_data_optimized(data_with_indicators['15m'], day, 12) + df_5min = self.get_day_data_optimized(data_with_indicators['5m'], day, 6) + df_min = self.get_day_data_optimized(data_with_indicators['1m'], day, 2) + #df_daily = self.get_day_data_optimized(data_with_indicators['d'], day, 20) + #df_nifty_1h = self.get_day_data_optimized(data_with_indicators['1h'], day, 45) + except Exception as e: + self.logger.error(f"Error in get_day_data_optimized for {self.symbol} on {day}: {e}") + continue + + if df.empty or df_5min.empty or df_min.empty: + continue + + # Get today's data only (for trading logic) + df_today = self.get_today_data_only(df, day) + df_today_5min = self.get_today_data_only(df_5min, day) + df_min_today = self.get_today_data_only(df_min, day) + + # Safety check: Skip if strategy columns are missing + required_strategy_cols = ['strategy_8', 'strategy_12', 'strategy_10', 'strategy_11', 'strategy_9'] + if not all(col in df_today.columns for col in required_strategy_cols[:2]) and not df_today.empty: + self.logger.warning(f"Missing 15m strategy columns for {self.symbol} on {day}, skipping") + continue + if not all(col in df_today_5min.columns for col in required_strategy_cols[2:]) and not df_today_5min.empty: + self.logger.warning(f"Missing 5m strategy columns for {self.symbol} on {day}, skipping") + continue + + entry_row = None + + # 15m dataframe - Strategy 8 and 12 + for i in range(1, len(df_today)): + prev = df_today.iloc[i - 1] + curr = df_today.iloc[i] + + curr_ist_time = curr['time'].astimezone(IST).time() + curr_ist_time_2 = curr['time'].astimezone(IST) + + # Safely access strategy columns with defaults + buy = curr.get('strategy_12', False) + sell = False + short = curr.get('strategy_8', False) + cover = False + + # ENTRY: Long + if not in_position and buy and curr_ist_time <= time(14, 30) and self.position == 0: + self.logger.info(f"[{curr_ist_time_2}] LONG ENTRY triggered at {curr['close']}") + self.position = 1 + entry_row = curr + capital_per_trade = self.capital * self.leverage * (self.capital_alloc_pct / 100) + quantity = int(capital_per_trade / entry_row['close']) + trailing_active = False + trail_stop = None + last_trail_price = None + trail_history = [] + in_position = True + self.entry_price = entry_row['close'] + strategy_id = '12' + entry_time = curr['time'] + timedelta(minutes=14) # 15min as we are working on 15min candles, else replace it with 5min + #self.logger.info(f"Entry confirmed: position={self.position}, entry_price={entry_row['close']}, in_position={in_position}") + break # Exit the 15m loop after entry + + # ENTRY: Short + if not in_position and short and curr_ist_time <= time(14, 30) and self.position == 0: + self.logger.info(f"[{curr_ist_time_2}] SHORT ENTRY triggered at {curr['close']}") + self.position = -1 + entry_row = curr + capital_per_trade = self.capital * self.leverage * (self.capital_alloc_pct / 100) + quantity = int(capital_per_trade / entry_row['close']) + trailing_active = False + trail_stop = None + last_trail_price = None + trail_history = [] + in_position = True + self.entry_price = entry_row['close'] + strategy_id = '8' + entry_time = curr['time'] + timedelta(minutes=14) # 15min as we are working on 15min candles, else replace it with 5min + #self.logger.info(f"Entry confirmed: position={self.position}, entry_price={entry_row['close']}, in_position={in_position}") + break # Exit the 15m loop after entry + + #self.logger.info(f"[{curr_ist_time_2}] Position check: position={self.position}, in_position={in_position}, entry_row={entry_row is not None}") + # EXIT + # TRADE MONITORING WHILE IN POSITION + if self.position != 0 and entry_row is not None and in_position: + # Find the corresponding minute in df_min after our entry + min_entries = df_min_today[df_min_today['time'] >= entry_time] + + for i in range(len(min_entries)): + curr_min = min_entries.iloc[i] + curr_min_ist_time = curr_min['time'].astimezone(IST).time() + curr_min_ist_time_2 = curr_min['time'].astimezone(IST) + price = curr_min['close'] + entry_price = entry_row['close'] + exit_row = None + exit_reason = None + #self.logger.info(f"[{curr_ist_time}] TRADE MONITORING: Position: {self.position}, Entry Price: {entry_price}, Current Price: {price}") + + is_long = self.position == 1 + is_short = self.position == -1 + price_change = price - entry_price + pct_move = price_change / entry_price + + # Targets + tp_hit = pct_move >= self.tp_pct if is_long else pct_move <= -self.tp_pct + sl_hit = pct_move <= -self.sl_pct if is_long else pct_move >= self.sl_pct + trail_trigger_pct = self.trail_activation_pct if is_long else -self.trail_activation_pct + trail_increment = self.trail_increment_pct * entry_price + trail_gap = self.trail_stop_gap_pct * entry_price + + #self.logger.info(f"[{curr_ist_time}] Evaluating exit for {'SHORT' if self.position == -1 else 'LONG'} at price {price}") + #self.logger.info(f"[{curr_ist_time_2}] Evaluating exit for {'SHORT' if self.position == -1 else 'LONG'} at price {price}") + + # TP + if tp_hit: + exit_row = curr_min + exit_reason = "TP" + self.logger.info(f"[{curr_min_ist_time_2}] TP HIT at {price}") + + # SL + elif sl_hit: + exit_row = curr_min + exit_reason = "SL" + self.logger.info(f"[{curr_min_ist_time_2}] SL HIT at {price}") + + # Trailing Stop + elif trailing_active: + # Price moved enough to adjust trail + if abs(price - last_trail_price) >= trail_increment: + if is_long: + trail_stop += trail_increment + else: + trail_stop -= trail_increment + last_trail_price = price + trail_history.append({ + 'time': curr_min_ist_time_2, + 'value': trail_stop + }) + + # Price hit trailing stop + if (is_long and price <= trail_stop) or (is_short and price >= trail_stop): + exit_row = curr_min + exit_reason = "TRAIL" + self.logger.info(f"[{curr_min_ist_time_2}] TRAIL STOP HIT at {price} vs stop {trail_stop}") + + # Activate trailing + elif (pct_move >= self.trail_activation_pct if is_long else pct_move <= -self.trail_activation_pct): + trailing_active = True + trail_stop = price - trail_gap if is_long else price + trail_gap + last_trail_price = price + + # Strategy exit signal + # elif (crossunder if is_long else crossover): + # exit_row = curr + # exit_reason = "CROSSOVER" if is_long else "CROSSUNDER" + # self.logger.info(f"[{curr_ist_time_2}] CROSSOVER HIT for SHORT at EMA fast {curr['ema_fast']} vs slow {curr['ema_slow']}") + + # Force exit based on direction and time (in UTC) + if exit_row is None: + #self.logger.info(f"[{curr_ist_time_2}] No exit condition met, checking force exit conditions") + if is_long and curr_min_ist_time >= time(15, 9): # 3:10 PM IST + exit_row = curr_min + exit_reason = "FORCE_EXIT_1510" + self.logger.info(f"[{curr_min_ist_time_2}] FORCE EXIT triggered for {'LONG' if is_long else 'SHORT'}") + + elif is_short and curr_min_ist_time >= time(14, 39): # 2:40 PM IST + exit_row = curr_min + exit_reason = "FORCE_EXIT_1440" + self.logger.info(f"[{curr_min_ist_time_2}] FORCE EXIT triggered for {'LONG' if is_long else 'SHORT'}") + + # Process exit + if exit_row is not None: + self.logger.info(f"[{curr_min_ist_time_2}] EXITING TRADE: {exit_reason} at {price}") + trade_df = df_min_today[(df_min_today['time'] >= entry_time) & (df_min_today['time'] <= curr_min['time'])] + trade_close = trade_df['close'] + + mae = (trade_close.min() - entry_price if is_long else entry_price - trade_close.max()) if not trade_df.empty else 0 + mfe = (trade_close.max() - entry_price if is_long else entry_price - trade_close.min()) if not trade_df.empty else 0 + pnl = (price - entry_price if is_long else entry_price - price) + + trades.append({ + 'symbol': self.symbol, + 'strategy': strategy_id, + 'quantity': quantity, + 'entry_time': entry_time.astimezone(IST).strftime("%Y-%m-%d %H:%M:%S"), + 'entry_price': round(entry_price, 2), + 'exit_time': exit_row['time'].astimezone(IST).strftime("%Y-%m-%d %H:%M:%S"), + 'exit_price': round(price, 2), + 'exit_reason': exit_reason, + 'gross_pnl': round(pnl * quantity, 2), + 'cumulative_pnl': 0, # Filled later + 'mae': round(mae * quantity, 2), + 'mfe': round(mfe * quantity, 2), + 'holding_period': str(exit_row['time'] - entry_row['time']), + 'direction': "LONG" if is_long else "SHORT", + 'capital_used': round(quantity * price, 2), + 'tax': round(quantity * price * 2 * 0.0002, 2), + 'brokerage': 0, + 'net_pnl': round((pnl * quantity) - (2 * quantity * price * 0.0002) - 0, 2), + }) + + if exit_reason == "TRAIL": + self.trailing_chart_data.append({ + 'symbol': self.symbol, + 'entry_time': entry_time, + 'exit_time': exit_row['time'], + 'entry_price': price, + 'df': df_min_today.copy(), + 'trail_stops': trail_history + }) + + # Reset + self.position = 0 + self.entry_price = None + entry_time = None + entry_row = None + trailing_active = False + trail_stop = None + last_trail_price = None + in_position = False + break # Exit the 1m loop after exit + + + entry_row = None + # 5m dataframe - Strategy 9, 10 and 11 + for i in range(1, len(df_today_5min)): + prev = df_today_5min.iloc[i - 1] + curr = df_today_5min.iloc[i] + + curr_ist_time = curr['time'].astimezone(IST).time() + curr_ist_time_2 = curr['time'].astimezone(IST) + + # Safely access strategy columns with defaults + buy = curr.get('strategy_11', False) + sell = False + short = curr.get('strategy_10', False) or curr.get('strategy_9', False) + cover = False + + # ENTRY: Long + if not in_position and buy and curr_ist_time <= time(14, 30) and self.position == 0: + self.logger.info(f"[{curr_ist_time_2}] LONG ENTRY triggered at {curr['close']}") + self.position = 1 + entry_row = curr + capital_per_trade = self.capital * self.leverage * (self.capital_alloc_pct / 100) + quantity = int(capital_per_trade / entry_row['close']) + trailing_active = False + trail_stop = None + last_trail_price = None + trail_history = [] + in_position = True + self.entry_price = entry_row['close'] + strategy_id = '11' + entry_time = curr['time'] + timedelta(minutes=4) # 5min as we are working on 5min candles, else replace it with 5min + #self.logger.info(f"Entry confirmed: position={self.position}, entry_price={entry_row['close']}, in_position={in_position}") + break # Exit the 15m loop after entry + + # ENTRY: Short + if not in_position and short and curr_ist_time <= time(14, 30) and self.position == 0: + self.logger.info(f"[{curr_ist_time_2}] SHORT ENTRY triggered at {curr['close']}") + self.position = -1 + entry_row = curr + capital_per_trade = self.capital * self.leverage * (self.capital_alloc_pct / 100) + quantity = int(capital_per_trade / entry_row['close']) + trailing_active = False + trail_stop = None + last_trail_price = None + trail_history = [] + in_position = True + self.entry_price = entry_row['close'] + #strategy_id = '10' + strategy_id = '10' if curr.get('strategy_10', False) else '9' + entry_time = curr['time'] + timedelta(minutes=4) # 5min as we are working on 5min candles, else replace it with 15min + #self.logger.info(f"Entry confirmed: position={self.position}, entry_price={entry_row['close']}, in_position={in_position}") + break # Exit the 15m loop after entry + + #self.logger.info(f"[{curr_ist_time_2}] Position check: position={self.position}, in_position={in_position}, entry_row={entry_row is not None}") + # EXIT + # TRADE MONITORING WHILE IN POSITION + if self.position != 0 and entry_row is not None and in_position: + # Find the corresponding minute in df_min after our entry + min_entries = df_min_today[df_min_today['time'] >= entry_time] + + for i in range(len(min_entries)): + curr_min = min_entries.iloc[i] + curr_min_ist_time = curr_min['time'].astimezone(IST).time() + curr_min_ist_time_2 = curr_min['time'].astimezone(IST) + price = curr_min['close'] + entry_price = entry_row['close'] + exit_row = None + exit_reason = None + #self.logger.info(f"[{curr_ist_time}] TRADE MONITORING: Position: {self.position}, Entry Price: {entry_price}, Current Price: {price}") + + is_long = self.position == 1 + is_short = self.position == -1 + price_change = price - entry_price + pct_move = price_change / entry_price + + # Targets + tp_hit = pct_move >= self.tp_pct if is_long else pct_move <= -self.tp_pct + sl_hit = pct_move <= -self.sl_pct if is_long else pct_move >= self.sl_pct + trail_trigger_pct = self.trail_activation_pct if is_long else -self.trail_activation_pct + trail_increment = self.trail_increment_pct * entry_price + trail_gap = self.trail_stop_gap_pct * entry_price + + #self.logger.info(f"[{curr_ist_time}] Evaluating exit for {'SHORT' if self.position == -1 else 'LONG'} at price {price}") + #self.logger.info(f"[{curr_ist_time_2}] Evaluating exit for {'SHORT' if self.position == -1 else 'LONG'} at price {price}") + + # TP + if tp_hit: + exit_row = curr_min + exit_reason = "TP" + self.logger.info(f"[{curr_min_ist_time_2}] TP HIT at {price}") + + # SL + elif sl_hit: + exit_row = curr_min + exit_reason = "SL" + self.logger.info(f"[{curr_min_ist_time_2}] SL HIT at {price}") + + # Trailing Stop + elif trailing_active: + # Price moved enough to adjust trail + if abs(price - last_trail_price) >= trail_increment: + if is_long: + trail_stop += trail_increment + else: + trail_stop -= trail_increment + last_trail_price = price + trail_history.append({ + 'time': curr_min_ist_time_2, + 'value': trail_stop + }) + + # Price hit trailing stop + if (is_long and price <= trail_stop) or (is_short and price >= trail_stop): + exit_row = curr_min + exit_reason = "TRAIL" + self.logger.info(f"[{curr_min_ist_time_2}] TRAIL STOP HIT at {price} vs stop {trail_stop}") + + # Activate trailing + elif (pct_move >= self.trail_activation_pct if is_long else pct_move <= -self.trail_activation_pct): + trailing_active = True + trail_stop = price - trail_gap if is_long else price + trail_gap + last_trail_price = price + + # Strategy exit signal + # elif (crossunder if is_long else crossover): + # exit_row = curr + # exit_reason = "CROSSOVER" if is_long else "CROSSUNDER" + # self.logger.info(f"[{curr_ist_time_2}] CROSSOVER HIT for SHORT at EMA fast {curr['ema_fast']} vs slow {curr['ema_slow']}") + + # Force exit based on direction and time (in UTC) + if exit_row is None: + #self.logger.info(f"[{curr_ist_time_2}] No exit condition met, checking force exit conditions") + if is_long and curr_min_ist_time >= time(15, 9): # 3:10 PM IST + exit_row = curr_min + exit_reason = "FORCE_EXIT_1510" + self.logger.info(f"[{curr_min_ist_time_2}] FORCE EXIT triggered for {'LONG' if is_long else 'SHORT'}") + + elif is_short and curr_min_ist_time >= time(14, 39): # 2:40 PM IST + exit_row = curr_min + exit_reason = "FORCE_EXIT_1440" + self.logger.info(f"[{curr_min_ist_time_2}] FORCE EXIT triggered for {'LONG' if is_long else 'SHORT'}") + + # Process exit + if exit_row is not None: + self.logger.info(f"[{curr_min_ist_time_2}] EXITING TRADE: {exit_reason} at {price}") + trade_df = df_min_today[(df_min_today['time'] >= entry_time) & (df_min_today['time'] <= curr_min['time'])] + trade_close = trade_df['close'] + + mae = (trade_close.min() - entry_price if is_long else entry_price - trade_close.max()) if not trade_df.empty else 0 + mfe = (trade_close.max() - entry_price if is_long else entry_price - trade_close.min()) if not trade_df.empty else 0 + pnl = (price - entry_price if is_long else entry_price - price) + + trades.append({ + 'symbol': self.symbol, + 'strategy': strategy_id, + 'quantity': quantity, + 'entry_time': entry_time.astimezone(IST).strftime("%Y-%m-%d %H:%M:%S"), + 'entry_price': round(entry_price, 2), + 'exit_time': exit_row['time'].astimezone(IST).strftime("%Y-%m-%d %H:%M:%S"), + 'exit_price': round(price, 2), + 'exit_reason': exit_reason, + 'gross_pnl': round(pnl * quantity, 2), + 'cumulative_pnl': 0, # Filled later + 'mae': round(mae * quantity, 2), + 'mfe': round(mfe * quantity, 2), + 'holding_period': str(exit_row['time'] - entry_row['time']), + 'direction': "LONG" if is_long else "SHORT", + 'capital_used': round(quantity * price, 2), + 'tax': round(quantity * price * 2 * 0.0002, 2), + 'brokerage': 0, + 'net_pnl': round((pnl * quantity) - (2 * quantity * price * 0.0002) - 0, 2), + }) + + if exit_reason == "TRAIL": + self.trailing_chart_data.append({ + 'symbol': self.symbol, + 'entry_time': entry_time, + 'exit_time': exit_row['time'], + 'entry_price': price, + 'df': df_min_today.copy(), + 'trail_stops': trail_history + }) + + # Reset + self.position = 0 + self.entry_price = None + entry_time = None + entry_row = None + trailing_active = False + trail_stop = None + last_trail_price = None + in_position = False + break # only one trade per symbol per strategy per day + + self.trades = pd.DataFrame(trades) + return self.trades + + def export_trail_charts(self, output_dir="trail_charts"): + os.makedirs(output_dir, exist_ok=True) + + for i, trade in enumerate(self.trailing_chart_data): + df = trade['df'] + entry_time = trade['entry_time'].tz_convert('Asia/Kolkata') + exit_time = trade['exit_time'].tz_convert('Asia/Kolkata') + entry_price = trade['entry_price'] + symbol = trade['symbol'] + trail_stops = trade['trail_stops'] + + trade_df = df[(df['time'] >= entry_time) & (df['time'] <= exit_time)].copy() + if trade_df.empty: + continue + + df['time'] = df['time'].dt.tz_convert('Asia/Kolkata') + trade_df['time'] = trade_df['time'].dt.tz_convert('Asia/Kolkata') + plt.figure(figsize=(12, 6)) + plt.plot(trade_df['time'], trade_df['close'], label='Price', color='blue') + plt.axhline(entry_price, color='green', linestyle='--', label='Entry Price') + plt.axvline(entry_time, color='green', linestyle=':', label='Entry Time') + plt.axvline(exit_time, color='red', linestyle=':', label='Exit Time') + + if trail_stops: + trail_df = pd.DataFrame(trail_stops) + plt.plot(trail_df['time'], trail_df['value'], color='orange', linestyle='--', label='Trailing Stop') + + plt.title(f"{symbol} TRAIL Exit [{entry_time} β†’ {exit_time}]") + plt.xlabel("Time") + plt.ylabel("Price") + plt.legend() + plt.grid(True) + plt.tight_layout() + + filename = f"{symbol}_trail_{entry_time.strftime('%Y%m%d_%H%M%S')}.png" + plt.savefig(os.path.join(output_dir, filename)) + plt.close() + + print(f"πŸ“ˆ Saved trailing chart: {filename}") + diff --git a/strategies/depth_example.py b/strategies/depth_example.py index 59952d31..b4040124 100644 --- a/strategies/depth_example.py +++ b/strategies/depth_example.py @@ -7,7 +7,7 @@ # Initialize feed client with explicit parameters client = api( - api_key="7653f710c940cdf1d757b5a7d808a60f43bc7e9c0239065435861da2869ec0fc", # Replace with your API key + api_key="8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c", # Replace with your API key host="http://127.0.0.1:5000", # Replace with your API host ws_url="ws://127.0.0.1:8765" # Explicit WebSocket URL (can be different from REST API host) ) diff --git a/strategies/ema_crossover.py b/strategies/ema_crossover.py index 1579cd03..689eceda 100644 --- a/strategies/ema_crossover.py +++ b/strategies/ema_crossover.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta # Get API key from openalgo portal -api_key = 'your-openalgo-api-key' +api_key = '8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c' # Set the strategy details and trading parameters diff --git a/strategies/ltp_example.py b/strategies/ltp_example.py index 5e960ce1..0b44dbc0 100644 --- a/strategies/ltp_example.py +++ b/strategies/ltp_example.py @@ -7,7 +7,7 @@ # Initialize feed client with explicit parameters client = api( - api_key="7653f710c940cdf1d757b5a7d808a60f43bc7e9c0239065435861da2869ec0fc", # Replace with your API key + api_key="8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c", # Replace with your API key host="http://127.0.0.1:5000", # Replace with your API host ws_url="ws://127.0.0.1:8765" # Explicit WebSocket URL (can be different from REST API host) ) diff --git a/strategies/quote_example.py b/strategies/quote_example.py index 58c9dbfa..c530c0b9 100644 --- a/strategies/quote_example.py +++ b/strategies/quote_example.py @@ -7,15 +7,15 @@ # Initialize feed client with explicit parameters client = api( - api_key="7653f710c940cdf1d757b5a7d808a60f43bc7e9c0239065435861da2869ec0fc", # Replace with your API key + api_key="8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c", # Replace with your API key host="http://127.0.0.1:5000", # Replace with your API host ws_url="ws://127.0.0.1:8765" # Explicit WebSocket URL (can be different from REST API host) ) # MCX instruments for testing instruments_list = [ - {"exchange": "NSE_INDEX", "symbol": "NIFTY"}, - {"exchange": "NSE", "symbol": "INFY"}, + #{"exchange": "NSE_INDEX", "symbol": "NIFTY"}, + {"exchange": "NSE", "symbol": "RELIANCE"}, {"exchange": "NSE", "symbol": "TCS"} ] @@ -28,10 +28,10 @@ def on_data_received(data): client.subscribe_quote(instruments_list, on_data_received=on_data_received) # Poll Quote data a few times -for i in range(100): - print(f"\nPoll {i+1}:") - print(client.get_quotes()) - time.sleep(0.5) +for i in range(3600): + #print(f"\nPoll {i+1}:") + #print(client.get_quotes()) + time.sleep(1) # Cleanup client.unsubscribe_quote(instruments_list) diff --git a/strategies/supertrend.py b/strategies/supertrend.py index c006a435..9e2c60b5 100644 --- a/strategies/supertrend.py +++ b/strategies/supertrend.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta # Get API key from openalgo portal -api_key = 'your-openalgo-api-key' +api_key = '8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c' # Set the strategy details and trading parameters strategy = "Supertrend Python" @@ -81,7 +81,7 @@ def supertrend_strategy(): try: # Dynamic date range: 7 days back to today end_date = datetime.now().strftime("%Y-%m-%d") - start_date = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") + start_date = (datetime.now() - timedelta(days=30)).strftime("%Y-%m-%d") # Fetch 1-minute historical data using OpenAlgo df = client.history( @@ -163,7 +163,7 @@ def supertrend_strategy(): continue # Wait before the next cycle - time.sleep(15) + time.sleep(5) if __name__ == "__main__": print("Starting Supertrend Strategy...") diff --git a/strategies/timescaledb.py b/strategies/timescaledb.py new file mode 100644 index 00000000..f0253f7f --- /dev/null +++ b/strategies/timescaledb.py @@ -0,0 +1,1907 @@ +import psycopg2 +from psycopg2 import sql +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +from psycopg2.extras import execute_batch +from psycopg2 import pool +from kafka import KafkaConsumer +from kafka.errors import KafkaError +import pytz +import json +import logging +from threading import Lock +from concurrent.futures import ThreadPoolExecutor, as_completed +from itertools import product, groupby +from operator import itemgetter +from datetime import datetime, timedelta +import os +import random +from dateutil import parser # For flexible ISO date parsing +import traceback +import argparse +import signal +import sys +from openalgo import api +import pandas as pd +from backtest_engine import BacktestEngine +from trading_engine import LiveTradingEngine +import glob +from concurrent.futures import ThreadPoolExecutor +import time +from tabulate import tabulate +from colorama import Fore, Back, Style, init +from textwrap import wrap +from collections import defaultdict +import shutil + +# Initialize colorama +init(autoreset=True) + +# Suppress User warnings in output +from warnings import filterwarnings +filterwarnings("ignore", category=UserWarning) + +from dotenv import load_dotenv + +load_dotenv() + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Check if Kafka is available +try: + KAFKA_AVAILABLE = True +except ImportError: + KAFKA_AVAILABLE = False + logger.warning("Kafka library not available. Install with: pip install kafka-python") + +class TimescaleDBManager: + def __init__(self, dbname=os.getenv('TIMESCALE_DB_NAME'), dbname_live=os.getenv('TIMESCALE_DB_NAME_LIVE'), user=os.getenv('TIMESCALE_DB_USER'), password=os.getenv('TIMESCALE_DB_PASSWORD'), host=os.getenv('TIMESCALE_DB_HOST'), port=os.getenv('TIMESCALE_DB_PORT')): + self.dbname = dbname + self.dbname_live = dbname_live + self.user = user + self.password = password + self.host = host + self.port = port + self.admin_conn = None + self.app_conn = None + self.logger = logging.getLogger(f"TimeScaleDBManager") + + self.logger.info(f"Initializing TimescaleDB connection to {self.host}:{self.port} as user '{self.user}' for database '{self.dbname}'") + + def _get_admin_connection(self): + """Connection without specifying database (for admin operations)""" + try: + conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname='postgres' # Connect to default admin DB + ) + # Set autocommit mode for DDL operations like CREATE DATABASE + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + return conn + except psycopg2.Error as e: + self.logger.error(f"Failed to connect to PostgreSQL server: {e}") + raise + + def _database_exists(self, dbname): + """Check if database exists""" + try: + with self._get_admin_connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", + (dbname,) + ) + return cursor.fetchone() is not None + except Exception as e: + self.logger.error(f"Error checking database existence: {e}") + return False + + + def _create_database(self, dbname): + """Create new database with TimescaleDB extension""" + try: + self.logger.info(f"Creating database '{dbname}'...") + + # Create database with autocommit connection + conn = self._get_admin_connection() + try: + with conn.cursor() as cursor: + # Create database + cursor.execute( + sql.SQL("CREATE DATABASE {}").format( + sql.Identifier(dbname) + ) + ) + self.logger.info(f"Database '{dbname}' created successfully") + finally: + conn.close() + + # Connect to new database to install extensions + self.logger.info("Installing TimescaleDB extension...") + conn_newdb = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname=dbname + ) + try: + with conn_newdb.cursor() as cursor_new: + cursor_new.execute("CREATE EXTENSION IF NOT EXISTS timescaledb") + conn_newdb.commit() + self.logger.info("TimescaleDB extension installed successfully") + finally: + conn_newdb.close() + + self.logger.info(f"Created database {dbname} with TimescaleDB extension") + return True + + except psycopg2.Error as e: + self.logger.error(f"PostgreSQL error creating database: {e}") + return False + except Exception as e: + self.logger.error(f"Error creating database: {e}") + return False + + + def _create_tables(self, dbname): + """Create required tables and hypertables""" + commands = [ + """ + CREATE TABLE IF NOT EXISTS ticks ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ticks', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_1m ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_1m', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_5m ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + atr_10 DECIMAL(18, 2), + volume_10 BIGINT, + atr_14 DECIMAL(18, 2), + volume_14 BIGINT, + nifty_trend_15m INT, + curbot DECIMAL(18, 2), + curtop DECIMAL(18, 2), + cum_intraday_volume BIGINT, + strategy_8 INT, + strategy_9 INT, + strategy_10 INT, + strategy_11 INT, + strategy_12 INT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_5m', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_15m ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + atr_10 DECIMAL(18, 2), + volume_10 BIGINT, + atr_14 DECIMAL(18, 2), + volume_14 BIGINT, + nifty_trend_15m INT, + curbot DECIMAL(18, 2), + curtop DECIMAL(18, 2), + cum_intraday_volume BIGINT, + strategy_8 INT, + strategy_9 INT, + strategy_10 INT, + strategy_11 INT, + strategy_12 INT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_15m', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_1h ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_1h', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS ohlc_D ( + time TIMESTAMPTZ NOT NULL, + symbol VARCHAR(20) NOT NULL, + open DECIMAL(18, 2), + high DECIMAL(18, 2), + low DECIMAL(18, 2), + close DECIMAL(18, 2), + volume BIGINT, + PRIMARY KEY (time, symbol) + ) + """, + """ + SELECT create_hypertable('ohlc_D', 'time', if_not_exists => TRUE) + """, + """ + CREATE TABLE IF NOT EXISTS trades ( + id SERIAL PRIMARY KEY, + symbol VARCHAR(20) NOT NULL, + strategy VARCHAR(10) NOT NULL, + quantity INTEGER NOT NULL, + entry_time TIMESTAMPTZ NOT NULL, + entry_price DECIMAL(18, 2) NOT NULL, + exit_time TIMESTAMPTZ NOT NULL, + exit_price DECIMAL(18, 2) NOT NULL, + exit_reason VARCHAR(20) NOT NULL, + gross_pnl DECIMAL(18, 2) NOT NULL, + direction VARCHAR(10) NOT NULL, + capital_used DECIMAL(18, 2) NOT NULL, + tax DECIMAL(18, 2) NOT NULL, + brokerage DECIMAL(18, 2) NOT NULL, + net_pnl DECIMAL(18, 2) NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() + ) + """, + """CREATE INDEX IF NOT EXISTS idx_ticks_symbol_time ON ticks (time, symbol)""", + """CREATE INDEX IF NOT EXISTS idx_ohlc_1m_symbol_time ON ohlc_1m (time, symbol)""", + """CREATE INDEX IF NOT EXISTS idx_ohlc_5m_symbol_time ON ohlc_5m (time, symbol)""", + """CREATE INDEX IF NOT EXISTS idx_ohlc_15m_symbol_time ON ohlc_15m (time, symbol)""", + """CREATE INDEX IF NOT EXISTS idx_ohlc_1h_symbol_time ON ohlc_1h (time, symbol)""", + """CREATE INDEX IF NOT EXISTS idx_ohlc_d_symbol_time ON ohlc_D (time, symbol)""" + ] + + try: + conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname=dbname + ) + try: + with conn.cursor() as cursor: + for i, command in enumerate(commands): + try: + cursor.execute(command) + self.logger.debug(f"Executed command {i+1}/{len(commands)}") + except psycopg2.Error as e: + # Skip hypertable creation if table already exists as hypertable + if "already a hypertable" in str(e): + self.logger.info(f"Table already exists as hypertable, skipping: {e}") + continue + else: + self.logger.error(f"Error executing command {i+1}: {e}") + self.logger.error(f"Command was: {command}") + raise + conn.commit() + self.logger.info("Created tables and hypertables successfully") + finally: + conn.close() + + except psycopg2.Error as e: + self.logger.error(f"PostgreSQL error creating tables: {e}") + raise + except Exception as e: + self.logger.error(f"Error creating tables: {e}") + raise + + def test_connection(self): + """Test database connection""" + try: + conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname='postgres' # Test with default database first + ) + conn.close() + self.logger.info("Database connection test successful") + return True + except psycopg2.Error as e: + self.logger.error(f"Database connection test failed: {e}") + return False + + def initialize_database(self, dbname): + """Main initialization method""" + # Test connection first + if not self.test_connection(): + raise RuntimeError("Cannot connect to PostgreSQL server. Check your connection parameters.") + + if not self._database_exists(dbname): + self.logger.info(f"Database {dbname} not found, creating...") + if not self._create_database(dbname): + raise RuntimeError("Failed to create database") + else: + self.logger.info(f"Database {dbname} already exists") + + self._create_tables(dbname) + + # Return an application connection + try: + app_conn = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname=dbname + ) + self.logger.info("Database connection established successfully") + return app_conn + + except psycopg2.Error as e: + self.logger.error(f"Database connection failed: {e}") + raise + except Exception as e: + self.logger.error(f"Database connection failed: {e}") + raise + +# Integration with your existing code +class MarketDataProcessor: + def __init__(self): + # Initialize TimescaleDBManager and connect to the database + self.db_manager = TimescaleDBManager() + self.db_conn = self.db_manager.initialize_database('openalgo') + self.db_live_conn = self.db_manager.initialize_database('openalgo_live') + self.logger = logging.getLogger(f"MarketDataProcessor") + self.interrupt_flag = False # Add interrupt flag + + self.consumer = KafkaConsumer( + 'tick_data', + bootstrap_servers='localhost:9092', + group_id='tick-processor', + auto_offset_reset='earliest' + #key_deserializer=lambda k: k.decode('utf-8') if k else None, + #value_deserializer=lambda v: json.loads(v.decode('utf-8')) + ) + + self.logger.info("Starting consumer with configuration:") + self.logger.info(f"Group ID: {self.consumer.config['group_id']}") + self.logger.info(f"Brokers: {self.consumer.config['bootstrap_servers']}") + + self.aggregation_lock = Lock() + + self.executor = ThreadPoolExecutor(max_workers=8) + + # Initialize aggregation buffers + self.reset_aggregation_buffers() + + # volume tracking + # Initialize volume tracking for all timeframes + self.last_period_volume = { + '1m': {}, + '5m': {}, + '15m': {} + } + + # self.client = api( + # api_key="8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c", # Replace with your API key + # host="http://127.0.0.1:5000" + # ) + + # # Initialize live trading engine + # self.trading_engine = LiveTradingEngine( + # api_key=API_CONFIG['api_key'], + # host=API_CONFIG['host'], + # db_config=DB_CONFIG, + # symbols=symbols + # ) + + def clean_database(self): + """Clear all records from all tables in the database""" + try: + self.logger.info("Cleaning database tables...") + tables = ['ticks', 'ohlc_1m', 'ohlc_5m', 'ohlc_15m', 'ohlc_1h', 'ohlc_D'] # Add all your table names here + + with self.db_live_conn.cursor() as cursor: + # Disable triggers temporarily to avoid hypertable constraints + cursor.execute("SET session_replication_role = 'replica';") + + for table in tables: + try: + cursor.execute(f"TRUNCATE TABLE {table} CASCADE;") + self.logger.info(f"Cleared table: {table}") + except Exception as e: + self.logger.error(f"Error clearing table {table}: {e}") + self.db_live_conn.rollback() + continue + + # Re-enable triggers + cursor.execute("SET session_replication_role = 'origin';") + self.db_live_conn.commit() + + self.logger.info("Database cleaning completed successfully") + return True + + except Exception as e: + self.logger.error(f"Database cleaning failed: {e}") + self.db_live_conn.rollback() + return False + + + def insert_historical_data(self, df, symbol, interval, mode): + """ + Insert historical data into the appropriate database table + + Args: + df (pd.DataFrame): DataFrame containing historical data + symbol (str): Stock symbol (e.g., 'RELIANCE') + interval (str): Time interval ('1m', '5m', '15m', '1d') + """ + try: + if df.empty: + self.logger.warning(f"No data to insert for {symbol} {interval}") + return False + + # Reset index to make timestamp a column + df = df.reset_index() + + # Rename columns to match database schema + df = df.rename(columns={ + 'timestamp': 'time', + 'open': 'open', + 'high': 'high', + 'low': 'low', + 'close': 'close', + 'volume': 'volume' + }) + + # Handle timezone conversion differently for intraday vs daily data + df['time'] = pd.to_datetime(df['time']) + if interval == 'D': + # Set to market open time (09:15:00 IST) for each date + df['time'] = df['time'].dt.tz_localize(None) # Remove any timezone + df['time'] = df['time'] + pd.Timedelta(hours=9, minutes=15) + df['time'] = df['time'].dt.tz_localize('Asia/Kolkata') + else: + if df['time'].dt.tz is None: + df['time'] = df['time'].dt.tz_localize('Asia/Kolkata') + else: + df['time'] = df['time'].dt.tz_convert('Asia/Kolkata') + + # Convert to UTC for database storage + df['time'] = df['time'].dt.tz_convert('UTC') + + # Add symbol column + df['symbol'] = symbol + + # Select and order the columns we need (excluding 'oi' which we don't store) + required_columns = ['time', 'symbol', 'open', 'high', 'low', 'close', 'volume'] + df = df[required_columns] + + # Convert numeric columns to appropriate types + numeric_cols = ['open', 'high', 'low', 'close'] + df[numeric_cols] = df[numeric_cols].astype(float) + df['volume'] = df['volume'].astype(int) + + # Determine the target table based on interval + table_name = f'ohlc_{interval.lower()}' + + # Convert DataFrame to list of tuples + records = [tuple(x) for x in df.to_numpy()] + + # Debug: print first record to verify format + self.logger.debug(f"First record sample: {records[0] if records else 'No records'}") + + if mode == "backtest": + conn = self.db_conn + elif mode == "live": + conn = self.db_live_conn + + with conn.cursor() as cursor: + # Use execute_batch for efficient bulk insertion + execute_batch(cursor, f""" + INSERT INTO {table_name} + (time, symbol, open, high, low, close, volume) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (time, symbol) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume + """, records) + + conn.commit() + self.logger.info(f"Successfully inserted {len(df)} records for {symbol} ({interval}) into {table_name}") + return True + + except KeyError as e: + self.logger.error(f"Missing required column in data for {symbol} {interval}: {e}") + self.logger.error(f"Available columns: {df.columns.tolist()}") + return False + except Exception as e: + self.logger.error(f"Error inserting historical data for {symbol} {interval}: {e}") + self.logger.error(traceback.format_exc()) + conn.rollback() + return False + + + def reset_aggregation_buffers(self): + """Initialize/reset aggregation buffers""" + with self.aggregation_lock: + self.tick_buffer = { + '1m': {}, + '5m': {}, + '15m': {} + } + now = datetime.now(pytz.utc) + self.last_agg_time = { + '1m': self.floor_to_interval(now, 1), + '5m': self.floor_to_interval(now, 5), + '15m': self.floor_to_interval(now, 15) + } + self.aggregation_state = { + '1m': {}, + '5m': {}, + '15m': {} + } + + # Reset volume tracking + self.last_period_volume = { + '1m': {}, + '5m': {}, + '15m': {} + } + + def group_missing_dates(self, missing_dates): + """ + Groups missing dates into continuous ranges. + + Example: + [2025-05-01, 2025-05-02, 2025-05-03, 2025-05-05] + β†’ [(2025-05-01, 2025-05-03), (2025-05-05, 2025-05-05)] + """ + sorted_dates = sorted(missing_dates) + ranges = [] + for _, g in groupby(enumerate(sorted_dates), lambda x: x[0] - x[1].toordinal()): + group = list(map(itemgetter(1), g)) + ranges.append((group[0], group[-1])) + return ranges + + def chunk_dates(self, start_date, end_date, chunk_size_days): + current = start_date + while current <= end_date: + next_chunk = min(current + timedelta(days=chunk_size_days - 1), end_date) + yield current, next_chunk + current = next_chunk + timedelta(days=1) + + def get_existing_dates(self, symbol, interval): + table_name = f"ohlc_d" + query = f""" + SELECT DISTINCT DATE(time AT TIME ZONE 'Asia/Kolkata') as trade_date + FROM {table_name} + WHERE symbol = %s + ORDER BY trade_date; + """ + try: + with self.db_conn.cursor() as cursor: + cursor.execute(query, (symbol,)) + rows = cursor.fetchall() + return set(row[0] for row in rows) + except Exception as e: + self.logger.error(f"Error fetching existing dates: {e}") + return set() + + + def fetch_missing_data(self, symbol, interval, client, start_date, end_date, mode): + try: + existing_dates = self.get_existing_dates(symbol, interval) + all_dates = pd.date_range(start=start_date, end=end_date, freq='D').date + missing_dates = sorted(set(all_dates) - set(existing_dates)) + missing_ranges = self.group_missing_dates(missing_dates) + self.logger.info(f"Missing dates for {symbol} {interval}: {missing_ranges}") + + for range_start, range_end in missing_ranges: + condition_1 = range_start.weekday() in [5, 6] and range_end.weekday() in [5, 6] and (range_end - range_start).days <= 2 + condition_2 = range_start.weekday() in [0, 1, 2, 3, 4] and range_end.weekday() in [0, 1, 2, 3, 4] and (range_end - range_start).days == 0 + if (condition_1 or condition_2): # Skip weekends + continue + + # Add retry logic with exponential backoff + max_retries = 3 + for attempt in range(max_retries): + try: + df = client.history( + symbol=symbol, + exchange='NSE_INDEX' if symbol.startswith('NIFTY') else 'NSE', + interval=interval, + start_date=range_start.strftime('%Y-%m-%d'), + end_date=range_end.strftime('%Y-%m-%d') + ) + + # Check if df is dictionary (error response) + if isinstance(df, dict): + if 'timeout' in str(df.get('message', '')).lower(): + if attempt < max_retries - 1: + wait_time = (2 ** attempt) + random.uniform(0, 1) # Exponential backoff + self.logger.warning(f"[{symbol}] ⏳ Timeout on attempt {attempt + 1}, retrying in {wait_time:.1f}s...") + time.sleep(wait_time) + continue + self.logger.warning(f"[{symbol}] ⚠️ API Response error! No data on {range_start}") + self.logger.info(f"API Response: {df}") + break # Exit retry loop + + # Success - process the dataframe + if hasattr(df, 'empty') and not df.empty: + self.insert_historical_data(df, symbol, interval, mode) + else: + self.logger.warning(f"[{symbol}] ⚠️ Empty Dataframe! No data on {range_start}") + break # Exit retry loop on success + + except Exception as retry_e: + if attempt < max_retries - 1: + wait_time = (2 ** attempt) + random.uniform(0, 1) + self.logger.warning(f"[{symbol}] ⏳ Error on attempt {attempt + 1}: {retry_e}, retrying in {wait_time:.1f}s...") + time.sleep(wait_time) + else: + self.logger.error(f"[{symbol}] ❌ Failed after {max_retries} attempts: {retry_e}") + + # Add delay between requests to reduce server load + time.sleep(random.uniform(0.1, 0.3)) + + except Exception as e: + self.logger.error(f"[{symbol}] ❌ Error during fetch: {e}") + + + def fetch_historical_data(self, symbol, interval, client, start_date, end_date, mode): + try: + # Check for interrupt flag + if self.interrupt_flag: + self.logger.info(f"[{symbol}] πŸ›‘ Task interrupted before starting") + return + + # Add retry logic with exponential backoff + max_retries = 3 + for attempt in range(max_retries): + try: + # Check for interrupt before each retry + if self.interrupt_flag: + self.logger.info(f"[{symbol}] πŸ›‘ Task interrupted during retry {attempt + 1}") + return + df = client.history( + symbol=symbol, + exchange='NSE_INDEX' if symbol.startswith('NIFTY') else 'NSE', + interval=interval, + start_date=start_date, + end_date=end_date + ) + + # Check if df is dictionary (error response) + if isinstance(df, dict): + if 'timeout' in str(df.get('message', '')).lower(): + if attempt < max_retries - 1: + wait_time = (2 ** attempt) + random.uniform(0, 1) # Exponential backoff + self.logger.warning(f"[{symbol}] ⏳ Timeout on attempt {attempt + 1}, retrying in {wait_time:.1f}s...") + time.sleep(wait_time) + continue + self.logger.warning(f"[{symbol}] ⚠️ API Response error! No data on {start_date}") + self.logger.info(f"API Response: {df}") + return # Exit function + + # Success - process the dataframe + if hasattr(df, 'empty') and not df.empty: + self.insert_historical_data(df, symbol, interval, mode) + else: + self.logger.warning(f"[{symbol}] ⚠️ Empty Dataframe! No data on {start_date}") + return # Exit function on success + + except Exception as retry_e: + if attempt < max_retries - 1: + wait_time = (2 ** attempt) + random.uniform(0, 1) + self.logger.warning(f"[{symbol}] ⏳ Error on attempt {attempt + 1}: {retry_e}, retrying in {wait_time:.1f}s...") + time.sleep(wait_time) + else: + self.logger.error(f"[{symbol}] ❌ Failed after {max_retries} attempts: {retry_e}") + + # Add delay between requests to reduce server load + time.sleep(random.uniform(0.1, 0.3)) + + except Exception as e: + self.logger.error(f"[{symbol}] ❌ Error during fetch: {e}") + + def process_symbol_interval(self, symbol, interval, client, start_date, end_date, mode): + """Process a single symbol-interval pair""" + try: + if interval == "5m" or interval == "1m": + # Chunk the dates into smaller ranges to avoid timeout + s_d = datetime.strptime(start_date, "%Y-%m-%d").date() + e_d = datetime.strptime(end_date, "%Y-%m-%d").date() + #self.logger.info(f"Fetching data for {symbol} with interval {interval} from {s_d} to {e_d}") + + for chunk_start, chunk_end in self.chunk_dates(start_date=s_d, end_date=e_d, chunk_size_days=10): + self.fetch_historical_data( + symbol, + interval, + client, + chunk_start.strftime("%Y-%m-%d"), + chunk_end.strftime("%Y-%m-%d"), + mode + ) + else: + self.fetch_historical_data(symbol, interval, client, start_date, end_date, mode) + except Exception as e: + self.logger.error(f"Error processing {symbol} {interval}: {str(e)}") + + + def process_messages(self): + """Main processing loop""" + + self.consumer.subscribe(['tick_data']) + self.logger.info("Started listening messages...") + + try: + while True: + raw_msg = self.consumer.poll(1000.0) + self.logger.info(f"\n\n\n\nReceived messages: {raw_msg}") + + if raw_msg is None: + self.logger.info("No messages received during timeout period ----------->") + continue + + for topic_partition, messages in raw_msg.items(): + for message in messages: + try: + # Extract key and value + key = message.key.decode('utf-8') # 'NSE_RELIANCE_LTP' + value = json.loads(message.value.decode('utf-8')) + + self.logger.info(f"Processing {key}: {value['symbol']}@{value['close']}") + + # Process the message + self.process_single_message(key, value) + + except Exception as e: + self.logger.error(f"Error processing message: {e}") + + except KeyboardInterrupt: + self.logger.info("Kafka Consumer Shutting down...") + finally: + self.shutdown() + + def _handle_kafka_error(self, error): + """Handle Kafka protocol errors""" + error_codes = { + KafkaError._PARTITION_EOF: "End of partition", + KafkaError.UNKNOWN_TOPIC_OR_PART: "Topic/partition does not exist", + KafkaError.NOT_COORDINATOR_FOR_GROUP: "Coordinator changed", + KafkaError.ILLEGAL_GENERATION: "Consumer group rebalanced", + KafkaError.UNKNOWN_MEMBER_ID: "Member ID expired" + } + + if error.code() in error_codes: + self.logger.warning(error_codes[error.code()]) + else: + self.logger.error(f"Kafka error [{error.code()}]: {error.str()}") + + + def process_single_message(self, key, value): + """Process extracted tick data""" + try: + # Extract components from key + components = key.split('_') + exchange = components[0] # 'NSE' + symbol = components[1] # 'RELIANCE' + data_type = components[2] # 'LTP' or 'QUOTE' + + # Convert timestamp (handling milliseconds since epoch) + timestamp = value['timestamp'] + if not isinstance(timestamp, (int, float)): + raise ValueError(f"Invalid timestamp type: {type(timestamp)}") + + # Convert to proper datetime object + # Ensure milliseconds (not seconds or microseconds) + if timestamp < 1e12: # Likely in seconds + timestamp *= 1000 + elif timestamp > 1e13: # Likely in microseconds + timestamp /= 1000 + + dt = datetime.fromtimestamp(timestamp / 1000, tz=pytz.UTC) + + # Validate date range + if dt.year < 2020 or dt.year > 2030: + raise ValueError(f"Implausible date {dt} from timestamp {timestamp}") + + # Prepare database record + record = { + 'time': dt, # Convert ms to seconds + 'symbol': symbol, + 'open': float(value['ltp']), + 'high': float(value['ltp']), + 'low': float(value['ltp']), + 'close': float(value['ltp']), + 'volume': int(value['volume']) + } + + #self.logger.info(f"Record---------> {record}") + + # Store in TimescaleDB + self.store_tick(record) + + # Add to aggregation buffers + self.buffer_tick(record) + + # Check for aggregation opportunities + self.check_aggregation(record['time']) + + except Exception as e: + self.logger.error(f"Tick processing failed: {e}") + self.logger.debug(traceback.format_exc()) + + + def store_tick(self, record): + """Store raw tick in database""" + try: + with self.db_live_conn.cursor() as cursor: + cursor.execute(""" + INSERT INTO ticks (time, symbol, open, high, low, close, volume) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (time, symbol) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume + """, (record['time'], record['symbol'], record['open'], record['high'], record['low'], record['close'], record['volume'])) + self.db_live_conn.commit() + except Exception as e: + logger.error(f"Error storing tick: {e}") + self.db_live_conn.rollback() + + def buffer_tick(self, record): + """Add tick to aggregation buffers""" + with self.aggregation_lock: + for timeframe in ['1m', '5m', '15m']: + minutes = int(timeframe[:-1]) + symbol = record['symbol'] + aligned_time = self.floor_to_interval(record['time'], minutes) + + if symbol not in self.tick_buffer[timeframe]: + self.tick_buffer[timeframe][symbol] = {} + + # Initialize this specific minute bucket + if aligned_time not in self.tick_buffer[timeframe][symbol]: + self.tick_buffer[timeframe][symbol][aligned_time] = { + 'opens': [], + 'highs': [], + 'lows': [], + 'closes': [], + 'volumes': [], + 'first_tick': None # Track the first tick separately + } + + bucket = self.tick_buffer[timeframe][symbol][aligned_time] + + # For the first tick in this interval, store it separately + if bucket['first_tick'] is None: + bucket['first_tick'] = record + + bucket['opens'].append(record['open']) + bucket['highs'].append(record['high']) + bucket['lows'].append(record['low']) + bucket['closes'].append(record['close']) + bucket['volumes'].append(record['volume']) + + def check_aggregation(self, current_time): + """Check if aggregation should occur for any timeframe""" + timeframes = ['1m', '5m', '15m'] + + for timeframe in timeframes: + agg_interval = timedelta(minutes=int(timeframe[:-1])) + last_agg = self.last_agg_time[timeframe] + + self.logger.info(f"{timeframe}: current_time={current_time}, last_agg={last_agg}, interval={agg_interval}") + + if current_time - last_agg >= agg_interval: + if self.aggregate_data(timeframe, current_time): + self.last_agg_time[timeframe] = self.floor_to_interval(current_time, int(timeframe[:-1])) + + + def floor_to_interval(self, dt, minutes=1): + """Floor a datetime to the start of its minute/5m/15m interval""" + discard = timedelta( + minutes=dt.minute % minutes, + seconds=dt.second, + microseconds=dt.microsecond + ) + return dt - discard + + def aggregate_data(self, timeframe, agg_time): + with self.aggregation_lock: + symbol_buckets = self.tick_buffer[timeframe] + if not symbol_buckets: + return False + + aggregated = [] + table_name = f"ohlc_{timeframe}" + + for symbol, buckets in symbol_buckets.items(): + for bucket_start, data in list(buckets.items()): + if bucket_start >= self.last_agg_time[timeframe] + timedelta(minutes=int(timeframe[:-1])): + # Don't process future buckets + continue + + if not data['opens']: + continue + + try: + # Get OHLC values + if data['first_tick'] is not None: + open_ = data['first_tick']['open'] + else: + open_ = data['opens'][0] + + #open_ = data['opens'][0] + high = max(data['highs']) + low = min(data['lows']) + close = data['closes'][-1] + + # Calculate volume correctly for cumulative data + current_last_volume = data['volumes'][-1] + previous_last_volume = self.last_period_volume[timeframe].get(symbol, current_last_volume) + volume = max(0, current_last_volume - previous_last_volume) + + # Store the current last volume for next period + self.last_period_volume[timeframe][symbol] = current_last_volume + + candle = { + 'time': bucket_start, + 'symbol': symbol, + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + } + + aggregated.append(candle) + + # Notify trading engine of new candle + # self.trading_engine.on_new_candle(symbol, timeframe, candle) + + # Remove this bucket to avoid re-aggregation + del self.tick_buffer[timeframe][symbol][bucket_start] + + except Exception as e: + self.logger.error(f"Error aggregating {symbol} for {timeframe}: {e}") + continue + + if aggregated: + try: + with self.db_live_conn.cursor() as cursor: + execute_batch(cursor, f""" + INSERT INTO {table_name} + (time, symbol, open, high, low, close, volume) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (time, symbol) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume + """, [(c['time'], c['symbol'], c['open'], c['high'], c['low'], c['close'], c['volume']) for c in aggregated]) + self.db_live_conn.commit() + self.logger.info(f"Aggregated {len(aggregated)} symbols to {table_name}") + return True + except Exception as e: + self.logger.error(f"Error aggregating {timeframe} data: {e}") + self.db_live_conn.rollback() + return False + return False + + def shutdown(self): + """Clean shutdown""" + logger.info("Shutting down processors") + self.trading_engine.stop() + self.executor.shutdown(wait=True) + self.consumer.close() + self.db_conn.close() + self.db_live_conn.close() + logger.info("Clean shutdown complete") + +if __name__ == "__main__": + # Global variables for signal handling + processor = None + interrupt_requested = False + + def signal_handler(signum, frame): + """Handle Ctrl+C interrupt gracefully""" + global interrupt_requested, processor + if interrupt_requested: + print("\n🚨 Force quit! Terminating immediately...") + sys.exit(1) + + interrupt_requested = True + print("\nπŸ›‘ Interrupt signal received! Stopping gracefully...") + print(" Press Ctrl+C again to force quit.") + + if processor: + processor.interrupt_flag = True + + # Set up signal handler + signal.signal(signal.SIGINT, signal_handler) + + client = api( + api_key="8009e08498f085ff1a3e7da718c5f4b585eaf9c2b7ce0c72740ab2b5d283d36c", # Replace with your API key + host="http://127.0.0.1:5000" + ) + # Start the timer + start_time = time.time() + + # Argument parsing + parser = argparse.ArgumentParser(description='Market Data Processor') + parser.add_argument('--mode', type=str, choices=['live', 'backtest'], required=True, + help='Run mode: "live" for live processing, "backtest" for backtesting') + + parser.add_argument('--from_date', type=str, + help='Start date for backtest (DD-MM-YYYY format)') + parser.add_argument('--to_date', type=str, + help='End date for backtest (DD-MM-YYYY format)') + parser.add_argument('--backtest_folder', type=str, + help='Folder to store backtest data') + parser.add_argument('--live_folder', type=str, + help='Base output logs directory for live data') + args = parser.parse_args() + + # Validate arguments for Backtest + if args.mode == 'backtest': + if not args.from_date or not args.to_date or not args.backtest_folder: + parser.error("--from_date and --to_date are required in backtest mode") + + try: + from_date = datetime.strptime(args.from_date, '%d-%m-%Y').date() + to_date = datetime.strptime(args.to_date, '%d-%m-%Y').date() + + if from_date > to_date: + parser.error("--from_date cannot be after --to_date") + + except ValueError as e: + parser.error(f"Invalid date format. Please use DD-MM-YYYY. Error: {e}") + + # Validate arguments for live + if args.mode == 'live': + if not args.from_date or not args.live_folder: + parser.error("--from_date and live_folder is required in live mode") + + try: + from_date = datetime.strptime(args.from_date, '%d-%m-%Y').date() + + except ValueError as e: + parser.error(f"Invalid date format. Please use DD-MM-YYYY. Error: {e}") + + # Initialize the processor (update global variable for signal handler) + processor = MarketDataProcessor() + try: + if args.mode == 'live': + logger.info(f"Running in live mode for the date {args.from_date}") + + # Fetch the last 20 days historical data(1 min, 5 min, 15min, 1hr, D) and insert in the DB + start_date = (from_date - timedelta(days=20)).strftime("%Y-%m-%d") + end_date = from_date.strftime("%Y-%m-%d") + + # Cleaning the today's live folder + live_output_dir = args.live_folder + output_dir = os.path.join(live_output_dir) + os.makedirs(output_dir, exist_ok=True) + for filename in os.listdir(output_dir): + if filename.endswith(".csv"): + os.remove(os.path.join(output_dir, filename)) + + # Clean the database at the start of the intraday trading session(9:00 AM IST) + if datetime.now().hour == 9 and datetime.now().minute == 0: + processor.clean_database() + + # Import symbol list from CSV file + symbol_list = pd.read_csv('symbol_list.csv') + symbol_list = symbol_list['Symbol'].tolist() + + # Fetch historical data for each symbol + intervals = ["D", "15m", "5m", "1m"] + + # Create all combinations of (symbol, interval) + symbol_interval_pairs = list(product(symbol_list, intervals)) + + # Add "NIFTY" "15m" in the symbol_interval_pairs + symbol_interval_pairs.append(("NIFTY", "15m")) + + # Fetch historical data of each symbol for the past 20 days + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [] + try: + for symbol, interval in symbol_interval_pairs: + time.sleep(1.5) # Increased delay to reduce server pressure + futures.append( + executor.submit( + processor.process_symbol_interval, + symbol, + interval, + client, + start_date, + end_date, + "live" + ) + ) + + # Wait for all tasks to complete with proper interrupt handling + completed = 0 + total = len(futures) + + for future in futures: + try: + future.result(timeout=30) # 30 second timeout per task + completed += 1 + if completed % 10 == 0: # Progress update every 10 tasks + logger.info(f"πŸ“Š Progress: {completed}/{total} tasks completed") + except TimeoutError: + logger.warning(f"⏰ Task timed out, continuing with next task") + future.cancel() + except Exception as e: + logger.error(f"❌ Task failed: {e}") + + logger.info(f"βœ… All {total} data fetching tasks completed") + + except KeyboardInterrupt: + logger.warning("\nπŸ›‘ KeyboardInterrupt received! Stopping all tasks...") + + # Set interrupt flag to stop running tasks + processor.interrupt_flag = True + + # Cancel all pending futures + cancelled_count = 0 + for future in futures: + if future.cancel(): + cancelled_count += 1 + + logger.info(f"πŸ“ Cancelled {cancelled_count} pending tasks") + + # Force shutdown the executor + logger.info("πŸ”„ Shutting down executor...") + executor.shutdown(wait=False) + + # Give running tasks a moment to cleanup and check interrupt flag + logger.info("⏳ Waiting for running tasks to cleanup...") + time.sleep(3) + + logger.info("πŸ›‘ Data fetching interrupted by user") + raise # Re-raise to exit the program + + # Start the Trading Engine + processor.trading_engine.start() + + # Process the real-time data + processor.process_messages() + + elif args.mode == 'backtest': + logger.info(f"Running in backtest mode from {args.from_date} to {args.to_date}") + + end_date = to_date.strftime("%Y-%m-%d") + start_date = (from_date - timedelta(days=20)).strftime("%Y-%m-%d") + + # Cleaning the backtest results folder + base_output_dir = args.backtest_folder + output_dir = os.path.join(base_output_dir) + os.makedirs(output_dir, exist_ok=True) + for filename in os.listdir(output_dir): + if filename.endswith(".csv"): + os.remove(os.path.join(output_dir, filename)) + + # Copy source files for traceability + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + source_files = [ + ('backtest_engine.py', os.path.join(current_dir, 'backtest_engine.py')), + ('timescaledb.py', os.path.join(current_dir, 'timescaledb.py')) + ] + + logger.info(f"πŸ“ Copying source files to {output_dir} for traceability...") + + for filename, source_path in source_files: + if os.path.exists(source_path): + # Include timestamp in filename for versioning + base_name = filename.replace('.py', '') + dest_filename = f"source_{base_name}_{timestamp}.py" + dest_path = os.path.join(output_dir, dest_filename) + + shutil.copy2(source_path, dest_path) + logger.info(f"βœ… Copied {filename} β†’ {dest_filename}") + else: + logger.warning(f"⚠️ Source file not found: {source_path}") + + # Also create a backtest info file with run details + info_file = os.path.join(output_dir, f"backtest_info_{timestamp}.txt") + with open(info_file, 'w') as f: + f.write(f"Backtest Run Information\n") + f.write(f"========================\n") + f.write(f"Run Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"From Date: {args.from_date}\n") + f.write(f"To Date: {args.to_date}\n") + f.write(f"Output Directory: {output_dir}\n") + f.write(f"Symbol Count: {len(pd.read_csv('symbol_list_backtest.csv'))}\n") + f.write(f"Source Files: backtest_engine_{timestamp}.py, timescaledb_{timestamp}.py\n") + + logger.info(f"πŸ“ Created backtest info file: backtest_info_{timestamp}.txt") + + except Exception as e: + logger.error(f"❌ Error copying source files: {e}") + # Don't fail the backtest if file copying fails + + # Import symbol list from CSV file + symbol_list = pd.read_csv('symbol_list_backtest.csv') + symbol_list = symbol_list['Symbol'].tolist() + + # Fetch historical data for each symbol + intervals = ["D", "15m", "5m", "1m"] + + # Create all combinations of (symbol, interval) + symbol_interval_pairs = list(product(symbol_list, intervals)) + + # Add "NIFTY" "1h" in the symbol_interval_pairs + symbol_interval_pairs.append(("NIFTY", "1h")) + symbol_interval_pairs.append(("NIFTY", "15m")) + + + # UNCOMMENT THIS BLOCK FOR FETCHING HISTORICAL DATA FOR ALL INTERVALS + # with ThreadPoolExecutor(max_workers=2) as executor: # Reduced to prevent server overload + # futures = [] + + # try: + # for symbol, interval in symbol_interval_pairs: + # time.sleep(1.5) # Increased delay to reduce server pressure + # futures.append( + # executor.submit( + # processor.process_symbol_interval, + # symbol, + # interval, + # client, + # start_date, + # end_date, + # "backtest" + # ) + # ) + + # # Wait for all tasks to complete with proper interrupt handling + # completed = 0 + # total = len(futures) + + # for future in futures: + # try: + # future.result(timeout=30) # 30 second timeout per task + # completed += 1 + # if completed % 10 == 0: # Progress update every 10 tasks + # logger.info(f"πŸ“Š Progress: {completed}/{total} tasks completed") + # except TimeoutError: + # logger.warning(f"⏰ Task timed out, continuing with next task") + # future.cancel() + # except Exception as e: + # logger.error(f"❌ Task failed: {e}") + + # logger.info(f"βœ… All {total} data fetching tasks completed") + + # except KeyboardInterrupt: + # logger.warning("\nπŸ›‘ KeyboardInterrupt received! Stopping all tasks...") + + # # Set interrupt flag to stop running tasks + # processor.interrupt_flag = True + + # # Cancel all pending futures + # cancelled_count = 0 + # for future in futures: + # if future.cancel(): + # cancelled_count += 1 + + # logger.info(f"πŸ“ Cancelled {cancelled_count} pending tasks") + + # # Force shutdown the executor + # logger.info("πŸ”„ Shutting down executor...") + # executor.shutdown(wait=False) + + # # Give running tasks a moment to cleanup and check interrupt flag + # logger.info("⏳ Waiting for running tasks to cleanup...") + # time.sleep(3) + + # logger.info("πŸ›‘ Data fetching interrupted by user") + # raise # Re-raise to exit the program + + + # Without threading + # for symbol, interval in symbol_interval_pairs: + # if interval == "5m" or interval == "1m": + # # Chunk the dates into smaller ranges to avoid timeout + # s_d = datetime.strptime(start_date, "%Y-%m-%d").date() + # e_d = datetime.strptime(end_date, "%Y-%m-%d").date() + # logger.info(f"Fetching data for {symbol} with interval {interval} from {s_d} to {e_d}") + + # for chunk_start, chunk_end in processor.chunk_dates(start_date=s_d, end_date=e_d, chunk_size_days=10): + # processor.fetch_historical_data(symbol, interval, client, chunk_start.strftime("%Y-%m-%d") , chunk_end.strftime("%Y-%m-%d")) + # else: + # processor.fetch_historical_data(symbol, interval, client, start_date, end_date) + + + # Process data in simulation mode + def run_backtest_for_symbol(symbol, connection_pool, start_date, end_date, base_output_dir): + # get connection from pool + conn = connection_pool.getconn() + + try: + engine = BacktestEngine( + conn=conn, + symbol=symbol, + start_date=start_date.strftime("%Y-%m-%d"), + end_date=end_date.strftime("%Y-%m-%d") + ) + trades_df = engine.run() + + output_dir = os.path.join(base_output_dir) + os.makedirs(output_dir, exist_ok=True) + + trades_file = os.path.join(output_dir, f"backtest_trades_{symbol}.csv") + summary_file = os.path.join(output_dir, f"summary_{symbol}.csv") + + trades_df.to_csv(trades_file, index=False) + + #if hasattr(engine, 'export_trail_charts'): + # engine.export_trail_charts() + # + logger.info(f"βœ… Backtest completed for {symbol} β†’ Trades: {len(trades_df)} β†’ Saved: backtest_trades_{symbol}.csv") + + finally: + connection_pool.putconn(conn) + + + def aggregate_all_summaries(base_output_dir="backtest_results", output_filename="master_summary.csv"): + """ + Aggregate all summary files from the organized folder structure + """ + # Find all summary files recursively + summary_pattern = os.path.join(base_output_dir, "**", "backtest_trades_*.csv") + summary_files = glob.glob(summary_pattern, recursive=True) + + if not summary_files: + logger.warning(f"No summary files found in {base_output_dir}") + return + + logger.info(f"Found {len(summary_files)} summary files to aggregate") + + all_dfs = [] + skipped_files = [] + + for file_path in summary_files: + try: + # Check if file is empty first + if os.path.getsize(file_path) == 0: + logger.warning(f"Skipping empty file: {file_path}") + skipped_files.append(file_path) + continue + + # Try to read the CSV file + df = pd.read_csv(file_path) + + # Check if DataFrame is empty or has no columns + if df.empty: + logger.warning(f"Skipping empty DataFrame from file: {file_path}") + skipped_files.append(file_path) + continue + + if len(df.columns) == 0: + logger.warning(f"Skipping file with no columns: {file_path}") + skipped_files.append(file_path) + continue + + # Add metadata columns to track source + df['source_file'] = os.path.basename(file_path) + df['folder_path'] = os.path.dirname(file_path) + all_dfs.append(df) + + except pd.errors.EmptyDataError: + #logger.warning(f"Skipping empty CSV file: {file_path}") + skipped_files.append(file_path) + continue + except pd.errors.ParserError as e: + logger.error(f"Parser error reading {file_path}: {e}") + skipped_files.append(file_path) + continue + except FileNotFoundError: + logger.error(f"File not found: {file_path}") + skipped_files.append(file_path) + continue + except PermissionError: + logger.error(f"Permission denied reading file: {file_path}") + skipped_files.append(file_path) + continue + except Exception as e: + logger.error(f"Unexpected error reading {file_path}: {e}") + skipped_files.append(file_path) + continue + + # Log summary of processing + processed_files = len(summary_files) - len(skipped_files) + logger.info(f"πŸ“Š Processing summary: {processed_files} files processed, {len(skipped_files)} files skipped") + + if all_dfs: + try: + master_df = pd.concat(all_dfs, ignore_index=True) + + # Convert datetime columns before sorting (fix for .dt accessor error) + if 'entry_time' in master_df.columns: + master_df['entry_time'] = pd.to_datetime(master_df['entry_time'], errors='coerce') + if 'exit_time' in master_df.columns: + master_df['exit_time'] = pd.to_datetime(master_df['exit_time'], errors='coerce') + + # Only sort if entry_time column exists and has valid data + if 'entry_time' in master_df.columns and not master_df['entry_time'].isna().all(): + master_df.sort_values(by='entry_time', inplace=True) + else: + logger.warning("entry_time column missing or contains no valid dates - skipping sort") + + # Save master trades in the base output directory + master_output_path_raw = os.path.join(base_output_dir, "backtest_trades_master_raw.csv") + master_df.to_csv(master_output_path_raw, index=False) + + # Process the trades with constraints + logger.info(f"Master DF columns before constraints: {list(master_df.columns)}") + logger.info(f"Master DF strategy values: {master_df['strategy'].unique() if 'strategy' in master_df.columns else 'No strategy column'}") + + master_df = process_trades_with_constraints(master_df) + + # Save master trades in the base output directory + master_output_path = os.path.join(base_output_dir, "backtest_trades_master.csv") + master_df.to_csv(master_output_path, index=False) + + # STRATEGY SUMMARY + logger.info(f"Creating strategy summary for {len(master_df)} trades") + logger.info(f"Available columns: {list(master_df.columns)}") + if 'strategy' in master_df.columns: + logger.info(f"Strategy values: {master_df['strategy'].unique()}") + else: + logger.error("No 'strategy' column found in master_df!") + + strat_summary = master_df.groupby('strategy').agg( + tot_trades=('gross_pnl', 'count'), + proftrades=('net_pnl', lambda x: (x > 0).sum()), + losstrades=('net_pnl', lambda x: (x < 0).sum()), + win_rate=('net_pnl', lambda x: (x > 0).mean() * 100), + gross_pnl=('gross_pnl', 'sum'), + brokerage=('brokerage', 'sum'), + tax_amount=('tax', 'sum'), + net_pnl__=('net_pnl', 'sum'), + avg_pnl__=('net_pnl', 'mean'), + max_dd__=('net_pnl', lambda x: (x.cumsum() - x.cumsum().cummax()).min()), + avg_dd_=('net_pnl', lambda x: + (lambda dd: dd[dd > 0].mean() if (dd > 0).any() else 0)( + x.cumsum() - x.cumsum().cummax() + ) + ) + ).reset_index() + strat_summary = strat_summary.round(2) + + # MONTH SUMMARY + master_df['entry_time'] = pd.to_datetime(master_df['entry_time']) + master_df['month'] = master_df['entry_time'].dt.to_period('M').astype(str) + + month_summary = master_df.groupby('month').agg( + tot_trades=('gross_pnl', 'count'), + proftrades=('net_pnl', lambda x: (x > 0).sum()), + losstrades=('net_pnl', lambda x: (x < 0).sum()), + win_rate=('net_pnl', lambda x: (x > 0).mean() * 100), + gross_pnl=('gross_pnl', 'sum'), + brokerage=('brokerage', 'sum'), + tax_amount=('tax', 'sum'), + net_pnl__=('net_pnl', 'sum'), + avg_pnl__=('net_pnl', 'mean'), + max_dd__=('net_pnl', lambda x: (x.cumsum() - x.cumsum().cummax()).min()), + avg_dd_=('net_pnl', lambda x: + (lambda dd: dd[dd > 0].mean() if (dd > 0).any() else 0)( + x.cumsum() - x.cumsum().cummax() + ) + ) + ).reset_index() + month_summary = month_summary.round(2) + + # MONTH-STRATEGY SUMMARY + month_strat_summary = master_df.groupby(['month', 'strategy']).agg( + tottrades=('gross_pnl', 'count'), + p_trades=('net_pnl', lambda x: (x > 0).sum()), + l_trades=('net_pnl', lambda x: (x < 0).sum()), + win_rate=('net_pnl', lambda x: (x > 0).mean() * 100), + gross_pnl=('gross_pnl', 'sum'), + brokerage=('brokerage', 'sum'), + tax_amt=('tax', 'sum'), + net_pnl_=('net_pnl', 'sum'), + avg_pnl_=('net_pnl', 'mean'), + max_dd__=('net_pnl', lambda x: (x.cumsum() - x.cumsum().cummax()).min()), + avg_dd__=('net_pnl', lambda x: + (lambda dd: dd[dd > 0].mean() if (dd > 0).any() else 0)( + x.cumsum() - x.cumsum().cummax() + ) + ) + ).reset_index() + month_strat_summary = month_strat_summary.round(2) + + # Save strategy summary in the base output directory + master_output_path = os.path.join(base_output_dir, "master_summary_by_strategy.csv") + strat_summary.to_csv(master_output_path, index=False) + + # Save month summary in the base output directory + master_output_path = os.path.join(base_output_dir, "master_summary_by_month.csv") + month_summary.to_csv(master_output_path, index=False) + + # Save month-strategy summary in the base output directory + master_output_path = os.path.join(base_output_dir, "master_summary_by_strategy_month.csv") + month_strat_summary.to_csv(master_output_path, index=False) + + # Remove the backtest_trades_* except backtest_trades_master.csv + for file in os.listdir(base_output_dir): + if file.startswith("backtest_trades_") and file != "backtest_trades_master.csv" and file != "backtest_trades_master_raw.csv": + file_path = os.path.join(base_output_dir, file) + os.remove(file_path) + + print(f"πŸ“ˆ Total symbols processed: {len(master_df)}") + print(f"πŸ“‹ Valid files: {len(all_dfs)}, Skipped files: {len(skipped_files)}") + + if not strat_summary.empty: + print_aggregate_totals_1(strat_summary, 'PERFORMANCE STRATEGY_WISE') + + if not month_summary.empty: + print_aggregate_totals_1(month_summary, 'PERFORMANCE MONTH_WISE') + + if not month_strat_summary.empty: + print_aggregate_totals_2(month_strat_summary, "PERFORMANCE MONTH_STRATEGY_WISE") + + except Exception as e: + logger.error(f"Error creating master summary: {e}") + logger.error(f"Number of DataFrames to concatenate: {len(all_dfs)}") + return + + else: + logger.warning("❌ No valid summary data found to aggregate") + print(f"\n⚠️ No valid summary files found. All {len(skipped_files)} files were skipped.") + + # Optionally, create an empty master file with headers if you know the expected structure + try: + # Create empty master file with basic structure + empty_df = pd.DataFrame(columns=['symbol', 'total_trades', 'profitable_trades', 'loss_trades', + 'win_rate', 'gross_pnl', 'max_drawdown', 'source_file', 'folder_path']) + master_output_path = os.path.join(base_output_dir, output_filename) + empty_df.to_csv(master_output_path, index=False) + logger.info(f"πŸ“„ Created empty master summary file: {master_output_path}") + except Exception as e: + logger.error(f"Error creating empty master summary: {e}") + + + def print_aggregate_totals_1(summary_df, title='PERFORMANCE SUMMARY'): + """ + Print Excel-like table with perfect alignment between headers and data rows + """ + if not isinstance(summary_df, pd.DataFrame) or summary_df.empty: + print(f"{Fore.RED}❌ No valid summary data") + return + + try: + # Create display copy + display_df = summary_df.copy() + + # Format numeric columns + def format_currency(x): + if pd.isna(x): return "N/A" + x = float(x) + if abs(x) >= 1000000: return f"β‚Ή{x/1000000:.1f}M" + if abs(x) >= 1000: return f"β‚Ή{x/1000:.1f}K" + return f"β‚Ή{x:.0f}" + + currency_cols = ['gross_pnl', 'tax_amount', 'brokerage', 'net_pnl__', 'avg_pnl__', 'max_dd__', 'avg_dd_'] + for col in currency_cols: + display_df[col] = display_df[col].apply(format_currency) + + # Get terminal width + try: + terminal_width = os.get_terminal_size().columns + except: + terminal_width = 80 + + # Calculate column widths (content + header) + col_widths = {} + for col in display_df.columns: + content_width = max(display_df[col].astype(str).apply(len).max(), len(col)) + col_widths[col] = min(content_width + 2, 20) # Max 20 chars per column + + # Adjust to fit terminal + while sum(col_widths.values()) + len(col_widths) + 1 > terminal_width: + max_col = max(col_widths, key=col_widths.get) + if col_widths[max_col] > 8: # Never go below 8 chars + col_widths[max_col] -= 1 + else: + break # Can't shrink further + + # Build horizontal border + border = '+' + '+'.join(['-' * (col_widths[col]) for col in display_df.columns]) + '+' + + # Print header + print(f"\n{Style.BRIGHT}{Fore.BLUE}πŸ“Š {title}") + print(border) + + # Print column headers + header_cells = [] + for col in display_df.columns: + header = f" {col.upper().replace('_', ' ')}" + header = header.ljust(col_widths[col]-1) + header_cells.append(f"{Style.BRIGHT}{header}{Style.RESET_ALL}") + print('|' + '|'.join(header_cells) + '|') + print(border) + + # Print data rows + for _, row in display_df.iterrows(): + cells = [] + for col in display_df.columns: + cell_content = str(row[col])[:col_widths[col]-2] + if len(str(row[col])) > col_widths[col]-2: + cell_content = cell_content[:-1] + '…' + cells.append(f" {cell_content.ljust(col_widths[col]-1)}") + print('|' + '|'.join(cells) + '|') + + # Print footer + print(border) + + # Print summary + if 'net_pnl__' in summary_df.columns: + total_net = summary_df['net_pnl__'].sum() + total_trades = summary_df['tot_trades'].sum() + avg_net = total_net / total_trades + status = (f"{Fore.GREEN}↑PROFIT" if total_net > 0 else + f"{Fore.RED}↓LOSS" if total_net < 0 else + f"{Fore.YELLOW}βž”BREAKEVEN") + print(f"| {status}{Style.RESET_ALL} Net: {format_currency(total_net)} " + f"Avg P/L: {format_currency(avg_net)} " + f"Trades: {total_trades:,} " + f"Win%: {summary_df['proftrades'].sum()/total_trades*100:.1f}%".ljust(len(border)-1) + "|") + print(border + Style.RESET_ALL) + + except Exception as e: + print(f"{Fore.RED}❌ Error displaying table: {e}") + + def print_aggregate_totals_2(summary_df, title='PERFORMANCE SUMMARY'): + """ + Print Excel-like table with perfect alignment between headers and data rows + """ + if not isinstance(summary_df, pd.DataFrame) or summary_df.empty: + print(f"{Fore.RED}❌ No valid summary data") + return + + try: + # Create display copy + display_df = summary_df.copy() + + # Format numeric columns + def format_currency(x): + if pd.isna(x): return "N/A" + x = float(x) + if abs(x) >= 1000000: return f"β‚Ή{x/1000000:.1f}M" + if abs(x) >= 1000: return f"β‚Ή{x/1000:.1f}K" + return f"β‚Ή{x:.0f}" + + currency_cols = ['gross_pnl', 'tax_amt', 'brokerage', 'net_pnl_', 'avg_pnl_', 'max_dd__', 'avg_dd__'] + for col in currency_cols: + display_df[col] = display_df[col].apply(format_currency) + + # Get terminal width + try: + terminal_width = os.get_terminal_size().columns + except: + terminal_width = 80 + + # Calculate column widths (content + header) + col_widths = {} + for col in display_df.columns: + content_width = max(display_df[col].astype(str).apply(len).max(), len(col)) + col_widths[col] = min(content_width + 2, 20) # Max 20 chars per column + + # Adjust to fit terminal + while sum(col_widths.values()) + len(col_widths) + 1 > terminal_width: + max_col = max(col_widths, key=col_widths.get) + if col_widths[max_col] > 8: # Never go below 8 chars + col_widths[max_col] -= 1 + else: + break # Can't shrink further + + # Build horizontal border + border = '+' + '+'.join(['-' * (col_widths[col]) for col in display_df.columns]) + '+' + + # Print header + print(f"\n{Style.BRIGHT}{Fore.BLUE}πŸ“Š {title}") + print(border) + + # Print column headers + header_cells = [] + for col in display_df.columns: + header = f" {col.upper().replace('_', ' ')}" + header = header.ljust(col_widths[col]-1) + header_cells.append(f"{Style.BRIGHT}{header}{Style.RESET_ALL}") + print('|' + '|'.join(header_cells) + '|') + print(border) + + # Print data rows + for _, row in display_df.iterrows(): + cells = [] + for col in display_df.columns: + cell_content = str(row[col])[:col_widths[col]-2] + if len(str(row[col])) > col_widths[col]-2: + cell_content = cell_content[:-1] + '…' + cells.append(f" {cell_content.ljust(col_widths[col]-1)}") + print('|' + '|'.join(cells) + '|') + + # Print footer + print(border) + + # Print summary + if 'net_pnl_' in summary_df.columns: + total_net = summary_df['net_pnl_'].sum() + status = (f"{Fore.GREEN}↑PROFIT" if total_net > 0 else + f"{Fore.RED}↓LOSS" if total_net < 0 else + f"{Fore.YELLOW}βž”BREAKEVEN") + print(f"| {status}{Style.RESET_ALL} Net: {format_currency(total_net)} " + f"Trades: {summary_df['tottrades'].sum():,} " + f"Win%: {summary_df['p_trades'].sum()/summary_df['tottrades'].sum()*100:.1f}%".ljust(len(border)-1) + "|") + print(border + Style.RESET_ALL) + + except Exception as e: + print(f"{Fore.RED}❌ Error displaying table: {e}") + + def process_trades_with_constraints(trades_df): + # Sort trades by entry time (datetime conversion already done in aggregate_all_summaries) + trades_df = trades_df.sort_values('entry_time') + + # Initialize tracking variables + active_positions = [] + strategy_counts = defaultdict(int) + daily_strategy_tracker = defaultdict(set) # {date: {strategies_used}} + filtered_trades = [] + + for _, trade in trades_df.iterrows(): + trade_date = trade['entry_time'].date() + strategy = str(trade['strategy']) + symbol = trade['symbol'] + + # Check if we've already used this strategy today + if strategy in daily_strategy_tracker.get(trade_date, set()): + continue + + # Check if we have capacity for new positions (max 3) + if len(active_positions) >= 3: + # Find the earliest exit time among active positions + earliest_exit = min(pos['exit_time'] for pos in active_positions) + if trade['entry_time'] < earliest_exit: + # Can't take this trade as all 3 positions would still be open + continue + + # For same-time entries, we need to check alphabetical priority + # Get all trades at the same entry time for the same strategy + same_time_trades = trades_df[ + (trades_df['entry_time'] == trade['entry_time']) & + (trades_df['strategy'] == trade['strategy'])] + + if len(same_time_trades) > 1: + # Sort by symbol alphabetically and take the first one + same_time_trades = same_time_trades.sort_values('symbol') + if symbol != same_time_trades.iloc[0]['symbol']: + continue + + # If we get here, the trade passes all constraints + filtered_trades.append(trade) + + # Update tracking + daily_strategy_tracker[trade_date].add(strategy) + + # Add to active positions + active_positions.append({ + 'symbol': symbol, + 'strategy': strategy, + 'entry_time': trade['entry_time'], + 'exit_time': trade['exit_time'] + }) + + # Remove any positions that have exited + active_positions = [pos for pos in active_positions + if pos['exit_time'] > trade['entry_time']] + + # Create new DataFrame with filtered trades + filtered_df = pd.DataFrame(filtered_trades) + + return filtered_df + + # Run backtests in parallel + # Create connection pool once + db_config = { + "user": processor.db_manager.user, + "password": processor.db_manager.password, + "host": processor.db_manager.host, + "port": processor.db_manager.port, + "dbname": processor.db_manager.dbname + } + + connection_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=1, maxconn=8, # Adjust based on your needs + user=db_config['user'], + password=db_config['password'], + host=db_config['host'], + port=db_config['port'], + dbname=db_config['dbname'] + ) + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + for symbol in symbol_list: + futures.append(executor.submit(run_backtest_for_symbol, symbol, connection_pool, from_date, to_date, base_output_dir)) + try: + for i, future in enumerate(futures): + try: + future.result() + logger.info(f"βœ… Future {i+1}/{len(futures)} completed successfully") + except Exception as e: + logger.error(f"❌ Future {i+1}/{len(futures)} failed with error: {e}") + logger.error(f"Error type: {type(e).__name__}") + import traceback + logger.error(f"Traceback: {traceback.format_exc()}") + raise # Re-raise the error + except KeyboardInterrupt: + print("Interrupted by user. Cancelling all futures.") + for future in futures: + future.cancel() + executor.shutdown(wait=False, cancel_futures=True) + + # Aggregate summaries from the organized folder structure + aggregate_all_summaries(base_output_dir, "master_summary.csv") + + # End the timer + end_time = time.time() + elapsed_time = end_time - start_time + logger.info(f"Elapsed time: {elapsed_time} seconds") + + except Exception as e: + logger.error(f"Fatal error: {e}") + processor.shutdown() diff --git a/websocket_proxy/base_adapter.py b/websocket_proxy/base_adapter.py index ccdcea0d..26e4ff74 100644 --- a/websocket_proxy/base_adapter.py +++ b/websocket_proxy/base_adapter.py @@ -1,11 +1,20 @@ import json import threading import zmq +import os import random import socket -import os from abc import ABC, abstractmethod from utils.logging import get_logger +from datetime import datetime + +# RedPanda/Kafka imports +try: + from kafka import KafkaProducer, KafkaConsumer + from kafka.errors import KafkaError + KAFKA_AVAILABLE = True +except ImportError: + KAFKA_AVAILABLE = False # Initialize logger logger = get_logger(__name__) @@ -67,13 +76,66 @@ def find_free_zmq_port(start_port=5556, max_attempts=50): logger.error("Failed to find an available port after maximum attempts") return None + +def find_free_redpanda_port(start_port=9092, max_attempts=50): + """ + Find an available port starting from start_port for RedPanda broker + + Args: + start_port (int): Port number to start the search from (default 9092 for Kafka) + max_attempts (int): Maximum number of attempts to find a free port + + Returns: + int: Available port number, or None if no port is available + """ + # Create logger here instead of using self.logger because this is a standalone function + logger = get_logger("redpanda_port_finder") + + # First check if any ports in the bound_ports set are actually free now + # This handles cases where the process that had the port died without cleanup + with BaseBrokerWebSocketAdapter._port_lock: + ports_to_remove = [] + for port in list(BaseBrokerWebSocketAdapter._bound_redpanda_ports): + if is_port_available(port): + logger.info(f"RedPanda port {port} in registry is actually free now, removing from bound ports") + ports_to_remove.append(port) + + # Remove ports that are actually available now + for port in ports_to_remove: + BaseBrokerWebSocketAdapter._bound_redpanda_ports.remove(port) + + # Now find a new free port + for _ in range(max_attempts): + # Try a sequential port first, then random if that fails + current_port = start_port + + # Check if this port is available and not in our bound_ports set + if (current_port not in BaseBrokerWebSocketAdapter._bound_redpanda_ports and + is_port_available(current_port)): + return current_port + + # Try a random port between start_port and 65535 + current_port = random.randint(start_port, 65535) + if (current_port not in BaseBrokerWebSocketAdapter._bound_redpanda_ports and + is_port_available(current_port)): + return current_port + + # Increment start_port for next sequential try + start_port = min(start_port + 1, 65000) # Cap at 65000 to stay in safe range + + # If we get here, we couldn't find an available port + logger.error("Failed to find an available RedPanda port after maximum attempts") + return None + + class BaseBrokerWebSocketAdapter(ABC): """ Base class for all broker-specific WebSocket adapters that implements common functionality and defines the interface for broker-specific implementations. """ - # Class variable to track bound ports across instances + # Class variables to track bound ports across instances _bound_ports = set() + _bound_redpanda_ports = set() _port_lock = threading.Lock() _shared_context = None _context_lock = threading.Lock() @@ -81,46 +143,169 @@ class BaseBrokerWebSocketAdapter(ABC): def __init__(self): self.logger = get_logger("broker_adapter") self.logger.info("BaseBrokerWebSocketAdapter initializing") + # ZeroMQ publisher setup for internal message distribution + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUB) + + # ZMQ publishing control + self.zmq_enabled = self._should_enable_zmq() - try: - # Initialize shared ZeroMQ context - self._initialize_shared_context() - - # Create socket and bind to port - self.socket = self._create_socket() + # Find an available port for ZMQ only if enabled + self.logger = get_logger("broker_adapter") + if self.zmq_enabled: self.zmq_port = self._bind_to_available_port() + self.logger.info(f"ZeroMQ socket bound to port {self.zmq_port}") + # Updating used ZMQ_PORT in environment variable. + # We must use os.environ (not os.getenv) for setting environment variables os.environ["ZMQ_PORT"] = str(self.zmq_port) + else: + self.zmq_port = None + self.logger.info("ZeroMQ publishing disabled") + + # RedPanda/Kafka setup + self.redpanda_enabled = self._should_enable_redpanda() + self.kafka_producer = None + self.kafka_consumer = None + self.redpanda_config = None + self.redpanda_topics = {} + + if self.redpanda_enabled: + self._setup_redpanda() + + # Subscription tracking + self.subscriptions = {} + self.connected = False + + def _should_enable_zmq(self): + """ + Check if ZeroMQ should be enabled based on environment variables + """ + # Check environment variable + enable_zmq = os.getenv('ENABLE_ZMQ_PUBLISH', 'true').lower() in ('true', '1', 'yes', 'on') + + if enable_zmq: + self.logger.info("ZeroMQ publishing enabled via environment variable") + #else: + # self.logger.info("ZeroMQ publishing disabled. Set ENABLE_ZMQ_PUBLISH=true to enable") - # Initialize instance variables - self.subscriptions = {} - self.connected = False + return enable_zmq + + def _should_enable_redpanda(self): + """ + Check if RedPanda should be enabled based on environment variables and availability + """ + if not KAFKA_AVAILABLE: + self.logger.warning("Kafka library not available. Install with: pip install kafka-python") + return False - self.logger.info(f"BaseBrokerWebSocketAdapter initialized on port {self.zmq_port}") + # Check environment variable + enable_redpanda = os.getenv('ENABLE_REDPANDA', 'false').lower() in ('true', '1', 'yes', 'on') + + if enable_redpanda: + self.logger.info("RedPanda streaming enabled via environment variable") + else: + self.logger.info("RedPanda streaming disabled. Set ENABLE_REDPANDA=true to enable") - except Exception as e: - self.logger.error(f"Error in BaseBrokerWebSocketAdapter init: {e}") - raise - - def _initialize_shared_context(self): + return enable_redpanda + + def _setup_redpanda(self): """ - Initialize shared ZeroMQ context if not already created + Setup RedPanda/Kafka configuration and connections """ - with self._context_lock: - if not BaseBrokerWebSocketAdapter._shared_context: - self.logger.info("Creating shared ZMQ context") - BaseBrokerWebSocketAdapter._shared_context = zmq.Context() - - self.context = BaseBrokerWebSocketAdapter._shared_context - - def _create_socket(self): + try: + # Get RedPanda configuration from environment + self.redpanda_config = { + 'bootstrap_servers': os.getenv('REDPANDA_BROKERS', 'localhost:9092'), + 'topic_prefix': os.getenv('REDPANDA_TOPIC_PREFIX', 'openalgo'), + 'compression_type': os.getenv('REDPANDA_COMPRESSION', 'snappy'), + 'batch_size': int(os.getenv('REDPANDA_BATCH_SIZE', '16384')), + 'buffer_memory': int(os.getenv('REDPANDA_BUFFER_MEMORY', '33554432')), + 'linger_ms': int(os.getenv('REDPANDA_LINGER_MS', '10')), + 'acks': os.getenv('REDPANDA_ACKS', 'all'), + 'retries': int(os.getenv('REDPANDA_RETRIES', '3')), + 'client_id': os.getenv('REDPANDA_CLIENT_ID', f'openalgo-{os.getpid()}'), + 'max_in_flight_requests_per_connection': int(os.getenv('REDPANDA_MAX_IN_FLIGHT_REQUESTS', '1')) + } + + # Define topics for different data types + self.redpanda_topics = { + 'tick_data': f"{self.redpanda_config['topic_prefix']}.tick.raw" + } + + # Initialize Kafka producer + self._init_kafka_producer() + + # Store RedPanda port in environment + redpanda_port = self.redpanda_config['bootstrap_servers'].split(':')[-1] + os.environ["REDPANDA_PORT"] = str(redpanda_port) + + self.logger.info(f"RedPanda configuration initialized: {self.redpanda_config['bootstrap_servers']}") + + + except Exception as e: + self.logger.error(f"Failed to setup RedPanda configuration: {e}") + self.redpanda_enabled = False + + def _init_kafka_producer(self): """ - Create and configure ZeroMQ socket + Initialize Kafka producer for publishing market data """ - with self._context_lock: - socket = self.context.socket(zmq.PUB) - socket.setsockopt(zmq.LINGER, 1000) # 1 second linger - socket.setsockopt(zmq.SNDHWM, 1000) # High water mark - return socket + try: + producer_config = { + 'bootstrap_servers': self.redpanda_config['bootstrap_servers'], + 'client_id': self.redpanda_config['client_id'], + 'value_serializer': lambda v: json.dumps(v).encode('utf-8'), + 'key_serializer': lambda k: str(k).encode('utf-8') if k else None, + 'acks': self.redpanda_config['acks'], + 'retries': self.redpanda_config['retries'], + 'batch_size': self.redpanda_config['batch_size'], + 'buffer_memory': self.redpanda_config['buffer_memory'], + 'linger_ms': self.redpanda_config['linger_ms'], + 'compression_type': self.redpanda_config['compression_type'], + 'max_in_flight_requests_per_connection': self.redpanda_config['max_in_flight_requests_per_connection'], + 'enable_idempotence': True + } + + self.kafka_producer = KafkaProducer(**producer_config) + self.logger.info("Kafka producer initialized successfully") + + except Exception as e: + self.logger.error(f"Failed to initialize Kafka producer: {e}") + self.redpanda_enabled = False + + # def _init_kafka_consumer(self, topics, group_id=None): + # """ + # Initialize Kafka consumer for reading market data + + # Args: + # topics: List of topics to subscribe to + # group_id: Consumer group ID + # """ + # try: + # if not group_id: + # group_id = f"openalgo-consumer-{os.getpid()}" + + # consumer_config = { + # 'bootstrap_servers': self.redpanda_config['bootstrap_servers'], + # 'client_id': self.redpanda_config['client_id'], + # 'group_id': group_id, + # 'value_deserializer': lambda m: json.loads(m.decode('utf-8')), + # 'key_deserializer': lambda k: k.decode('utf-8') if k else None, + # 'auto_offset_reset': 'latest', + # 'enable_auto_commit': True, + # 'auto_commit_interval_ms': 1000, + # 'session_timeout_ms': 30000, + # 'heartbeat_interval_ms': 3000 + # } + + # self.kafka_consumer = KafkaConsumer(*topics, **consumer_config) + # self.logger.info(f"Kafka consumer initialized for topics: {topics}") + # return self.kafka_consumer + + # except Exception as e: + # self.logger.error(f"Failed to initialize Kafka consumer: {e}") + # return None + def _bind_to_available_port(self): """ diff --git a/websocket_proxy/server.py b/websocket_proxy/server.py index 8dfb42fe..06f846c8 100644 --- a/websocket_proxy/server.py +++ b/websocket_proxy/server.py @@ -12,6 +12,14 @@ from typing import Dict, Set, Any, Optional from dotenv import load_dotenv +# RedPanda/Kafka imports +try: + from kafka import KafkaProducer, KafkaConsumer + from kafka.errors import KafkaError + KAFKA_AVAILABLE = True +except ImportError: + KAFKA_AVAILABLE = False + from .port_check import is_port_in_use, find_available_port from database.auth_db import get_broker_name from sqlalchemy import text @@ -27,6 +35,7 @@ class WebSocketProxy: WebSocket Proxy Server that handles client connections and authentication, manages subscriptions, and routes market data from broker adapters to clients. Supports dynamic broker selection based on user configuration. + Enhanced with Kafka/RedPanda integration for scalable message distribution. """ def __init__(self, host: str = "127.0.0.1", port: int = 8765): @@ -71,20 +80,107 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8765): # Set up ZeroMQ subscriber to receive all messages self.socket.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics + + # Kafka/RedPanda integration + #self.kafka_enabled = os.getenv('ENABLE_REDPANDA', 'false').lower() == 'true' and KAFKA_AVAILABLE + #self.kafka_consumer = None + #self.kafka_consumer_task = None + self.kafka_subscriptions = {} # Maps client_id to set of kafka topics + + #if self.kafka_enabled: + # logger.info("Kafka integration enabled for WebSocket proxy") + # self._setup_kafka() + #else: + # logger.info("Kafka integration disabled - using ZMQ only") + + def _setup_kafka(self): + """Setup Kafka consumer for market data distribution""" + try: + kafka_brokers = os.getenv('REDPANDA_BROKERS', 'localhost:9092') + consumer_group = os.getenv('REDPANDA_CONSUMER_GROUP_ID', 'websocket_proxy_group') + + # Parse and validate broker addresses + broker_list = [] + for broker in kafka_brokers.split(','): + broker = broker.strip() + if ':' in broker: + host, port = broker.rsplit(':', 1) + try: + port = int(port) # Convert port to integer + broker_list.append(f"{host}:{port}") + except ValueError: + logger.error(f"Invalid port in broker address: {broker}") + broker_list.append(broker) # Use as-is if port conversion fails + else: + broker_list.append(broker) + + # Kafka consumer configuration + consumer_config = { + 'bootstrap_servers': broker_list, # Use the parsed broker list + 'group_id': consumer_group, + 'key_deserializer': lambda k: k.decode('utf-8') if k else None, + 'value_deserializer': lambda v: json.loads(v.decode('utf-8')), + 'auto_offset_reset': os.getenv('REDPANDA_AUTO_OFFSET_RESET', 'latest'), + 'enable_auto_commit': os.getenv('REDPANDA_ENABLE_AUTO_COMMIT', 'true').lower() == 'true', + 'auto_commit_interval_ms': int(os.getenv('REDPANDA_AUTO_COMMIT_INTERVAL_MS', '1000')), + 'session_timeout_ms': int(os.getenv('REDPANDA_SESSION_TIMEOUT_MS', '30000')), + 'heartbeat_interval_ms': int(os.getenv('REDPANDA_HEARTBEAT_INTERVAL_MS', '3000')), + 'max_poll_records': int(os.getenv('REDPANDA_MAX_POLL_RECORDS', '500')), + 'fetch_min_bytes': int(os.getenv('REDPANDA_FETCH_MIN_BYTES', '1')), + 'fetch_max_wait_ms': int(os.getenv('REDPANDA_FETCH_MAX_WAIT_MS', '500')), + } + + # Security configuration if needed + security_protocol = os.getenv('REDPANDA_SECURITY_PROTOCOL', 'PLAINTEXT') + if security_protocol != 'PLAINTEXT': + consumer_config['security_protocol'] = security_protocol + if security_protocol in ['SASL_PLAINTEXT', 'SASL_SSL']: + consumer_config['sasl_mechanism'] = os.getenv('REDPANDA_SASL_MECHANISM', 'PLAIN') + consumer_config['sasl_plain_username'] = os.getenv('REDPANDA_SASL_USERNAME', '') + consumer_config['sasl_plain_password'] = os.getenv('REDPANDA_SASL_PASSWORD', '') + if security_protocol in ['SSL', 'SASL_SSL']: + ssl_cafile = os.getenv('REDPANDA_SSL_CAFILE') + ssl_certfile = os.getenv('REDPANDA_SSL_CERTFILE') + ssl_keyfile = os.getenv('REDPANDA_SSL_KEYFILE') + if ssl_cafile: + consumer_config['ssl_cafile'] = ssl_cafile + if ssl_certfile: + consumer_config['ssl_certfile'] = ssl_certfile + if ssl_keyfile: + consumer_config['ssl_keyfile'] = ssl_keyfile + + # Add connection timeout to prevent hanging + consumer_config['request_timeout_ms'] = int(os.getenv('REDPANDA_REQUEST_TIMEOUT_MS', '30000')) + consumer_config['connections_max_idle_ms'] = int(os.getenv('REDPANDA_CONNECTIONS_MAX_IDLE_MS', '540000')) + + logger.info(f"Attempting to connect to Kafka brokers: {broker_list}") + self.kafka_consumer = KafkaConsumer(**consumer_config) + logger.info(f"Kafka consumer initialized successfully with brokers: {broker_list}") + + except Exception as e: + logger.error(f"Failed to setup Kafka consumer: {e}") + logger.info("Kafka integration will be disabled, falling back to ZMQ only") + self.kafka_enabled = False + self.kafka_consumer = None async def start(self): - """Start the WebSocket server and ZeroMQ listener""" + """Start the WebSocket server, ZeroMQ listener, and Kafka consumer""" self.running = True try: # Start ZeroMQ listener - logger.info("Initializing ZeroMQ listener task") + # logger.info("Initializing ZeroMQ listener task") # Get the current event loop loop = aio.get_running_loop() # Create the ZMQ listener task - zmq_task = loop.create_task(self.zmq_listener()) + # zmq_task = loop.create_task(self.zmq_listener()) + + # Start Kafka consumer if enabled + # if self.kafka_enabled and self.kafka_consumer: + # logger.info("Starting Kafka consumer task") + # self.kafka_consumer_task = loop.create_task(self.kafka_listener()) # Start WebSocket server stop = aio.Future() # Used to stop the server @@ -235,6 +331,7 @@ async def handle_client(self, websocket): client_id = id(websocket) self.clients[client_id] = websocket self.subscriptions[client_id] = set() + self.kafka_subscriptions[client_id] = set() # Get path info from websocket if available path = getattr(websocket, 'path', '/unknown') @@ -1002,6 +1099,7 @@ async def main(): # Continue with ZeroMQ listener even if signal handlers fail if proxy: await proxy.zmq_listener() + else: logger.error(f"Runtime error: {e}") raise