From 14ada3e9647b1e535e30a01fcba180f93b86ee27 Mon Sep 17 00:00:00 2001 From: oleh starosvitskyi Date: Mon, 9 Jun 2025 09:15:44 +0200 Subject: [PATCH 1/2] add_xpub --- app/__init__.py | 20 ++++++++++----- app/api/services/__init__.py | 0 app/api/services/address_service.py | 38 +++++++++++++++++++++++++++++ app/api/views.py | 34 ++++++++++---------------- app/block_scanner.py | 13 ++++++++++ app/config.py | 11 ++++++++- app/schemas.py | 2 ++ app/tasks.py | 15 +++++++++--- requirements.txt | 1 + run.py | 11 ++++----- 10 files changed, 107 insertions(+), 38 deletions(-) create mode 100644 app/api/services/__init__.py create mode 100644 app/api/services/address_service.py diff --git a/app/__init__.py b/app/__init__.py index 029ae7d..97a2470 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -22,6 +22,14 @@ sqlite3.register_adapter(decimal.Decimal, lambda x: str(x)) sqlite3.register_converter("DECTEXT", lambda x: decimal.Decimal(x.decode())) +def init_settings_table(db): + create_table_sql = """ + CREATE TABLE IF NOT EXISTS settings ( + name TEXT PRIMARY KEY, + value TEXT + ); + """ + db.query_db2(create_table_sql) def create_app(): @@ -47,13 +55,13 @@ def __dir__(self): from . import db db.init_app(app) + init_settings_table(db) + key_type = "only_read" if app.config.READ_MODE else "onetime" - block_scanner.BlockScanner.set_watched_accounts( - [ - row["public"] - for row in db.query_db2('select public from keys where type = "onetime"') - ] - ) + rows = db.query_db2(f'SELECT public FROM keys WHERE type = "{key_type}"') + accounts = [row["public"] for row in rows] + + block_scanner.BlockScanner.set_watched_accounts(accounts) from . import utils diff --git a/app/api/services/__init__.py b/app/api/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/services/address_service.py b/app/api/services/address_service.py new file mode 100644 index 0000000..20e7641 --- /dev/null +++ b/app/api/services/address_service.py @@ -0,0 +1,38 @@ +from bip_utils import Bip32PublicKey, Bip44, Bip44Coins, Bip44Changes +from Crypto.Hash import keccak +import base58 +from ...db import get_db +from ...logging import logger +from ...wallet_encryption import wallet_encryption +from ...block_scanner import BlockScanner + + +def generate_address_from_xpub(symbol, xpub_str): + db = get_db() + bip44_acc = Bip44.FromExtendedKey(xpub_str, Bip44Coins.TRON) + bip44_change = bip44_acc.Change(Bip44Changes.CHAIN_EXT) + count = db.execute("SELECT COUNT(*) FROM keys").fetchone()[0] + address = bip44_change.AddressIndex(count).PublicKey().ToAddress() + db.execute( + "INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'only_read')", + (symbol, address, ''), + ) + db.commit() + BlockScanner.add_watched_account(address) + return address + + +def generate_address_with_private_key(symbol, client): + db = get_db() + addresses = client.generate_address() + public_address = addresses["base58check_address"] + encrypted_priv = wallet_encryption.encrypt(addresses["private_key"]) + + db.execute( + "INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'onetime')", + (symbol, public_address, encrypted_priv), + ) + db.commit() + BlockScanner.add_watched_account(public_address) + + return public_address \ No newline at end of file diff --git a/app/api/views.py b/app/api/views.py index 5d51367..ae674a2 100644 --- a/app/api/views.py +++ b/app/api/views.py @@ -5,6 +5,7 @@ import tronpy.exceptions from flask import current_app, g +from flask import request from tronpy import Tron from ..db import get_db, query_db @@ -15,31 +16,22 @@ from ..connection_manager import ConnectionManager from . import api from ..wallet_encryption import wallet_encryption +from ..config import config +from .services.address_service import generate_address_from_xpub, generate_address_with_private_key @api.post("/generate-address") def generate_new_address(): - client = Tron() - addresses = client.generate_address() - - db = get_db() - db.execute( - "INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'onetime')", - ( - g.symbol, - addresses["base58check_address"], - wallet_encryption.encrypt(addresses["private_key"]), - ), - ) - db.commit() - - BlockScanner.add_watched_account(addresses["base58check_address"]) - - return { - "status": "success", - "base58check_address": addresses["base58check_address"], - } - + symbol = g.symbol + if config.READ_MODE: + data = request.get_json(silent=True) or {} + xpub_str = data.get("xpub") + address = generate_address_from_xpub(symbol, xpub_str) + return {"status": "success", "base58check_address": address} + else: + client = Tron() + address = generate_address_with_private_key(symbol, client) + return {"status": "success", "base58check_address": address} @api.post("/balance") def get_balance(): diff --git a/app/block_scanner.py b/app/block_scanner.py index af0207a..123cc28 100644 --- a/app/block_scanner.py +++ b/app/block_scanner.py @@ -162,6 +162,18 @@ def download_tx_info_by_block_num(self, n): result["id"]: result for result in transaction_results if "log" in result } + def mark_key_as_finished(self, dst_addr): + row = query_db2( + 'SELECT 1 FROM keys WHERE public = ? AND type = "only_read"', + (dst_addr,), + one=True + ) + if row: + query_db2( + 'UPDATE keys SET type = "only_read_finished" WHERE public = ? AND type = "only_read"', + (dst_addr,) + ) + def notify_shkeeper(self, symbol, txid): url = f"http://{config.SHKEEPER_HOST}/api/v1/walletnotify/{symbol}/{txid}" headers = {"X-Shkeeper-Backend-Key": config.SHKEEPER_BACKEND_KEY} @@ -279,6 +291,7 @@ def scan(self, block_num: int) -> bool: if tron_tx.dst_addr in valid_addresses: if tron_tx.status == "SUCCESS": logger.info(f"Sending notification for {tron_tx}") + self.mark_key_as_finished(tron_tx.dst_addr) self.notify_shkeeper(tron_tx.symbol.value, tron_tx.txid) # Send funds to main account if tron_tx.is_trc20: diff --git a/app/config.py b/app/config.py index 6d52d1a..a7376eb 100644 --- a/app/config.py +++ b/app/config.py @@ -1,9 +1,11 @@ from decimal import Decimal from functools import cache from typing import List +import os -from pydantic import Field, Json, field_validator +from pydantic import Field, Json, field_validator, validator from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import ClassVar from .custom.aml.schemas import ExternalDrain from .schemas import TronFullnode, TronNetwork, Token, TronSymbol @@ -54,6 +56,13 @@ class Settings(BaseSettings): AML_RESULT_UPDATE_PERIOD: int = 120 AML_SWEEP_ACCOUNTS_PERIOD: int = 3600 AML_WAIT_BEFORE_API_CALL: int = 320 + READ_MODE: bool = False + + @validator("READ_MODE", pre=True) + def parse_read_mode(cls, v): + if isinstance(v, str): + return v.lower() in ("true", "1", "yes", "enabled") + return bool(v) TOKENS: List[Token] = [ Token( diff --git a/app/schemas.py b/app/schemas.py index 555bad4..e607c1e 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -14,6 +14,8 @@ class KeyType(str, Enum): fee_deposit = "fee_deposit" onetime = "onetime" + only_read = "only_read" + only_read_finished = "only_read_finished" class TronNetwork(str, Enum): diff --git a/app/tasks.py b/app/tasks.py index 402863a..49a0812 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -28,6 +28,8 @@ @celery.task() def prepare_payout(dest, amount, symbol): + if config.READ_MODE: + return if (balance := Wallet(symbol).balance) < amount: raise Exception( f"Wallet balance is less than payout amount: {balance} < {amount}" @@ -74,6 +76,8 @@ def payout(steps, symbol): @celery.task() def transfer_trc20_from(onetime_publ_key, symbol): + if config.READ_MODE: + return """ Transfers TRC20 from onetime to main account """ @@ -155,6 +159,8 @@ def transfer_trc20_from(onetime_publ_key, symbol): @celery.task() def transfer_trx_from(onetime_publ_key): + if config.READ_MODE: + return """ Transfers TRX from onetime to main account """ @@ -257,10 +263,11 @@ def precision_of(symbol): .functions.decimals() ) - accounts = [ - row["public"] - for row in query_db('SELECT public FROM keys WHERE type = "onetime"') - ] + query = 'SELECT public FROM keys WHERE type = ?' + accounts_onetime = [row["public"] for row in query_db(query, ("onetime",))] + accounts_read_mode = [row["public"] for row in query_db(query, ("only_read",))] + accounts = accounts_read_mode if config.READ_MODE else accounts_onetime + for index, account in enumerate(accounts, start=1): try: # diff --git a/requirements.txt b/requirements.txt index 4a43e07..2e428e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ pymysql==1.1.1 redis==5.2.1 requests==2.32.3 sqlmodel==0.0.22 +bip_utils==2.8.0 tronpy==0.5.0 \ No newline at end of file diff --git a/run.py b/run.py index 83dc345..4764513 100644 --- a/run.py +++ b/run.py @@ -2,7 +2,6 @@ import app - # # Wallet encryption # @@ -12,6 +11,11 @@ # # Refresh best server # +# +# Flask +# + +server = app.create_app() refresh_best_server_thread = threading.Thread( daemon=True, @@ -20,11 +24,6 @@ ) refresh_best_server_thread.start() -# -# Flask -# - -server = app.create_app() # # Block scanner From 37c2f1c936994d6773c7833d6bd7a040d34bbf94 Mon Sep 17 00:00:00 2001 From: oleh starosvitskyi Date: Thu, 12 Jun 2025 11:07:31 +0200 Subject: [PATCH 2/2] fix calcualte balance --- app/__init__.py | 10 ------ app/tasks.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++--- app/wallet.py | 13 +++++++- run.py | 10 +++--- 4 files changed, 100 insertions(+), 21 deletions(-) diff --git a/app/__init__.py b/app/__init__.py index 97a2470..dd6bbfa 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -22,15 +22,6 @@ sqlite3.register_adapter(decimal.Decimal, lambda x: str(x)) sqlite3.register_converter("DECTEXT", lambda x: decimal.Decimal(x.decode())) -def init_settings_table(db): - create_table_sql = """ - CREATE TABLE IF NOT EXISTS settings ( - name TEXT PRIMARY KEY, - value TEXT - ); - """ - db.query_db2(create_table_sql) - def create_app(): from flask.config import Config @@ -55,7 +46,6 @@ def __dir__(self): from . import db db.init_app(app) - init_settings_table(db) key_type = "only_read" if app.config.READ_MODE else "onetime" rows = db.query_db2(f'SELECT public FROM keys WHERE type = "{key_type}"') diff --git a/app/tasks.py b/app/tasks.py index 49a0812..5e27e61 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -12,6 +12,7 @@ from celery.schedules import crontab from tronpy.keys import PrivateKey from tronpy.tron import current_timestamp +from tronpy import Tron import tronpy.exceptions import requests from sqlmodel import Session, select @@ -263,11 +264,10 @@ def precision_of(symbol): .functions.decimals() ) - query = 'SELECT public FROM keys WHERE type = ?' - accounts_onetime = [row["public"] for row in query_db(query, ("onetime",))] - accounts_read_mode = [row["public"] for row in query_db(query, ("only_read",))] - accounts = accounts_read_mode if config.READ_MODE else accounts_onetime - + accounts = [ + row["public"] + for row in query_db('SELECT public FROM keys WHERE type = "onetime"') + ] for index, account in enumerate(accounts, start=1): try: # @@ -381,9 +381,87 @@ def precision_of(symbol): stats["exception_num"] += 1 return stats +@celery.task(bind=True) +@skip_if_running +def scan_ballance(self, *args, **kwargs): + """ + Scans accounts balances (TRX and TRC20), + """ + from .db import engine + from .models import Balance + from tronpy import Tron + client = Tron() + with Session(engine) as session: + stats = { + "balances": collections.defaultdict(Decimal), + "exception_num": 0, + } + + accounts = [ + row["public"] + for row in query_db('SELECT public FROM keys WHERE type = "only_read_finished"') + ] + + for index, account in enumerate(accounts, start=1): + try: + # === TRX BALANCE === + trx_balance = client.get_account_balance(account) + stats["balances"]["TRX"] += trx_balance + logger.debug(f"[TRX] {account} -> {trx_balance} TRX") + + if config.READ_MODE: + acc_balance = session.exec( + select(Balance).where(Balance.account == account, Balance.symbol == "TRX") + ).first() + if acc_balance: + acc_balance.balance = trx_balance + else: + acc_balance = Balance( + account=account, + symbol="TRX", + balance=trx_balance, + ) + session.add(acc_balance) + session.commit() + + # === TRC20 TOKENS === + for token in config.get_tokens(): + symbol = token.symbol + contract_address = token.contract_address + contract = client.get_contract(contract_address) + + # balanceOf returns balance in raw token units + raw_balance = contract.functions.balanceOf(account) + trc20_balance = Decimal(raw_balance) / (10 ** token.decimal) + + stats["balances"][symbol] += trc20_balance + logger.debug(f"[{symbol}] {account} -> {trc20_balance} tokens") + + if config.READ_MODE: + acc_balance = session.exec( + select(Balance).where(Balance.account == account, Balance.symbol == symbol) + ).first() + if acc_balance: + acc_balance.balance = trc20_balance + else: + acc_balance = Balance( + account=account, + symbol=symbol, + balance=trc20_balance, + ) + session.add(acc_balance) + session.commit() + + except Exception as e: + stats["exception_num"] += 1 + logger.warning(f"[ERROR] {account} scan error: {e}") + + return stats @celery.on_after_configure.connect def setup_periodic_tasks(sender, **kwargs): + if config.READ_MODE: + sender.add_periodic_task(config.BALANCES_RESCAN_PERIOD, scan_ballance.s()) if config.EXTERNAL_DRAIN_CONFIG: from .custom.aml.tasks import sweep_accounts, recheck_transactions diff --git a/app/wallet.py b/app/wallet.py index c91cf37..f09efa7 100644 --- a/app/wallet.py +++ b/app/wallet.py @@ -3,8 +3,12 @@ import tronpy.exceptions from tronpy.keys import PrivateKey +from sqlalchemy import func +from sqlmodel import Session, select +from .models import Balance from .config import config from .db import query_db2 +from .db import engine from .logging import logger from .connection_manager import ConnectionManager from .wallet_encryption import wallet_encryption @@ -38,7 +42,14 @@ def get_contract(self, contract_address=None): @property def balance(self): - return self.balance_of(self.main_account["public"]) + if config.READ_MODE: + with Session(engine) as session: + total_sum = session.exec(select(func.sum(Balance.balance))).one() + raw_balance = total_sum if total_sum is not None else Decimal(0) + overall_balance = float(round(raw_balance, 6)) + return overall_balance + else: + return self.balance_of(self.main_account["public"]) def balance_of(self, address): if self.symbol == "TRX": diff --git a/run.py b/run.py index 4764513..e8d2804 100644 --- a/run.py +++ b/run.py @@ -11,11 +11,6 @@ # # Refresh best server # -# -# Flask -# - -server = app.create_app() refresh_best_server_thread = threading.Thread( daemon=True, @@ -24,6 +19,11 @@ ) refresh_best_server_thread.start() +# +# Flask +# + +server = app.create_app() # # Block scanner