diff --git a/app/__init__.py b/app/__init__.py index 029ae7d..dd6bbfa 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -22,7 +22,6 @@ sqlite3.register_adapter(decimal.Decimal, lambda x: str(x)) sqlite3.register_converter("DECTEXT", lambda x: decimal.Decimal(x.decode())) - def create_app(): from flask.config import Config @@ -47,13 +46,12 @@ def __dir__(self): from . import db db.init_app(app) + key_type = "only_read" if app.config.READ_MODE else "onetime" + + rows = db.query_db2(f'SELECT public FROM keys WHERE type = "{key_type}"') + accounts = [row["public"] for row in rows] - block_scanner.BlockScanner.set_watched_accounts( - [ - row["public"] - for row in db.query_db2('select public from keys where type = "onetime"') - ] - ) + block_scanner.BlockScanner.set_watched_accounts(accounts) from . import utils diff --git a/app/api/services/__init__.py b/app/api/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/services/address_service.py b/app/api/services/address_service.py new file mode 100644 index 0000000..20e7641 --- /dev/null +++ b/app/api/services/address_service.py @@ -0,0 +1,38 @@ +from bip_utils import Bip32PublicKey, Bip44, Bip44Coins, Bip44Changes +from Crypto.Hash import keccak +import base58 +from ...db import get_db +from ...logging import logger +from ...wallet_encryption import wallet_encryption +from ...block_scanner import BlockScanner + + +def generate_address_from_xpub(symbol, xpub_str): + db = get_db() + bip44_acc = Bip44.FromExtendedKey(xpub_str, Bip44Coins.TRON) + bip44_change = bip44_acc.Change(Bip44Changes.CHAIN_EXT) + count = db.execute("SELECT COUNT(*) FROM keys").fetchone()[0] + address = bip44_change.AddressIndex(count).PublicKey().ToAddress() + db.execute( + "INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'only_read')", + (symbol, address, ''), + ) + db.commit() + BlockScanner.add_watched_account(address) + return address + + +def generate_address_with_private_key(symbol, client): + db = get_db() + addresses = client.generate_address() + public_address = addresses["base58check_address"] + encrypted_priv = wallet_encryption.encrypt(addresses["private_key"]) + + db.execute( + "INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'onetime')", + (symbol, public_address, encrypted_priv), + ) + db.commit() + BlockScanner.add_watched_account(public_address) + + return public_address \ No newline at end of file diff --git a/app/api/views.py b/app/api/views.py index 5d51367..ae674a2 100644 --- a/app/api/views.py +++ b/app/api/views.py @@ -5,6 +5,7 @@ import tronpy.exceptions from flask import current_app, g +from flask import request from tronpy import Tron from ..db import get_db, query_db @@ -15,31 +16,22 @@ from ..connection_manager import ConnectionManager from . import api from ..wallet_encryption import wallet_encryption +from ..config import config +from .services.address_service import generate_address_from_xpub, generate_address_with_private_key @api.post("/generate-address") def generate_new_address(): - client = Tron() - addresses = client.generate_address() - - db = get_db() - db.execute( - "INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'onetime')", - ( - g.symbol, - addresses["base58check_address"], - wallet_encryption.encrypt(addresses["private_key"]), - ), - ) - db.commit() - - BlockScanner.add_watched_account(addresses["base58check_address"]) - - return { - "status": "success", - "base58check_address": addresses["base58check_address"], - } - + symbol = g.symbol + if config.READ_MODE: + data = request.get_json(silent=True) or {} + xpub_str = data.get("xpub") + address = generate_address_from_xpub(symbol, xpub_str) + return {"status": "success", "base58check_address": address} + else: + client = Tron() + address = generate_address_with_private_key(symbol, client) + return {"status": "success", "base58check_address": address} @api.post("/balance") def get_balance(): diff --git a/app/block_scanner.py b/app/block_scanner.py index af0207a..123cc28 100644 --- a/app/block_scanner.py +++ b/app/block_scanner.py @@ -162,6 +162,18 @@ def download_tx_info_by_block_num(self, n): result["id"]: result for result in transaction_results if "log" in result } + def mark_key_as_finished(self, dst_addr): + row = query_db2( + 'SELECT 1 FROM keys WHERE public = ? AND type = "only_read"', + (dst_addr,), + one=True + ) + if row: + query_db2( + 'UPDATE keys SET type = "only_read_finished" WHERE public = ? AND type = "only_read"', + (dst_addr,) + ) + def notify_shkeeper(self, symbol, txid): url = f"http://{config.SHKEEPER_HOST}/api/v1/walletnotify/{symbol}/{txid}" headers = {"X-Shkeeper-Backend-Key": config.SHKEEPER_BACKEND_KEY} @@ -279,6 +291,7 @@ def scan(self, block_num: int) -> bool: if tron_tx.dst_addr in valid_addresses: if tron_tx.status == "SUCCESS": logger.info(f"Sending notification for {tron_tx}") + self.mark_key_as_finished(tron_tx.dst_addr) self.notify_shkeeper(tron_tx.symbol.value, tron_tx.txid) # Send funds to main account if tron_tx.is_trc20: diff --git a/app/config.py b/app/config.py index 6d52d1a..a7376eb 100644 --- a/app/config.py +++ b/app/config.py @@ -1,9 +1,11 @@ from decimal import Decimal from functools import cache from typing import List +import os -from pydantic import Field, Json, field_validator +from pydantic import Field, Json, field_validator, validator from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import ClassVar from .custom.aml.schemas import ExternalDrain from .schemas import TronFullnode, TronNetwork, Token, TronSymbol @@ -54,6 +56,13 @@ class Settings(BaseSettings): AML_RESULT_UPDATE_PERIOD: int = 120 AML_SWEEP_ACCOUNTS_PERIOD: int = 3600 AML_WAIT_BEFORE_API_CALL: int = 320 + READ_MODE: bool = False + + @validator("READ_MODE", pre=True) + def parse_read_mode(cls, v): + if isinstance(v, str): + return v.lower() in ("true", "1", "yes", "enabled") + return bool(v) TOKENS: List[Token] = [ Token( diff --git a/app/schemas.py b/app/schemas.py index 555bad4..e607c1e 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -14,6 +14,8 @@ class KeyType(str, Enum): fee_deposit = "fee_deposit" onetime = "onetime" + only_read = "only_read" + only_read_finished = "only_read_finished" class TronNetwork(str, Enum): diff --git a/app/tasks.py b/app/tasks.py index 402863a..5e27e61 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -12,6 +12,7 @@ from celery.schedules import crontab from tronpy.keys import PrivateKey from tronpy.tron import current_timestamp +from tronpy import Tron import tronpy.exceptions import requests from sqlmodel import Session, select @@ -28,6 +29,8 @@ @celery.task() def prepare_payout(dest, amount, symbol): + if config.READ_MODE: + return if (balance := Wallet(symbol).balance) < amount: raise Exception( f"Wallet balance is less than payout amount: {balance} < {amount}" @@ -74,6 +77,8 @@ def payout(steps, symbol): @celery.task() def transfer_trc20_from(onetime_publ_key, symbol): + if config.READ_MODE: + return """ Transfers TRC20 from onetime to main account """ @@ -155,6 +160,8 @@ def transfer_trc20_from(onetime_publ_key, symbol): @celery.task() def transfer_trx_from(onetime_publ_key): + if config.READ_MODE: + return """ Transfers TRX from onetime to main account """ @@ -374,9 +381,87 @@ def precision_of(symbol): stats["exception_num"] += 1 return stats +@celery.task(bind=True) +@skip_if_running +def scan_ballance(self, *args, **kwargs): + """ + Scans accounts balances (TRX and TRC20), + """ + from .db import engine + from .models import Balance + from tronpy import Tron + client = Tron() + with Session(engine) as session: + stats = { + "balances": collections.defaultdict(Decimal), + "exception_num": 0, + } + + accounts = [ + row["public"] + for row in query_db('SELECT public FROM keys WHERE type = "only_read_finished"') + ] + + for index, account in enumerate(accounts, start=1): + try: + # === TRX BALANCE === + trx_balance = client.get_account_balance(account) + stats["balances"]["TRX"] += trx_balance + logger.debug(f"[TRX] {account} -> {trx_balance} TRX") + + if config.READ_MODE: + acc_balance = session.exec( + select(Balance).where(Balance.account == account, Balance.symbol == "TRX") + ).first() + if acc_balance: + acc_balance.balance = trx_balance + else: + acc_balance = Balance( + account=account, + symbol="TRX", + balance=trx_balance, + ) + session.add(acc_balance) + session.commit() + + # === TRC20 TOKENS === + for token in config.get_tokens(): + symbol = token.symbol + contract_address = token.contract_address + contract = client.get_contract(contract_address) + + # balanceOf returns balance in raw token units + raw_balance = contract.functions.balanceOf(account) + trc20_balance = Decimal(raw_balance) / (10 ** token.decimal) + + stats["balances"][symbol] += trc20_balance + logger.debug(f"[{symbol}] {account} -> {trc20_balance} tokens") + + if config.READ_MODE: + acc_balance = session.exec( + select(Balance).where(Balance.account == account, Balance.symbol == symbol) + ).first() + if acc_balance: + acc_balance.balance = trc20_balance + else: + acc_balance = Balance( + account=account, + symbol=symbol, + balance=trc20_balance, + ) + session.add(acc_balance) + session.commit() + + except Exception as e: + stats["exception_num"] += 1 + logger.warning(f"[ERROR] {account} scan error: {e}") + + return stats @celery.on_after_configure.connect def setup_periodic_tasks(sender, **kwargs): + if config.READ_MODE: + sender.add_periodic_task(config.BALANCES_RESCAN_PERIOD, scan_ballance.s()) if config.EXTERNAL_DRAIN_CONFIG: from .custom.aml.tasks import sweep_accounts, recheck_transactions diff --git a/app/wallet.py b/app/wallet.py index c91cf37..f09efa7 100644 --- a/app/wallet.py +++ b/app/wallet.py @@ -3,8 +3,12 @@ import tronpy.exceptions from tronpy.keys import PrivateKey +from sqlalchemy import func +from sqlmodel import Session, select +from .models import Balance from .config import config from .db import query_db2 +from .db import engine from .logging import logger from .connection_manager import ConnectionManager from .wallet_encryption import wallet_encryption @@ -38,7 +42,14 @@ def get_contract(self, contract_address=None): @property def balance(self): - return self.balance_of(self.main_account["public"]) + if config.READ_MODE: + with Session(engine) as session: + total_sum = session.exec(select(func.sum(Balance.balance))).one() + raw_balance = total_sum if total_sum is not None else Decimal(0) + overall_balance = float(round(raw_balance, 6)) + return overall_balance + else: + return self.balance_of(self.main_account["public"]) def balance_of(self, address): if self.symbol == "TRX": diff --git a/requirements.txt b/requirements.txt index 4a43e07..2e428e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ pymysql==1.1.1 redis==5.2.1 requests==2.32.3 sqlmodel==0.0.22 +bip_utils==2.8.0 tronpy==0.5.0 \ No newline at end of file diff --git a/run.py b/run.py index 83dc345..e8d2804 100644 --- a/run.py +++ b/run.py @@ -2,7 +2,6 @@ import app - # # Wallet encryption #