refactor: Make db mutex non-recursive.

master^2
tecnovert 7 months ago
parent ae1df0b556
commit 5f6819afcb
  1. 2
      basicswap/base.py
  2. 601
      basicswap/basicswap.py
  3. 4
      basicswap/db_upgrades.py
  4. 2
      basicswap/protocols/atomic_swap_1.py
  5. 9
      basicswap/protocols/xmr_swap_1.py
  6. 5
      tests/basicswap/extended/test_dcr.py
  7. 10
      tests/basicswap/test_xmr.py

@ -46,7 +46,7 @@ class BaseApp:
self.settings = settings self.settings = settings
self.coin_clients = {} self.coin_clients = {}
self.coin_interfaces = {} self.coin_interfaces = {}
self.mxDB = threading.RLock() self.mxDB = threading.Lock()
self.debug = self.settings.get('debug', False) self.debug = self.settings.get('debug', False)
self.delay_event = threading.Event() self.delay_event = threading.Event()
self.chainstate_delay_event = threading.Event() self.chainstate_delay_event = threading.Event()

File diff suppressed because it is too large Load Diff

@ -93,7 +93,7 @@ def upgradeDatabaseData(self, data_version):
created_at=now)) created_at=now))
self.db_data_version = CURRENT_DB_DATA_VERSION self.db_data_version = CURRENT_DB_DATA_VERSION
self.setIntKVInSession('db_data_version', self.db_data_version, session) self.setIntKV('db_data_version', self.db_data_version, session)
session.commit() session.commit()
self.log.info('Upgraded database records to version {}'.format(self.db_data_version)) self.log.info('Upgraded database records to version {}'.format(self.db_data_version))
finally: finally:
@ -314,7 +314,7 @@ def upgradeDatabase(self, db_version):
session.execute('ALTER TABLE bids ADD COLUMN pkhash_buyer_to BLOB') session.execute('ALTER TABLE bids ADD COLUMN pkhash_buyer_to BLOB')
if current_version != db_version: if current_version != db_version:
self.db_version = db_version self.db_version = db_version
self.setIntKVInSession('db_version', db_version, session) self.setIntKV('db_version', db_version, session)
session.commit() session.commit()
session.close() session.close()
session.remove() session.remove()

@ -105,7 +105,7 @@ def redeemITx(self, bid_id: bytes, session):
bid, offer = self.getBidAndOffer(bid_id, session) bid, offer = self.getBidAndOffer(bid_id, session)
ci_from = self.ci(offer.coin_from) ci_from = self.ci(offer.coin_from)
txn = self.createRedeemTxn(ci_from.coin_type(), bid, for_txn_type='initiate') txn = self.createRedeemTxn(ci_from.coin_type(), bid, for_txn_type='initiate', session=session)
txid = ci_from.publishTx(bytes.fromhex(txn)) txid = ci_from.publishTx(bytes.fromhex(txn))
bid.initiate_tx.spend_txid = bytes.fromhex(txid) bid.initiate_tx.spend_txid = bytes.fromhex(txid)

@ -1,11 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright (c) 2020-2023 tecnovert # Copyright (c) 2020-2024 tecnovert
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php. # file LICENSE or http://www.opensource.org/licenses/mit-license.php.
from sqlalchemy.orm import scoped_session
from basicswap.util import ( from basicswap.util import (
ensure, ensure,
) )
@ -45,7 +43,7 @@ def addLockRefundSigs(self, xmr_swap, ci):
def recoverNoScriptTxnWithKey(self, bid_id: bytes, encoded_key): def recoverNoScriptTxnWithKey(self, bid_id: bytes, encoded_key):
self.log.info('Manually recovering %s', bid_id.hex()) self.log.info('Manually recovering %s', bid_id.hex())
# Manually recover txn if other key is known # Manually recover txn if other key is known
session = scoped_session(self.session_factory) session = self.openSession()
try: try:
bid, xmr_swap = self.getXmrBidFromSession(session, bid_id) bid, xmr_swap = self.getXmrBidFromSession(session, bid_id)
ensure(bid, 'Bid not found: {}.'.format(bid_id.hex())) ensure(bid, 'Bid not found: {}.'.format(bid_id.hex()))
@ -86,8 +84,7 @@ def recoverNoScriptTxnWithKey(self, bid_id: bytes, encoded_key):
return txid return txid
finally: finally:
session.close() self.closeSession(session, commit=False)
session.remove()
def getChainBSplitKey(swap_client, bid, xmr_swap, offer): def getChainBSplitKey(swap_client, bid, xmr_swap, offer):

@ -345,9 +345,8 @@ def run_test_ads_both_refund(self, coin_from: Coins, coin_to: Coins, lock_value:
ci_from = swap_clients[id_offerer].ci(coin_from) ci_from = swap_clients[id_offerer].ci(coin_from)
ci_to = swap_clients[id_offerer].ci(coin_to) ci_to = swap_clients[id_offerer].ci(coin_to)
if reverse_bid: self.prepare_balance(coin_to, 100.0, 1801, 1800)
self.prepare_balance(coin_to, 100.0, 1801, 1800) self.prepare_balance(coin_from, 100.0, 1800, 1801)
self.prepare_balance(coin_from, 100.0, 1800, 1801)
id_leader: int = id_bidder if reverse_bid else id_offerer id_leader: int = id_bidder if reverse_bid else id_offerer
id_follower: int = id_offerer if reverse_bid else id_bidder id_follower: int = id_offerer if reverse_bid else id_bidder

@ -1388,9 +1388,13 @@ class Test(BaseTest):
js_0 = read_json_api(1800, 'wallets/part') js_0 = read_json_api(1800, 'wallets/part')
node0_blind_before = js_0['blind_balance'] + js_0['blind_unconfirmed'] node0_blind_before = js_0['blind_balance'] + js_0['blind_unconfirmed']
amt_swap = make_int(random.uniform(0.1, 2.0), scale=8, r=1) coin_from = Coins.PART_BLIND
rate_swap = make_int(random.uniform(2.0, 20.0), scale=8, r=1) coin_to = Coins.XMR
offer_id = swap_clients[0].postOffer(Coins.PART_BLIND, Coins.XMR, amt_swap, rate_swap, amt_swap, SwapTypes.XMR_SWAP) ci_from = swap_clients[0].ci(coin_from)
ci_to = swap_clients[0].ci(coin_to)
amt_swap = ci_from.make_int(random.uniform(0.1, 2.0), r=1)
rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1)
offer_id = swap_clients[0].postOffer(coin_from, coin_to, amt_swap, rate_swap, amt_swap, SwapTypes.XMR_SWAP)
wait_for_offer(test_delay_event, swap_clients[1], offer_id) wait_for_offer(test_delay_event, swap_clients[1], offer_id)
offers = swap_clients[0].listOffers(filters={'offer_id': offer_id}) offers = swap_clients[0].listOffers(filters={'offer_id': offer_id})
offer = offers[0] offer = offers[0]

Loading…
Cancel
Save