Load in-progress bids only when unlocked.
This commit is contained in:
		
							parent
							
								
									3234e3fba3
								
							
						
					
					
						commit
						2922b171a6
					
				@ -1,6 +1,6 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2019-2022 tecnovert
 | 
			
		||||
# Copyright (c) 2019-2023 tecnovert
 | 
			
		||||
# Distributed under the MIT software license, see the accompanying
 | 
			
		||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
 | 
			
		||||
 | 
			
		||||
@ -92,7 +92,7 @@ class BaseApp:
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
    def setDaemonPID(self, name, pid):
 | 
			
		||||
    def setDaemonPID(self, name, pid) -> None:
 | 
			
		||||
        if isinstance(name, Coins):
 | 
			
		||||
            self.coin_clients[name]['pid'] = pid
 | 
			
		||||
            return
 | 
			
		||||
@ -100,12 +100,12 @@ class BaseApp:
 | 
			
		||||
            if v['name'] == name:
 | 
			
		||||
                v['pid'] = pid
 | 
			
		||||
 | 
			
		||||
    def getChainDatadirPath(self, coin):
 | 
			
		||||
    def getChainDatadirPath(self, coin) -> str:
 | 
			
		||||
        datadir = self.coin_clients[coin]['datadir']
 | 
			
		||||
        testnet_name = '' if self.chain == 'mainnet' else chainparams[coin][self.chain].get('name', self.chain)
 | 
			
		||||
        return os.path.join(datadir, testnet_name)
 | 
			
		||||
 | 
			
		||||
    def getCoinIdFromName(self, coin_name):
 | 
			
		||||
    def getCoinIdFromName(self, coin_name: str):
 | 
			
		||||
        for c, params in chainparams.items():
 | 
			
		||||
            if coin_name.lower() == params['name'].lower():
 | 
			
		||||
                return c
 | 
			
		||||
@ -146,7 +146,7 @@ class BaseApp:
 | 
			
		||||
            raise ValueError('CLI error ' + str(out[1]))
 | 
			
		||||
        return out[0].decode('utf-8').strip()
 | 
			
		||||
 | 
			
		||||
    def is_transient_error(self, ex):
 | 
			
		||||
    def is_transient_error(self, ex) -> bool:
 | 
			
		||||
        if isinstance(ex, TemporaryError):
 | 
			
		||||
            return True
 | 
			
		||||
        str_error = str(ex).lower()
 | 
			
		||||
@ -164,13 +164,13 @@ class BaseApp:
 | 
			
		||||
 | 
			
		||||
        socket.setdefaulttimeout(timeout)
 | 
			
		||||
 | 
			
		||||
    def popConnectionParameters(self):
 | 
			
		||||
    def popConnectionParameters(self) -> None:
 | 
			
		||||
        if self.use_tor_proxy:
 | 
			
		||||
            socket.socket = self.default_socket
 | 
			
		||||
            socket.getaddrinfo = self.default_socket_getaddrinfo
 | 
			
		||||
        socket.setdefaulttimeout(self.default_socket_timeout)
 | 
			
		||||
 | 
			
		||||
    def logException(self, message):
 | 
			
		||||
    def logException(self, message) -> None:
 | 
			
		||||
        self.log.error(message)
 | 
			
		||||
        if self.debug:
 | 
			
		||||
            self.log.error(traceback.format_exc())
 | 
			
		||||
 | 
			
		||||
@ -216,6 +216,7 @@ class WatchedTransaction():
 | 
			
		||||
 | 
			
		||||
class BasicSwap(BaseApp):
 | 
			
		||||
    ws_server = None
 | 
			
		||||
    _read_zmq_queue: bool = True
 | 
			
		||||
    protocolInterfaces = {
 | 
			
		||||
        SwapTypes.SELLER_FIRST: atomic_swap_1.AtomicSwapInterface(),
 | 
			
		||||
        SwapTypes.XMR_SWAP: xmr_swap_1.XmrSwapInterface(),
 | 
			
		||||
@ -696,7 +697,41 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            self._network = bsn.Network(self.settings['p2p_host'], self.settings['p2p_port'], network_key, self)
 | 
			
		||||
            self._network.startNetwork()
 | 
			
		||||
 | 
			
		||||
        self.initialise()
 | 
			
		||||
        self.log.debug('network_key %s\nnetwork_pubkey %s\nnetwork_addr %s',
 | 
			
		||||
                       self.network_key, self.network_pubkey, self.network_addr)
 | 
			
		||||
 | 
			
		||||
        ro = self.callrpc('smsglocalkeys')
 | 
			
		||||
        found = False
 | 
			
		||||
        for k in ro['smsg_keys']:
 | 
			
		||||
            if k['address'] == self.network_addr:
 | 
			
		||||
                found = True
 | 
			
		||||
                break
 | 
			
		||||
        if not found:
 | 
			
		||||
            self.log.info('Importing network key to SMSG')
 | 
			
		||||
            self.callrpc('smsgimportprivkey', [self.network_key, 'basicswap offers'])
 | 
			
		||||
            ro = self.callrpc('smsglocalkeys', ['anon', '-', self.network_addr])
 | 
			
		||||
            ensure(ro['result'] == 'Success.', 'smsglocalkeys failed')
 | 
			
		||||
 | 
			
		||||
        # TODO: Ensure smsg is enabled for the active wallet.
 | 
			
		||||
 | 
			
		||||
        # Initialise locked state
 | 
			
		||||
        _, _ = self.getLockedState()
 | 
			
		||||
 | 
			
		||||
        # Re-load in-progress bids
 | 
			
		||||
        self.loadFromDB()
 | 
			
		||||
 | 
			
		||||
        # Scan inbox
 | 
			
		||||
        # TODO: Redundant? small window for zmq messages to go unnoticed during startup?
 | 
			
		||||
        # options = {'encoding': 'hex'}
 | 
			
		||||
        options = {'encoding': 'none'}
 | 
			
		||||
        ro = self.callrpc('smsginbox', ['unread', '', options])
 | 
			
		||||
        nm = 0
 | 
			
		||||
        for msg in ro['messages']:
 | 
			
		||||
            # TODO: Remove workaround for smsginbox bug
 | 
			
		||||
            get_msg = self.callrpc('smsg', [msg['msgid'], {'encoding': 'hex', 'setread': True}])
 | 
			
		||||
            self.processMsg(get_msg)
 | 
			
		||||
            nm += 1
 | 
			
		||||
        self.log.info('Scanned %d unread messages.', nm)
 | 
			
		||||
 | 
			
		||||
    def stopDaemon(self, coin):
 | 
			
		||||
        if coin == Coins.XMR:
 | 
			
		||||
@ -757,6 +792,11 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            if synced < 1.0:
 | 
			
		||||
                raise ValueError('{} chain is still syncing, currently at {}.'.format(self.coin_clients[c]['name'], synced))
 | 
			
		||||
 | 
			
		||||
    def isSystemUnlocked(self):
 | 
			
		||||
        # TODO - Check all active coins
 | 
			
		||||
        ci = self.ci(Coins.PART)
 | 
			
		||||
        return not ci.isWalletLocked()
 | 
			
		||||
 | 
			
		||||
    def checkSystemStatus(self):
 | 
			
		||||
        ci = self.ci(Coins.PART)
 | 
			
		||||
        if ci.isWalletLocked():
 | 
			
		||||
@ -801,6 +841,7 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            self._is_encrypted, self._is_locked = self.ci(Coins.PART).isWalletEncryptedLocked()
 | 
			
		||||
 | 
			
		||||
    def unlockWallets(self, password, coin=None):
 | 
			
		||||
        self._read_zmq_queue = False
 | 
			
		||||
        for c in self.activeCoins():
 | 
			
		||||
            if coin and c != coin:
 | 
			
		||||
                continue
 | 
			
		||||
@ -808,13 +849,20 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            if c == Coins.PART:
 | 
			
		||||
                self._is_locked = False
 | 
			
		||||
 | 
			
		||||
        self.loadFromDB()
 | 
			
		||||
        self._read_zmq_queue = True
 | 
			
		||||
 | 
			
		||||
    def lockWallets(self, coin=None):
 | 
			
		||||
        self._read_zmq_queue = False
 | 
			
		||||
        self.swaps_in_progress.clear()
 | 
			
		||||
 | 
			
		||||
        for c in self.activeCoins():
 | 
			
		||||
            if coin and c != coin:
 | 
			
		||||
                continue
 | 
			
		||||
            self.ci(c).lockWallet()
 | 
			
		||||
            if c == Coins.PART:
 | 
			
		||||
                self._is_locked = True
 | 
			
		||||
        self._read_zmq_queue = True
 | 
			
		||||
 | 
			
		||||
    def initialiseWallet(self, coin_type, raise_errors=False):
 | 
			
		||||
        if coin_type == Coins.PART:
 | 
			
		||||
@ -929,7 +977,7 @@ class BasicSwap(BaseApp):
 | 
			
		||||
        with self.mxDB:
 | 
			
		||||
            try:
 | 
			
		||||
                session = scoped_session(self.session_factory)
 | 
			
		||||
                session.execute('DELETE FROM kv_string WHERE key = "{}" '.format(str_key))
 | 
			
		||||
                session.execute('DELETE FROM kv_string WHERE key = :key', {'key': str_key})
 | 
			
		||||
                session.commit()
 | 
			
		||||
            finally:
 | 
			
		||||
                session.close()
 | 
			
		||||
@ -1037,7 +1085,10 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            if session is None:
 | 
			
		||||
                self.closeSession(use_session)
 | 
			
		||||
 | 
			
		||||
    def loadFromDB(self):
 | 
			
		||||
    def loadFromDB(self) -> None:
 | 
			
		||||
        if self.isSystemUnlocked() is False:
 | 
			
		||||
            self.log.info('Not loading from db.  System is locked.')
 | 
			
		||||
            return
 | 
			
		||||
        self.log.info('Loading data from db')
 | 
			
		||||
        self.mxDB.acquire()
 | 
			
		||||
        self.swaps_in_progress.clear()
 | 
			
		||||
@ -1061,39 +1112,6 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            session.remove()
 | 
			
		||||
            self.mxDB.release()
 | 
			
		||||
 | 
			
		||||
    def initialise(self):
 | 
			
		||||
        self.log.debug('network_key %s\nnetwork_pubkey %s\nnetwork_addr %s',
 | 
			
		||||
                       self.network_key, self.network_pubkey, self.network_addr)
 | 
			
		||||
 | 
			
		||||
        ro = self.callrpc('smsglocalkeys')
 | 
			
		||||
        found = False
 | 
			
		||||
        for k in ro['smsg_keys']:
 | 
			
		||||
            if k['address'] == self.network_addr:
 | 
			
		||||
                found = True
 | 
			
		||||
                break
 | 
			
		||||
        if not found:
 | 
			
		||||
            self.log.info('Importing network key to SMSG')
 | 
			
		||||
            self.callrpc('smsgimportprivkey', [self.network_key, 'basicswap offers'])
 | 
			
		||||
            ro = self.callrpc('smsglocalkeys', ['anon', '-', self.network_addr])
 | 
			
		||||
            ensure(ro['result'] == 'Success.', 'smsglocalkeys failed')
 | 
			
		||||
 | 
			
		||||
        # TODO: Ensure smsg is enabled for the active wallet.
 | 
			
		||||
 | 
			
		||||
        self.loadFromDB()
 | 
			
		||||
 | 
			
		||||
        # Scan inbox
 | 
			
		||||
        # TODO: Redundant? small window for zmq messages to go unnoticed during startup?
 | 
			
		||||
        # options = {'encoding': 'hex'}
 | 
			
		||||
        options = {'encoding': 'none'}
 | 
			
		||||
        ro = self.callrpc('smsginbox', ['unread', '', options])
 | 
			
		||||
        nm = 0
 | 
			
		||||
        for msg in ro['messages']:
 | 
			
		||||
            # TODO: Remove workaround for smsginbox bug
 | 
			
		||||
            get_msg = self.callrpc('smsg', [msg['msgid'], {'encoding': 'hex', 'setread': True}])
 | 
			
		||||
            self.processMsg(get_msg)
 | 
			
		||||
            nm += 1
 | 
			
		||||
        self.log.info('Scanned %d unread messages.', nm)
 | 
			
		||||
 | 
			
		||||
    def getActiveBidMsgValidTime(self):
 | 
			
		||||
        return self.SMSG_SECONDS_IN_HOUR * 48
 | 
			
		||||
 | 
			
		||||
@ -1882,7 +1900,7 @@ class BasicSwap(BaseApp):
 | 
			
		||||
        try:
 | 
			
		||||
            self._contract_count += 1
 | 
			
		||||
            session = scoped_session(self.session_factory)
 | 
			
		||||
            session.execute('UPDATE kv_int SET value = {} WHERE KEY="contract_count"'.format(self._contract_count))
 | 
			
		||||
            session.execute('UPDATE kv_int SET value = :value WHERE KEY="contract_count"', {'value': self._contract_count})
 | 
			
		||||
            session.commit()
 | 
			
		||||
        finally:
 | 
			
		||||
            session.close()
 | 
			
		||||
@ -3870,7 +3888,11 @@ class BasicSwap(BaseApp):
 | 
			
		||||
                c['last_height_checked'] = last_height_checked
 | 
			
		||||
                self.setIntKV('last_height_checked_' + chainparams[coin_type]['name'], last_height_checked)
 | 
			
		||||
 | 
			
		||||
    def expireMessages(self):
 | 
			
		||||
    def expireMessages(self) -> None:
 | 
			
		||||
        if self._is_locked is True:
 | 
			
		||||
            self.log.debug('Not expiring messages while system locked')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        self.mxDB.acquire()
 | 
			
		||||
        rpc_conn = None
 | 
			
		||||
        try:
 | 
			
		||||
@ -3947,9 +3969,9 @@ class BasicSwap(BaseApp):
 | 
			
		||||
                    self.logException(f'checkQueuedActions failed: {ex}')
 | 
			
		||||
 | 
			
		||||
            if self.debug:
 | 
			
		||||
                session.execute('UPDATE actions SET active_ind = 2 WHERE trigger_at <= {}'.format(now))
 | 
			
		||||
                session.execute('UPDATE actions SET active_ind = 2 WHERE trigger_at <= :now', {'now': now})
 | 
			
		||||
            else:
 | 
			
		||||
                session.execute('DELETE FROM actions WHERE trigger_at <= {}'.format(now))
 | 
			
		||||
                session.execute('DELETE FROM actions WHERE trigger_at <= :now', {'now': now})
 | 
			
		||||
 | 
			
		||||
            session.commit()
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
@ -5014,7 +5036,7 @@ class BasicSwap(BaseApp):
 | 
			
		||||
 | 
			
		||||
            if coin_to == Coins.XMR:
 | 
			
		||||
                address_to = self.getCachedMainWalletAddress(ci_to)
 | 
			
		||||
            elif coin_to == Coins.PART_BLIND:
 | 
			
		||||
            elif coin_to in (Coins.PART_BLIND, Coins.PART_ANON):
 | 
			
		||||
                address_to = self.getCachedStealthAddressForCoin(coin_to)
 | 
			
		||||
            else:
 | 
			
		||||
                address_to = self.getReceiveAddressFromPool(coin_to, bid_id, TxTypes.XMR_SWAP_B_LOCK_SPEND)
 | 
			
		||||
@ -5323,6 +5345,9 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            rv = None
 | 
			
		||||
            if msg_type == MessageTypes.OFFER:
 | 
			
		||||
                self.processOffer(msg)
 | 
			
		||||
            elif msg_type == MessageTypes.OFFER_REVOKE:
 | 
			
		||||
                self.processOfferRevoke(msg)
 | 
			
		||||
            # TODO: When changing from wallet keys (encrypted/locked) handle swap messages while locked
 | 
			
		||||
            elif msg_type == MessageTypes.BID:
 | 
			
		||||
                self.processBid(msg)
 | 
			
		||||
            elif msg_type == MessageTypes.BID_ACCEPT:
 | 
			
		||||
@ -5339,8 +5364,6 @@ class BasicSwap(BaseApp):
 | 
			
		||||
                self.processXmrSplitMessage(msg)
 | 
			
		||||
            elif msg_type == MessageTypes.XMR_BID_LOCK_RELEASE_LF:
 | 
			
		||||
                self.processXmrLockReleaseMessage(msg)
 | 
			
		||||
            if msg_type == MessageTypes.OFFER_REVOKE:
 | 
			
		||||
                self.processOfferRevoke(msg)
 | 
			
		||||
 | 
			
		||||
        except InactiveCoin as ex:
 | 
			
		||||
            self.log.info('Ignoring message involving inactive coin {}, type {}'.format(Coins(ex.coinid).name, MessageTypes(msg_type).name))
 | 
			
		||||
@ -5381,10 +5404,10 @@ class BasicSwap(BaseApp):
 | 
			
		||||
 | 
			
		||||
    def update(self):
 | 
			
		||||
        try:
 | 
			
		||||
            # while True:
 | 
			
		||||
            message = self.zmqSubscriber.recv(flags=zmq.NOBLOCK)
 | 
			
		||||
            if message == b'smsg':
 | 
			
		||||
                self.processZmqSmsg()
 | 
			
		||||
            if self._read_zmq_queue:
 | 
			
		||||
                message = self.zmqSubscriber.recv(flags=zmq.NOBLOCK)
 | 
			
		||||
                if message == b'smsg':
 | 
			
		||||
                    self.processZmqSmsg()
 | 
			
		||||
        except zmq.Again as ex:
 | 
			
		||||
            pass
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
@ -6178,6 +6201,7 @@ class BasicSwap(BaseApp):
 | 
			
		||||
 | 
			
		||||
            addr_info = self.callrpc('getaddressinfo', [new_addr])
 | 
			
		||||
            self.callrpc('smsgaddlocaladdress', [new_addr])  # Enable receiving smsgs
 | 
			
		||||
            self.callrpc('smsglocalkeys', ['anon', '-', new_addr])
 | 
			
		||||
 | 
			
		||||
            use_session.add(SmsgAddress(addr=new_addr, use_type=use_type, active_ind=1, created_at=now, note=addressnote, pubkey=addr_info['pubkey']))
 | 
			
		||||
            return new_addr, addr_info['pubkey']
 | 
			
		||||
@ -6193,6 +6217,7 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            ci = self.ci(Coins.PART)
 | 
			
		||||
            add_addr = ci.pubkey_to_address(bytes.fromhex(pubkey_hex))
 | 
			
		||||
            self.callrpc('smsgaddaddress', [add_addr, pubkey_hex])
 | 
			
		||||
            self.callrpc('smsglocalkeys', ['anon', '-', add_addr])
 | 
			
		||||
 | 
			
		||||
            session.add(SmsgAddress(addr=add_addr, use_type=AddressTypes.SEND_OFFER, active_ind=1, created_at=now, note=addressnote, pubkey=pubkey_hex))
 | 
			
		||||
            session.commit()
 | 
			
		||||
@ -6209,7 +6234,7 @@ class BasicSwap(BaseApp):
 | 
			
		||||
            mode = '-' if active_ind == 0 else '+'
 | 
			
		||||
            self.callrpc('smsglocalkeys', ['recv', mode, address])
 | 
			
		||||
 | 
			
		||||
            session.execute('UPDATE smsgaddresses SET active_ind = {}, note = "{}" WHERE addr = "{}"'.format(active_ind, addressnote, address))
 | 
			
		||||
            session.execute('UPDATE smsgaddresses SET active_ind = :active_ind, note = :note WHERE addr = :addr', {'active_ind': active_ind, 'note': addressnote, 'addr': address})
 | 
			
		||||
            session.commit()
 | 
			
		||||
        finally:
 | 
			
		||||
            session.close()
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2019-2022 tecnovert
 | 
			
		||||
# Copyright (c) 2019-2023 tecnovert
 | 
			
		||||
# Distributed under the MIT software license, see the accompanying
 | 
			
		||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
 | 
			
		||||
 | 
			
		||||
@ -374,29 +374,29 @@ class CoinInterface:
 | 
			
		||||
            ticker = 'rt' + ticker
 | 
			
		||||
        return ticker
 | 
			
		||||
 | 
			
		||||
    def getExchangeTicker(self, exchange_name):
 | 
			
		||||
    def getExchangeTicker(self, exchange_name: str) -> str:
 | 
			
		||||
        return chainparams[self.coin_type()]['ticker']
 | 
			
		||||
 | 
			
		||||
    def getExchangeName(self, exchange_name):
 | 
			
		||||
    def getExchangeName(self, exchange_name: str) -> str:
 | 
			
		||||
        return chainparams[self.coin_type()]['name']
 | 
			
		||||
 | 
			
		||||
    def ticker_mainnet(self):
 | 
			
		||||
    def ticker_mainnet(self) -> str:
 | 
			
		||||
        ticker = chainparams[self.coin_type()]['ticker']
 | 
			
		||||
        return ticker
 | 
			
		||||
 | 
			
		||||
    def min_amount(self):
 | 
			
		||||
    def min_amount(self) -> int:
 | 
			
		||||
        return chainparams[self.coin_type()][self._network]['min_amount']
 | 
			
		||||
 | 
			
		||||
    def max_amount(self):
 | 
			
		||||
    def max_amount(self) -> int:
 | 
			
		||||
        return chainparams[self.coin_type()][self._network]['max_amount']
 | 
			
		||||
 | 
			
		||||
    def setWalletSeedWarning(self, value):
 | 
			
		||||
    def setWalletSeedWarning(self, value: bool) -> None:
 | 
			
		||||
        self._unknown_wallet_seed = value
 | 
			
		||||
 | 
			
		||||
    def setWalletRestoreHeight(self, value):
 | 
			
		||||
    def setWalletRestoreHeight(self, value: int) -> None:
 | 
			
		||||
        self._restore_height = value
 | 
			
		||||
 | 
			
		||||
    def knownWalletSeed(self):
 | 
			
		||||
    def knownWalletSeed(self) -> bool:
 | 
			
		||||
        return not self._unknown_wallet_seed
 | 
			
		||||
 | 
			
		||||
    def chainparams(self):
 | 
			
		||||
@ -408,13 +408,13 @@ class CoinInterface:
 | 
			
		||||
    def has_segwit(self) -> bool:
 | 
			
		||||
        return chainparams[self.coin_type()].get('has_segwit', True)
 | 
			
		||||
 | 
			
		||||
    def is_transient_error(self, ex):
 | 
			
		||||
    def is_transient_error(self, ex) -> bool:
 | 
			
		||||
        if isinstance(ex, TemporaryError):
 | 
			
		||||
            return True
 | 
			
		||||
        str_error = str(ex).lower()
 | 
			
		||||
        str_error: str = str(ex).lower()
 | 
			
		||||
        if 'not enough unlocked money' in str_error:
 | 
			
		||||
            return True
 | 
			
		||||
        if 'No unlocked balance' in str_error:
 | 
			
		||||
        if 'no unlocked balance' in str_error:
 | 
			
		||||
            return True
 | 
			
		||||
        if 'transaction was rejected by daemon' in str_error:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
@ -423,6 +423,7 @@ class KnownIdentity(Base):
 | 
			
		||||
    num_recv_bids_failed = sa.Column(sa.Integer)
 | 
			
		||||
    automation_override = sa.Column(sa.Integer)  # AutomationOverrideOptions
 | 
			
		||||
    visibility_override = sa.Column(sa.Integer)  # VisibilityOverrideOptions
 | 
			
		||||
    data = sa.Column(sa.LargeBinary)
 | 
			
		||||
    note = sa.Column(sa.String)
 | 
			
		||||
    updated_at = sa.Column(sa.BigInteger)
 | 
			
		||||
    created_at = sa.Column(sa.BigInteger)
 | 
			
		||||
 | 
			
		||||
@ -238,10 +238,11 @@ def upgradeDatabase(self, db_version):
 | 
			
		||||
                    tx_data BLOB,
 | 
			
		||||
                    used_by BLOB,
 | 
			
		||||
                    PRIMARY KEY (record_id))''')
 | 
			
		||||
        elif current_version == 16:
 | 
			
		||||
        elif current_version == 17:
 | 
			
		||||
            db_version += 1
 | 
			
		||||
            session.execute('ALTER TABLE knownidentities ADD COLUMN automation_override INTEGER')
 | 
			
		||||
            session.execute('ALTER TABLE knownidentities ADD COLUMN visibility_override INTEGER')
 | 
			
		||||
            session.execute('ALTER TABLE knownidentities ADD COLUMN data BLOB')
 | 
			
		||||
            session.execute('UPDATE knownidentities SET active_ind = 1')
 | 
			
		||||
 | 
			
		||||
        if current_version != db_version:
 | 
			
		||||
 | 
			
		||||
@ -684,7 +684,7 @@ class PARTInterfaceBlind(PARTInterface):
 | 
			
		||||
                return -1
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def spendBLockTx(self, chain_b_lock_txid, address_to, kbv, kbs, cb_swap_value, b_fee, restore_height, spend_actual_balance=False):
 | 
			
		||||
    def spendBLockTx(self, chain_b_lock_txid: bytes, address_to: str, kbv: bytes, kbs: bytes, cb_swap_value: int, b_fee: int, restore_height: int, spend_actual_balance: bool = False) -> bytes:
 | 
			
		||||
        Kbv = self.getPubkey(kbv)
 | 
			
		||||
        Kbs = self.getPubkey(kbs)
 | 
			
		||||
        sx_addr = self.formatStealthAddress(Kbv, Kbs)
 | 
			
		||||
@ -813,7 +813,7 @@ class PARTInterfaceAnon(PARTInterface):
 | 
			
		||||
                return -1
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def spendBLockTx(self, chain_b_lock_txid, address_to, kbv, kbs, cb_swap_value, b_fee, restore_height, spend_actual_balance=False):
 | 
			
		||||
    def spendBLockTx(self, chain_b_lock_txid: bytes, address_to: str, kbv: bytes, kbs: bytes, cb_swap_value: int, b_fee: int, restore_height: int, spend_actual_balance: bool = False) -> bytes:
 | 
			
		||||
        Kbv = self.getPubkey(kbv)
 | 
			
		||||
        Kbs = self.getPubkey(kbs)
 | 
			
		||||
        sx_addr = self.formatStealthAddress(Kbv, Kbs)
 | 
			
		||||
 | 
			
		||||
@ -417,7 +417,7 @@ class XMRInterface(CoinInterface):
 | 
			
		||||
 | 
			
		||||
            return bytes.fromhex(rv['tx_hash_list'][0])
 | 
			
		||||
 | 
			
		||||
    def withdrawCoin(self, value, addr_to, subfee):
 | 
			
		||||
    def withdrawCoin(self, value: int, addr_to: str, subfee: bool) -> str:
 | 
			
		||||
        with self._mx_wallet:
 | 
			
		||||
            value_sats = make_int(value, self.exp())
 | 
			
		||||
 | 
			
		||||
@ -427,7 +427,7 @@ class XMRInterface(CoinInterface):
 | 
			
		||||
            if subfee:
 | 
			
		||||
                balance = self.rpc_wallet_cb('get_balance')
 | 
			
		||||
                diff = balance['unlocked_balance'] - value_sats
 | 
			
		||||
                if diff > 0 and diff <= 10:
 | 
			
		||||
                if diff >= 0 and diff <= 10:
 | 
			
		||||
                    self._log.info('subfee enabled and value close to total, using sweep_all.')
 | 
			
		||||
                    params = {'address': addr_to}
 | 
			
		||||
                    if self._fee_priority > 0:
 | 
			
		||||
 | 
			
		||||
@ -344,7 +344,7 @@ def js_bids(self, url_split, post_string: str, is_json: bool) -> bytes:
 | 
			
		||||
        data = describeBid(swap_client, bid, xmr_swap, offer, xmr_offer, events, edit_bid, show_txns, for_api=True)
 | 
			
		||||
        return bytes(json.dumps(data), 'UTF-8')
 | 
			
		||||
 | 
			
		||||
    post_data = getFormData(post_string, is_json)
 | 
			
		||||
    post_data = {} if post_string == '' else getFormData(post_string, is_json)
 | 
			
		||||
    offer_id, filters = parseBidFilters(post_data)
 | 
			
		||||
 | 
			
		||||
    bids = swap_client.listBids(offer_id=offer_id, filters=filters)
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,7 @@ def dumpje(jin):
 | 
			
		||||
    return json.dumps(jin, default=jsonDecimal).replace('"', '\\"')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def SerialiseNum(n):
 | 
			
		||||
def SerialiseNum(n: int) -> bytes:
 | 
			
		||||
    if n == 0:
 | 
			
		||||
        return bytes((0x00,))
 | 
			
		||||
    if n > 0 and n <= 16:
 | 
			
		||||
@ -84,7 +84,7 @@ def SerialiseNum(n):
 | 
			
		||||
    return bytes((len(rv),)) + rv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def DeserialiseNum(b, o=0) -> int:
 | 
			
		||||
def DeserialiseNum(b: bytes, o: int = 0) -> int:
 | 
			
		||||
    if b[o] == 0:
 | 
			
		||||
        return 0
 | 
			
		||||
    if b[o] > 0x50 and b[o] <= 0x50 + 16:
 | 
			
		||||
@ -100,13 +100,13 @@ def DeserialiseNum(b, o=0) -> int:
 | 
			
		||||
    return v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def float_to_str(f):
 | 
			
		||||
def float_to_str(f: float) -> str:
 | 
			
		||||
    # stackoverflow.com/questions/38847690
 | 
			
		||||
    d1 = decimal_ctx.create_decimal(repr(f))
 | 
			
		||||
    return format(d1, 'f')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_int(v, scale=8, r=0):  # r = 0, no rounding, fail, r > 0 round up, r < 0 floor
 | 
			
		||||
def make_int(v, scale=8, r=0) -> int:  # r = 0, no rounding, fail, r > 0 round up, r < 0 floor
 | 
			
		||||
    if type(v) == float:
 | 
			
		||||
        v = float_to_str(v)
 | 
			
		||||
    elif type(v) == int:
 | 
			
		||||
@ -177,7 +177,7 @@ def format_amount(i, display_scale, scale=None):
 | 
			
		||||
    return rv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def format_timestamp(value: int, with_seconds=False) -> str:
 | 
			
		||||
def format_timestamp(value: int, with_seconds: bool = False) -> str:
 | 
			
		||||
    str_format = '%Y-%m-%d %H:%M'
 | 
			
		||||
    if with_seconds:
 | 
			
		||||
        str_format += ':%S'
 | 
			
		||||
@ -185,7 +185,7 @@ def format_timestamp(value: int, with_seconds=False) -> str:
 | 
			
		||||
    return time.strftime(str_format, time.localtime(value))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def b2i(b) -> int:
 | 
			
		||||
def b2i(b: bytes) -> int:
 | 
			
		||||
    # bytes32ToInt
 | 
			
		||||
    return int.from_bytes(b, byteorder='big')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2022 tecnovert
 | 
			
		||||
# Copyright (c) 2022-2023 tecnovert
 | 
			
		||||
# Distributed under the MIT software license, see the accompanying
 | 
			
		||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
 | 
			
		||||
 | 
			
		||||
@ -59,7 +59,7 @@ def b58encode(v):
 | 
			
		||||
    return (__b58chars[0] * nPad) + result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey):
 | 
			
		||||
def encodeStealthAddress(prefix_byte: int, scan_pubkey: bytes, spend_pubkey: bytes) -> str:
 | 
			
		||||
    data = bytes((0x00,))
 | 
			
		||||
    data += scan_pubkey
 | 
			
		||||
    data += bytes((0x01,))
 | 
			
		||||
@ -72,14 +72,14 @@ def encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey):
 | 
			
		||||
    return b58encode(b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def decodeWif(encoded_key):
 | 
			
		||||
def decodeWif(encoded_key: str) -> bytes:
 | 
			
		||||
    key = b58decode(encoded_key)[1:-4]
 | 
			
		||||
    if len(key) == 33:
 | 
			
		||||
        return key[:-1]
 | 
			
		||||
    return key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def toWIF(prefix_byte, b, compressed=True):
 | 
			
		||||
def toWIF(prefix_byte: int, b: bytes, compressed: bool = True) -> str:
 | 
			
		||||
    b = bytes((prefix_byte,)) + b
 | 
			
		||||
    if compressed:
 | 
			
		||||
        b += bytes((0x01,))
 | 
			
		||||
@ -87,9 +87,9 @@ def toWIF(prefix_byte, b, compressed=True):
 | 
			
		||||
    return b58encode(b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def getKeyID(bytes):
 | 
			
		||||
    data = hashlib.sha256(bytes).digest()
 | 
			
		||||
    return ripemd160(data)
 | 
			
		||||
def getKeyID(key_data: bytes) -> str:
 | 
			
		||||
    sha256_hash = hashlib.sha256(key_data).digest()
 | 
			
		||||
    return ripemd160(sha256_hash)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bech32Decode(hrp, addr):
 | 
			
		||||
@ -109,7 +109,7 @@ def bech32Encode(hrp, data):
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def decodeAddress(address_str):
 | 
			
		||||
def decodeAddress(address_str: str):
 | 
			
		||||
    b58_addr = b58decode(address_str)
 | 
			
		||||
    if b58_addr is not None:
 | 
			
		||||
        address = b58_addr[:-4]
 | 
			
		||||
@ -119,10 +119,10 @@ def decodeAddress(address_str):
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def encodeAddress(address):
 | 
			
		||||
def encodeAddress(address: bytes) -> str:
 | 
			
		||||
    checksum = hashlib.sha256(hashlib.sha256(address).digest()).digest()
 | 
			
		||||
    return b58encode(address + checksum[0:4])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pubkeyToAddress(prefix, pubkey):
 | 
			
		||||
def pubkeyToAddress(prefix: int, pubkey: bytes) -> str:
 | 
			
		||||
    return encodeAddress(bytes((prefix,)) + getKeyID(pubkey))
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										160
									
								
								tests/basicswap/extended/test_encrypted_xmr_reload.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								tests/basicswap/extended/test_encrypted_xmr_reload.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,160 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2020-2023 tecnovert
 | 
			
		||||
# Distributed under the MIT software license, see the accompanying
 | 
			
		||||
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
export TEST_PATH=/tmp/test_basicswap
 | 
			
		||||
mkdir -p ${TEST_PATH}/bin
 | 
			
		||||
cp -r ~/tmp/basicswap_bin/* ${TEST_PATH}/bin
 | 
			
		||||
export PYTHONPATH=$(pwd)
 | 
			
		||||
python tests/basicswap/extended/test_encrypted_xmr_reload.py
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import logging
 | 
			
		||||
import unittest
 | 
			
		||||
import multiprocessing
 | 
			
		||||
 | 
			
		||||
from tests.basicswap.util import (
 | 
			
		||||
    read_json_api,
 | 
			
		||||
    post_json_api,
 | 
			
		||||
    waitForServer,
 | 
			
		||||
)
 | 
			
		||||
from tests.basicswap.common import (
 | 
			
		||||
    waitForNumOffers,
 | 
			
		||||
    waitForNumBids,
 | 
			
		||||
    waitForNumSwapping,
 | 
			
		||||
)
 | 
			
		||||
from tests.basicswap.common_xmr import (
 | 
			
		||||
    XmrTestBase,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger()
 | 
			
		||||
logger.level = logging.DEBUG
 | 
			
		||||
if not len(logger.handlers):
 | 
			
		||||
    logger.addHandler(logging.StreamHandler(sys.stdout))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Test(XmrTestBase):
 | 
			
		||||
 | 
			
		||||
    def test_reload(self):
 | 
			
		||||
        self.start_processes()
 | 
			
		||||
 | 
			
		||||
        waitForServer(self.delay_event, 12700)
 | 
			
		||||
        waitForServer(self.delay_event, 12701)
 | 
			
		||||
        wallets1 = read_json_api(12701, 'wallets')
 | 
			
		||||
        assert (float(wallets1['XMR']['balance']) > 0.0)
 | 
			
		||||
 | 
			
		||||
        node1_password: str = 'notapassword123'
 | 
			
		||||
        logger.info('Encrypting node 1 wallets')
 | 
			
		||||
        rv = read_json_api(12701, 'setpassword', {'oldpassword': '', 'newpassword': node1_password})
 | 
			
		||||
        assert ('success' in rv)
 | 
			
		||||
        rv = read_json_api(12701, 'unlock', {'password': node1_password})
 | 
			
		||||
        assert ('success' in rv)
 | 
			
		||||
 | 
			
		||||
        data = {
 | 
			
		||||
            'addr_from': '-1',
 | 
			
		||||
            'coin_from': 'part',
 | 
			
		||||
            'coin_to': 'xmr',
 | 
			
		||||
            'amt_from': '1',
 | 
			
		||||
            'amt_to': '1',
 | 
			
		||||
            'lockhrs': '24'}
 | 
			
		||||
 | 
			
		||||
        offer_id = post_json_api(12700, 'offers/new', data)['offer_id']
 | 
			
		||||
        summary = read_json_api(12700)
 | 
			
		||||
        assert (summary['num_sent_offers'] == 1)
 | 
			
		||||
 | 
			
		||||
        logger.info('Waiting for offer')
 | 
			
		||||
        waitForNumOffers(self.delay_event, 12701, 1)
 | 
			
		||||
 | 
			
		||||
        offers = read_json_api(12701, 'offers')
 | 
			
		||||
        offer = offers[0]
 | 
			
		||||
 | 
			
		||||
        data = {
 | 
			
		||||
            'offer_id': offer['offer_id'],
 | 
			
		||||
            'amount_from': offer['amount_from']}
 | 
			
		||||
 | 
			
		||||
        data['valid_for_seconds'] = 24 * 60 * 60 + 1
 | 
			
		||||
        bid = post_json_api(12701, 'bids/new', data)
 | 
			
		||||
        assert (bid['error'] == 'Bid TTL too high')
 | 
			
		||||
        del data['valid_for_seconds']
 | 
			
		||||
        data['validmins'] = 24 * 60 + 1
 | 
			
		||||
        bid = post_json_api(12701, 'bids/new', data)
 | 
			
		||||
        assert (bid['error'] == 'Bid TTL too high')
 | 
			
		||||
 | 
			
		||||
        del data['validmins']
 | 
			
		||||
        data['valid_for_seconds'] = 10
 | 
			
		||||
        bid = post_json_api(12701, 'bids/new', data)
 | 
			
		||||
        assert (bid['error'] == 'Bid TTL too low')
 | 
			
		||||
        del data['valid_for_seconds']
 | 
			
		||||
        data['validmins'] = 1
 | 
			
		||||
        bid = post_json_api(12701, 'bids/new', data)
 | 
			
		||||
        assert (bid['error'] == 'Bid TTL too low')
 | 
			
		||||
 | 
			
		||||
        data['validmins'] = 60
 | 
			
		||||
        bid_id = post_json_api(12701, 'bids/new', data)
 | 
			
		||||
 | 
			
		||||
        waitForNumBids(self.delay_event, 12700, 1)
 | 
			
		||||
 | 
			
		||||
        for i in range(10):
 | 
			
		||||
            bids = read_json_api(12700, 'bids')
 | 
			
		||||
            bid = bids[0]
 | 
			
		||||
            if bid['bid_state'] == 'Received':
 | 
			
		||||
                break
 | 
			
		||||
            self.delay_event.wait(1)
 | 
			
		||||
        assert (bid['expire_at'] == bid['created_at'] + data['validmins'] * 60)
 | 
			
		||||
 | 
			
		||||
        data = {
 | 
			
		||||
            'accept': True
 | 
			
		||||
        }
 | 
			
		||||
        rv = post_json_api(12700, 'bids/{}'.format(bid['bid_id']), data)
 | 
			
		||||
        assert (rv['bid_state'] == 'Accepted')
 | 
			
		||||
 | 
			
		||||
        waitForNumSwapping(self.delay_event, 12701, 1)
 | 
			
		||||
 | 
			
		||||
        logger.info('Restarting node 1')
 | 
			
		||||
        c1 = self.processes[1]
 | 
			
		||||
        c1.terminate()
 | 
			
		||||
        c1.join()
 | 
			
		||||
        self.processes[1] = multiprocessing.Process(target=self.run_thread, args=(1,))
 | 
			
		||||
        self.processes[1].start()
 | 
			
		||||
 | 
			
		||||
        waitForServer(self.delay_event, 12701)
 | 
			
		||||
        rv = read_json_api(12701)
 | 
			
		||||
        assert ('error' in rv)
 | 
			
		||||
 | 
			
		||||
        logger.info('Unlocking node 1')
 | 
			
		||||
        rv = read_json_api(12701, 'unlock', {'password': node1_password})
 | 
			
		||||
        assert ('success' in rv)
 | 
			
		||||
        rv = read_json_api(12701)
 | 
			
		||||
        assert (rv['num_swapping'] == 1)
 | 
			
		||||
 | 
			
		||||
        rv = read_json_api(12700, 'revokeoffer/{}'.format(offer_id))
 | 
			
		||||
        assert (rv['revoked_offer'] == offer_id)
 | 
			
		||||
 | 
			
		||||
        logger.info('Completing swap')
 | 
			
		||||
        for i in range(240):
 | 
			
		||||
            if self.delay_event.is_set():
 | 
			
		||||
                raise ValueError('Test stopped.')
 | 
			
		||||
            self.delay_event.wait(4)
 | 
			
		||||
 | 
			
		||||
            rv = read_json_api(12700, 'bids/{}'.format(bid['bid_id']))
 | 
			
		||||
            if rv['bid_state'] == 'Completed':
 | 
			
		||||
                break
 | 
			
		||||
        assert (rv['bid_state'] == 'Completed')
 | 
			
		||||
 | 
			
		||||
        # Ensure offer was revoked
 | 
			
		||||
        summary = read_json_api(12700)
 | 
			
		||||
        assert (summary['num_network_offers'] == 0)
 | 
			
		||||
 | 
			
		||||
        # Wait for bid to be removed from in-progress
 | 
			
		||||
        waitForNumBids(self.delay_event, 12700, 0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    unittest.main()
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user