diff --git a/src/dexorder/accounting.py b/src/dexorder/accounting.py index 29152ed..14a39d6 100644 --- a/src/dexorder/accounting.py +++ b/src/dexorder/accounting.py @@ -33,32 +33,40 @@ class ReconciliationException(Exception): pass +def accounting_lock(): + """ + This must be called before accounting_*() calls are made. + """ + db.session.execute(text("LOCK TABLE account, accounting, reconciliation IN EXCLUSIVE MODE")) + + async def initialize_accounting(): global accounting_initialized if not accounting_initialized: - await initialize_mark_to_market() # set up mark-to-market first, so accounts can value their initial balances - await initialize_accounts() + accounting_lock() + await _initialize_mark_to_market() # set up mark-to-market first, so accounts can value their initial balances + await _initialize_accounts() accounting_initialized = True log.info(f'accounting initialized\n\tstablecoins: {config.stablecoins}\n\tquotecoins: {config.quotecoins}\n\tnativecoin: {config.nativecoin}') -async def initialize_accounts(): +async def _initialize_accounts(): # Since this is called by top-level main functions outside the Runner, we trigger an explicit db commit/rollback try: # noinspection PyStatementEffect - await initialize_accounts_2() + await _initialize_accounts_2() db.session.commit() except: db.session.rollback() raise -async def initialize_accounts_2(): +async def _initialize_accounts_2(): fm = await FeeManager.get() - of_account = ensure_account(fm.order_fee_account_addr, AccountKind.OrderFee) - gf_account = ensure_account(fm.gas_fee_account_addr, AccountKind.GasFee) - ff_account = ensure_account(fm.fill_fee_account_addr, AccountKind.FillFee) - exe_accounts = [ensure_account(account.address, AccountKind.Execution) for account in Account.all()] + of_account = _ensure_account(fm.order_fee_account_addr, AccountKind.OrderFee) + gf_account = _ensure_account(fm.gas_fee_account_addr, AccountKind.GasFee) + ff_account = _ensure_account(fm.fill_fee_account_addr, AccountKind.FillFee) + exe_accounts = [_ensure_account(account.address, AccountKind.Execution) for account in Account.all()] if current_chain.get().id in [1337, 31337]: log.debug('adjusting debug account balances') await asyncio.gather( @@ -68,7 +76,7 @@ async def initialize_accounts_2(): _tracked_addrs.add(db_account.address) -async def initialize_mark_to_market(): +async def _initialize_mark_to_market(): quotes.clear() quotes.extend(config.stablecoins) quotes.extend(config.quotecoins) @@ -113,22 +121,7 @@ async def initialize_mark_to_market(): add_mark_pool(addr, pool['base'], pool['quote'], pool['fee']) -async def handle_feeaccountschanged(fee_accounts: EventData): - try: - order_fee_account_addr = fee_accounts['args']['orderFeeAccount'] - gas_fee_account_addr = fee_accounts['args']['gasFeeAccount'] - fill_fee_account_addr = fee_accounts['args']['fillFeeAccount'] - except KeyError: - log.warning(f'Could not parse FeeAccountsChanged {fee_accounts}') - return - fm = await FeeManager.get() - fm.order_fee_account_addr = order_fee_account_addr - fm.gas_fee_account_addr = gas_fee_account_addr - fm.fill_fee_account_addr = fill_fee_account_addr - await initialize_accounts_2() - - -def ensure_account(addr: str, kind: AccountKind) -> DbAccount: +def _ensure_account(addr: str, kind: AccountKind) -> DbAccount: chain = current_chain.get() found = db.session.get(DbAccount, (chain, addr)) if found: @@ -144,6 +137,21 @@ def ensure_account(addr: str, kind: AccountKind) -> DbAccount: return found +async def handle_feeaccountschanged(fee_accounts: EventData): + try: + order_fee_account_addr = fee_accounts['args']['orderFeeAccount'] + gas_fee_account_addr = fee_accounts['args']['gasFeeAccount'] + fill_fee_account_addr = fee_accounts['args']['fillFeeAccount'] + except KeyError: + log.warning(f'Could not parse FeeAccountsChanged {fee_accounts}') + return + fm = await FeeManager.get() + fm.order_fee_account_addr = order_fee_account_addr + fm.gas_fee_account_addr = gas_fee_account_addr + fm.fill_fee_account_addr = fill_fee_account_addr + await _initialize_accounts_2() + + async def accounting_transfer(receipt: TransactionReceiptDict, token: str, sender: str, receiver: str, amount: Union[dec,int], adjust_decimals=True): block_hash = hexstr(receipt['blockHash']) @@ -224,10 +232,7 @@ async def adjust_balance(account: DbAccount, token=NATIVE_TOKEN, subcategory=Acc await add_accounting_row(account.address, None, None, AccountingCategory.Special, subcategory, NATIVE_TOKEN, amount, note, adjust_decimals=False) -async def reconcile(account: DbAccount, block_id: Optional[str] = None, last_accounting_row_id: Optional[int] = None): - # First we lock all the relevant tables to ensure consistency - db.session.execute(text("LOCK TABLE account, accounting, reconciliation IN EXCLUSIVE MODE")) - +async def accounting_reconcile(account: DbAccount, block_id: Optional[str] = None, last_accounting_row_id: Optional[int] = None): # Fetch the latest reconciliation for the account latest_recon = db.session.execute( select(Reconciliation).where( diff --git a/src/dexorder/bin/reconcile.py b/src/dexorder/bin/reconcile.py index e12b478..ec98ea9 100644 --- a/src/dexorder/bin/reconcile.py +++ b/src/dexorder/bin/reconcile.py @@ -3,7 +3,7 @@ import logging from sqlalchemy import select from dexorder import db, blockchain -from dexorder.accounting import reconcile +from dexorder.accounting import accounting_reconcile, accounting_lock from dexorder.bin.executable import execute from dexorder.blocks import fetch_latest_block, current_block from dexorder.database.model import DbAccount @@ -15,10 +15,11 @@ async def main(): db.connect() block = await fetch_latest_block() current_block.set(block) + accounting_lock() try: accounts = db.session.execute(select(DbAccount)).scalars().all() for account in accounts: - await reconcile(account) + await accounting_reconcile(account) db.session.commit() log.info('Reconciliation complete') except: diff --git a/src/dexorder/event_handler.py b/src/dexorder/event_handler.py index 229275c..d1a9e3a 100644 --- a/src/dexorder/event_handler.py +++ b/src/dexorder/event_handler.py @@ -4,7 +4,8 @@ import logging from web3.types import EventData from dexorder import db, metric -from dexorder.accounting import accounting_fill, accounting_placement, accounting_transfer, is_tracked_address +from dexorder.accounting import accounting_fill, accounting_placement, accounting_transfer, is_tracked_address, \ + accounting_lock from dexorder.base.chain import current_chain from dexorder.base.order import TrancheKey, OrderKey from dexorder.base.orderlib import SwapOrderState @@ -31,6 +32,7 @@ def dump_log(eventlog): def init(): new_pool_prices.clear() start_trigger_updates() + accounting_lock() async def handle_order_placed(event: EventData):