Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
sqlite3.register_adapter(decimal.Decimal, lambda x: str(x))
sqlite3.register_converter("DECTEXT", lambda x: decimal.Decimal(x.decode()))


def create_app():

from flask.config import Config
Expand All @@ -47,13 +46,12 @@ def __dir__(self):
from . import db

db.init_app(app)
key_type = "only_read" if app.config.READ_MODE else "onetime"

rows = db.query_db2(f'SELECT public FROM keys WHERE type = "{key_type}"')
accounts = [row["public"] for row in rows]

block_scanner.BlockScanner.set_watched_accounts(
[
row["public"]
for row in db.query_db2('select public from keys where type = "onetime"')
]
)
block_scanner.BlockScanner.set_watched_accounts(accounts)

from . import utils

Expand Down
Empty file added app/api/services/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions app/api/services/address_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from bip_utils import Bip32PublicKey, Bip44, Bip44Coins, Bip44Changes
from Crypto.Hash import keccak
import base58
from ...db import get_db
from ...logging import logger
from ...wallet_encryption import wallet_encryption
from ...block_scanner import BlockScanner


def generate_address_from_xpub(symbol, xpub_str):
db = get_db()
bip44_acc = Bip44.FromExtendedKey(xpub_str, Bip44Coins.TRON)
bip44_change = bip44_acc.Change(Bip44Changes.CHAIN_EXT)
count = db.execute("SELECT COUNT(*) FROM keys").fetchone()[0]
address = bip44_change.AddressIndex(count).PublicKey().ToAddress()
db.execute(
"INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'only_read')",
(symbol, address, ''),
)
db.commit()
BlockScanner.add_watched_account(address)
return address


def generate_address_with_private_key(symbol, client):
db = get_db()
addresses = client.generate_address()
public_address = addresses["base58check_address"]
encrypted_priv = wallet_encryption.encrypt(addresses["private_key"])

db.execute(
"INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'onetime')",
(symbol, public_address, encrypted_priv),
)
db.commit()
BlockScanner.add_watched_account(public_address)

return public_address
34 changes: 13 additions & 21 deletions app/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import tronpy.exceptions
from flask import current_app, g
from flask import request
from tronpy import Tron

from ..db import get_db, query_db
Expand All @@ -15,31 +16,22 @@
from ..connection_manager import ConnectionManager
from . import api
from ..wallet_encryption import wallet_encryption
from ..config import config
from .services.address_service import generate_address_from_xpub, generate_address_with_private_key


@api.post("/generate-address")
def generate_new_address():
client = Tron()
addresses = client.generate_address()

db = get_db()
db.execute(
"INSERT INTO keys (symbol, public, private, type) VALUES (?, ?, ?, 'onetime')",
(
g.symbol,
addresses["base58check_address"],
wallet_encryption.encrypt(addresses["private_key"]),
),
)
db.commit()

BlockScanner.add_watched_account(addresses["base58check_address"])

return {
"status": "success",
"base58check_address": addresses["base58check_address"],
}

symbol = g.symbol
if config.READ_MODE:
data = request.get_json(silent=True) or {}
xpub_str = data.get("xpub")
address = generate_address_from_xpub(symbol, xpub_str)
return {"status": "success", "base58check_address": address}
else:
client = Tron()
address = generate_address_with_private_key(symbol, client)
return {"status": "success", "base58check_address": address}

@api.post("/balance")
def get_balance():
Expand Down
13 changes: 13 additions & 0 deletions app/block_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ def download_tx_info_by_block_num(self, n):
result["id"]: result for result in transaction_results if "log" in result
}

def mark_key_as_finished(self, dst_addr):
row = query_db2(
'SELECT 1 FROM keys WHERE public = ? AND type = "only_read"',
(dst_addr,),
one=True
)
if row:
query_db2(
'UPDATE keys SET type = "only_read_finished" WHERE public = ? AND type = "only_read"',
(dst_addr,)
)

def notify_shkeeper(self, symbol, txid):
url = f"http://{config.SHKEEPER_HOST}/api/v1/walletnotify/{symbol}/{txid}"
headers = {"X-Shkeeper-Backend-Key": config.SHKEEPER_BACKEND_KEY}
Expand Down Expand Up @@ -279,6 +291,7 @@ def scan(self, block_num: int) -> bool:
if tron_tx.dst_addr in valid_addresses:
if tron_tx.status == "SUCCESS":
logger.info(f"Sending notification for {tron_tx}")
self.mark_key_as_finished(tron_tx.dst_addr)
self.notify_shkeeper(tron_tx.symbol.value, tron_tx.txid)
# Send funds to main account
if tron_tx.is_trc20:
Expand Down
11 changes: 10 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from decimal import Decimal
from functools import cache
from typing import List
import os

from pydantic import Field, Json, field_validator
from pydantic import Field, Json, field_validator, validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import ClassVar

from .custom.aml.schemas import ExternalDrain
from .schemas import TronFullnode, TronNetwork, Token, TronSymbol
Expand Down Expand Up @@ -54,6 +56,13 @@ class Settings(BaseSettings):
AML_RESULT_UPDATE_PERIOD: int = 120
AML_SWEEP_ACCOUNTS_PERIOD: int = 3600
AML_WAIT_BEFORE_API_CALL: int = 320
READ_MODE: bool = False

@validator("READ_MODE", pre=True)
def parse_read_mode(cls, v):
if isinstance(v, str):
return v.lower() in ("true", "1", "yes", "enabled")
return bool(v)

TOKENS: List[Token] = [
Token(
Expand Down
2 changes: 2 additions & 0 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
class KeyType(str, Enum):
fee_deposit = "fee_deposit"
onetime = "onetime"
only_read = "only_read"
only_read_finished = "only_read_finished"


class TronNetwork(str, Enum):
Expand Down
85 changes: 85 additions & 0 deletions app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from celery.schedules import crontab
from tronpy.keys import PrivateKey
from tronpy.tron import current_timestamp
from tronpy import Tron
import tronpy.exceptions
import requests
from sqlmodel import Session, select
Expand All @@ -28,6 +29,8 @@

@celery.task()
def prepare_payout(dest, amount, symbol):
if config.READ_MODE:
return
if (balance := Wallet(symbol).balance) < amount:
raise Exception(
f"Wallet balance is less than payout amount: {balance} < {amount}"
Expand Down Expand Up @@ -74,6 +77,8 @@ def payout(steps, symbol):

@celery.task()
def transfer_trc20_from(onetime_publ_key, symbol):
if config.READ_MODE:
return
"""
Transfers TRC20 from onetime to main account
"""
Expand Down Expand Up @@ -155,6 +160,8 @@ def transfer_trc20_from(onetime_publ_key, symbol):

@celery.task()
def transfer_trx_from(onetime_publ_key):
if config.READ_MODE:
return
"""
Transfers TRX from onetime to main account
"""
Expand Down Expand Up @@ -374,9 +381,87 @@ def precision_of(symbol):
stats["exception_num"] += 1
return stats

@celery.task(bind=True)
@skip_if_running
def scan_ballance(self, *args, **kwargs):
"""
Scans accounts balances (TRX and TRC20),
"""
from .db import engine
from .models import Balance
from tronpy import Tron
client = Tron()
with Session(engine) as session:
stats = {
"balances": collections.defaultdict(Decimal),
"exception_num": 0,
}

accounts = [
row["public"]
for row in query_db('SELECT public FROM keys WHERE type = "only_read_finished"')
]

for index, account in enumerate(accounts, start=1):
try:
# === TRX BALANCE ===
trx_balance = client.get_account_balance(account)
stats["balances"]["TRX"] += trx_balance
logger.debug(f"[TRX] {account} -> {trx_balance} TRX")

if config.READ_MODE:
acc_balance = session.exec(
select(Balance).where(Balance.account == account, Balance.symbol == "TRX")
).first()
if acc_balance:
acc_balance.balance = trx_balance
else:
acc_balance = Balance(
account=account,
symbol="TRX",
balance=trx_balance,
)
session.add(acc_balance)
session.commit()

# === TRC20 TOKENS ===
for token in config.get_tokens():
symbol = token.symbol
contract_address = token.contract_address
contract = client.get_contract(contract_address)

# balanceOf returns balance in raw token units
raw_balance = contract.functions.balanceOf(account)
trc20_balance = Decimal(raw_balance) / (10 ** token.decimal)

stats["balances"][symbol] += trc20_balance
logger.debug(f"[{symbol}] {account} -> {trc20_balance} tokens")

if config.READ_MODE:
acc_balance = session.exec(
select(Balance).where(Balance.account == account, Balance.symbol == symbol)
).first()
if acc_balance:
acc_balance.balance = trc20_balance
else:
acc_balance = Balance(
account=account,
symbol=symbol,
balance=trc20_balance,
)
session.add(acc_balance)
session.commit()

except Exception as e:
stats["exception_num"] += 1
logger.warning(f"[ERROR] {account} scan error: {e}")

return stats

@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
if config.READ_MODE:
sender.add_periodic_task(config.BALANCES_RESCAN_PERIOD, scan_ballance.s())
if config.EXTERNAL_DRAIN_CONFIG:
from .custom.aml.tasks import sweep_accounts, recheck_transactions

Expand Down
13 changes: 12 additions & 1 deletion app/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import tronpy.exceptions
from tronpy.keys import PrivateKey

from sqlalchemy import func
from sqlmodel import Session, select
from .models import Balance
from .config import config
from .db import query_db2
from .db import engine
from .logging import logger
from .connection_manager import ConnectionManager
from .wallet_encryption import wallet_encryption
Expand Down Expand Up @@ -38,7 +42,14 @@ def get_contract(self, contract_address=None):

@property
def balance(self):
return self.balance_of(self.main_account["public"])
if config.READ_MODE:
with Session(engine) as session:
total_sum = session.exec(select(func.sum(Balance.balance))).one()
raw_balance = total_sum if total_sum is not None else Decimal(0)
overall_balance = float(round(raw_balance, 6))
return overall_balance
else:
return self.balance_of(self.main_account["public"])

def balance_of(self, address):
if self.symbol == "TRX":
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pymysql==1.1.1
redis==5.2.1
requests==2.32.3
sqlmodel==0.0.22
bip_utils==2.8.0
tronpy==0.5.0
1 change: 0 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import app


#
# Wallet encryption
#
Expand Down