refactor: Replace struct.pack/unpack.
This commit is contained in:
		
							parent
							
								
									7c9504e0cd
								
							
						
					
					
						commit
						8318961f0b
					
				@ -24,7 +24,6 @@ import queue
 | 
			
		||||
import random
 | 
			
		||||
import select
 | 
			
		||||
import socket
 | 
			
		||||
import struct
 | 
			
		||||
import hashlib
 | 
			
		||||
import logging
 | 
			
		||||
import secrets
 | 
			
		||||
@ -41,7 +40,7 @@ from basicswap.contrib.rfc6979 import (
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
START_TOKEN = 0xabcd
 | 
			
		||||
MSG_START_TOKEN = struct.pack('>H', START_TOKEN)
 | 
			
		||||
MSG_START_TOKEN = START_TOKEN.to_bytes(2, 'big')
 | 
			
		||||
 | 
			
		||||
MSG_MAX_SIZE = 0x200000  # 2MB
 | 
			
		||||
 | 
			
		||||
@ -83,8 +82,8 @@ class MsgHandshake:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def encode_aad(self):  # Additional Authenticated Data
 | 
			
		||||
        return struct.pack('>H', NetMessageTypes.HANDSHAKE) + \
 | 
			
		||||
            struct.pack('>Q', self._timestamp) + \
 | 
			
		||||
        return int(NetMessageTypes.HANDSHAKE).to_bytes(2, 'big') + \
 | 
			
		||||
            self._timestamp.to_bytes(8, 'big') + \
 | 
			
		||||
            self._ephem_pk
 | 
			
		||||
 | 
			
		||||
    def encode(self):
 | 
			
		||||
@ -92,7 +91,7 @@ class MsgHandshake:
 | 
			
		||||
 | 
			
		||||
    def decode(self, msg_mv):
 | 
			
		||||
        o = 2
 | 
			
		||||
        self._timestamp = struct.unpack('>Q', msg_mv[o: o + 8])[0]
 | 
			
		||||
        self._timestamp = int.from_bytes(msg_mv[o: o + 8], 'big')
 | 
			
		||||
        o += 8
 | 
			
		||||
        self._ephem_pk = bytes(msg_mv[o: o + 33])
 | 
			
		||||
        o += 33
 | 
			
		||||
@ -333,7 +332,7 @@ class Network:
 | 
			
		||||
 | 
			
		||||
            ss = k.ecdh(peer._pubkey)
 | 
			
		||||
 | 
			
		||||
            hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
 | 
			
		||||
            hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
 | 
			
		||||
            peer._ke = hashed[:32]
 | 
			
		||||
            peer._km = hashed[32:]
 | 
			
		||||
 | 
			
		||||
@ -386,7 +385,7 @@ class Network:
 | 
			
		||||
            nk = PrivateKey(self._network_key)
 | 
			
		||||
            ss = nk.ecdh(msg._ephem_pk)
 | 
			
		||||
 | 
			
		||||
            hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
 | 
			
		||||
            hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
 | 
			
		||||
            peer._ke = hashed[:32]
 | 
			
		||||
            peer._km = hashed[32:]
 | 
			
		||||
 | 
			
		||||
@ -427,7 +426,7 @@ class Network:
 | 
			
		||||
        mac = msg_mv[-16:]
 | 
			
		||||
        plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
 | 
			
		||||
 | 
			
		||||
        ping_nonce = struct.unpack('>I', plaintext[:4])[0]
 | 
			
		||||
        ping_nonce = int.from_bytes(plaintext[:4], 'big')
 | 
			
		||||
        # Version is added to a ping following a handshake message
 | 
			
		||||
        if len(plaintext) >= 10:
 | 
			
		||||
            peer._ready = True
 | 
			
		||||
@ -450,7 +449,7 @@ class Network:
 | 
			
		||||
        mac = msg_mv[-16:]
 | 
			
		||||
        plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
 | 
			
		||||
 | 
			
		||||
        pong_nonce = struct.unpack('>I', plaintext[:4])[0]
 | 
			
		||||
        pong_nonce = int.from_bytes(plaintext[:4], 'big')
 | 
			
		||||
 | 
			
		||||
        if pong_nonce == peer._ping_nonce:
 | 
			
		||||
            peer._last_ping_rtt = (time.time_ns() // 1000) - peer._last_ping_at
 | 
			
		||||
@ -462,14 +461,14 @@ class Network:
 | 
			
		||||
    def send_ping(self, peer):
 | 
			
		||||
        ping_nonce = random.getrandbits(32)
 | 
			
		||||
 | 
			
		||||
        msg_bytes = struct.pack('>H', NetMessageTypes.PING)
 | 
			
		||||
        msg_bytes = int(NetMessageTypes.PING).to_bytes(2, 'big')
 | 
			
		||||
        nonce = peer._sent_nonce[:24]
 | 
			
		||||
 | 
			
		||||
        cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
 | 
			
		||||
        cipher.update(msg_bytes)
 | 
			
		||||
        cipher.update(nonce)
 | 
			
		||||
 | 
			
		||||
        payload = struct.pack('>I', ping_nonce)
 | 
			
		||||
        payload = ping_nonce.to_bytes(4, 'big')
 | 
			
		||||
        if peer._last_ping_at == 0:
 | 
			
		||||
            payload += self._sc._version
 | 
			
		||||
        ct, mac = cipher.encrypt_and_digest(payload)
 | 
			
		||||
@ -484,14 +483,14 @@ class Network:
 | 
			
		||||
        self.send_msg(peer, msg_bytes)
 | 
			
		||||
 | 
			
		||||
    def send_pong(self, peer, ping_nonce):
 | 
			
		||||
        msg_bytes = struct.pack('>H', NetMessageTypes.PONG)
 | 
			
		||||
        msg_bytes = int(NetMessageTypes.PONG).to_bytes(2, 'big')
 | 
			
		||||
        nonce = peer._sent_nonce[:24]
 | 
			
		||||
 | 
			
		||||
        cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
 | 
			
		||||
        cipher.update(msg_bytes)
 | 
			
		||||
        cipher.update(nonce)
 | 
			
		||||
 | 
			
		||||
        payload = struct.pack('>I', ping_nonce)
 | 
			
		||||
        payload = ping_nonce.to_bytes(4, 'big')
 | 
			
		||||
        ct, mac = cipher.encrypt_and_digest(payload)
 | 
			
		||||
        msg_bytes += ct + mac
 | 
			
		||||
 | 
			
		||||
@ -503,7 +502,7 @@ class Network:
 | 
			
		||||
        msg_encoded = msg if isinstance(msg, bytes) else msg.encode()
 | 
			
		||||
        len_encoded = len(msg_encoded)
 | 
			
		||||
 | 
			
		||||
        msg_packed = bytearray(MSG_START_TOKEN) + struct.pack('>I', len_encoded) + msg_encoded
 | 
			
		||||
        msg_packed = bytearray(MSG_START_TOKEN) + len_encoded.to_bytes(4, 'big') + msg_encoded
 | 
			
		||||
        peer._socket.sendall(msg_packed)
 | 
			
		||||
 | 
			
		||||
        peer._bytes_sent += len_encoded
 | 
			
		||||
@ -515,7 +514,7 @@ class Network:
 | 
			
		||||
        try:
 | 
			
		||||
            mv = memoryview(msg_bytes)
 | 
			
		||||
            o = 0
 | 
			
		||||
            msg_type = struct.unpack('>H', mv[o: o + 2])[0]
 | 
			
		||||
            msg_type = int.from_bytes(mv[o: o + 2], 'big')
 | 
			
		||||
            if msg_type == NetMessageTypes.HANDSHAKE:
 | 
			
		||||
                self.process_handshake(peer, mv)
 | 
			
		||||
            elif msg_type == NetMessageTypes.PING:
 | 
			
		||||
@ -548,13 +547,13 @@ class Network:
 | 
			
		||||
                        raise ValueError('Invalid start token')
 | 
			
		||||
                    o += 2
 | 
			
		||||
 | 
			
		||||
                    msg_len = struct.unpack('>I', mv[o: o + 4])[0]
 | 
			
		||||
                    msg_len = int.from_bytes(mv[o: o + 4], 'big')
 | 
			
		||||
                    o += 4
 | 
			
		||||
                    if msg_len < 2 or msg_len > MSG_MAX_SIZE:
 | 
			
		||||
                        raise ValueError('Invalid data length')
 | 
			
		||||
 | 
			
		||||
                    # Precheck msg_type
 | 
			
		||||
                    msg_type = struct.unpack('>H', mv[o: o + 2])[0]
 | 
			
		||||
                    msg_type = int.from_bytes(mv[o: o + 2], 'big')
 | 
			
		||||
                    # o += 2  # Don't inc offset, msg includes type
 | 
			
		||||
                    if not NetMessageTypes.has_value(msg_type):
 | 
			
		||||
                        raise ValueError('Invalid msg type')
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user