refactor: Replace struct.pack/unpack.
This commit is contained in:
parent
7c9504e0cd
commit
8318961f0b
@ -24,7 +24,6 @@ import queue
|
||||
import random
|
||||
import select
|
||||
import socket
|
||||
import struct
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
@ -41,7 +40,7 @@ from basicswap.contrib.rfc6979 import (
|
||||
|
||||
|
||||
START_TOKEN = 0xabcd
|
||||
MSG_START_TOKEN = struct.pack('>H', START_TOKEN)
|
||||
MSG_START_TOKEN = START_TOKEN.to_bytes(2, 'big')
|
||||
|
||||
MSG_MAX_SIZE = 0x200000 # 2MB
|
||||
|
||||
@ -83,8 +82,8 @@ class MsgHandshake:
|
||||
pass
|
||||
|
||||
def encode_aad(self): # Additional Authenticated Data
|
||||
return struct.pack('>H', NetMessageTypes.HANDSHAKE) + \
|
||||
struct.pack('>Q', self._timestamp) + \
|
||||
return int(NetMessageTypes.HANDSHAKE).to_bytes(2, 'big') + \
|
||||
self._timestamp.to_bytes(8, 'big') + \
|
||||
self._ephem_pk
|
||||
|
||||
def encode(self):
|
||||
@ -92,7 +91,7 @@ class MsgHandshake:
|
||||
|
||||
def decode(self, msg_mv):
|
||||
o = 2
|
||||
self._timestamp = struct.unpack('>Q', msg_mv[o: o + 8])[0]
|
||||
self._timestamp = int.from_bytes(msg_mv[o: o + 8], 'big')
|
||||
o += 8
|
||||
self._ephem_pk = bytes(msg_mv[o: o + 33])
|
||||
o += 33
|
||||
@ -333,7 +332,7 @@ class Network:
|
||||
|
||||
ss = k.ecdh(peer._pubkey)
|
||||
|
||||
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
|
||||
hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
|
||||
peer._ke = hashed[:32]
|
||||
peer._km = hashed[32:]
|
||||
|
||||
@ -386,7 +385,7 @@ class Network:
|
||||
nk = PrivateKey(self._network_key)
|
||||
ss = nk.ecdh(msg._ephem_pk)
|
||||
|
||||
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
|
||||
hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
|
||||
peer._ke = hashed[:32]
|
||||
peer._km = hashed[32:]
|
||||
|
||||
@ -427,7 +426,7 @@ class Network:
|
||||
mac = msg_mv[-16:]
|
||||
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
|
||||
|
||||
ping_nonce = struct.unpack('>I', plaintext[:4])[0]
|
||||
ping_nonce = int.from_bytes(plaintext[:4], 'big')
|
||||
# Version is added to a ping following a handshake message
|
||||
if len(plaintext) >= 10:
|
||||
peer._ready = True
|
||||
@ -450,7 +449,7 @@ class Network:
|
||||
mac = msg_mv[-16:]
|
||||
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
|
||||
|
||||
pong_nonce = struct.unpack('>I', plaintext[:4])[0]
|
||||
pong_nonce = int.from_bytes(plaintext[:4], 'big')
|
||||
|
||||
if pong_nonce == peer._ping_nonce:
|
||||
peer._last_ping_rtt = (time.time_ns() // 1000) - peer._last_ping_at
|
||||
@ -462,14 +461,14 @@ class Network:
|
||||
def send_ping(self, peer):
|
||||
ping_nonce = random.getrandbits(32)
|
||||
|
||||
msg_bytes = struct.pack('>H', NetMessageTypes.PING)
|
||||
msg_bytes = int(NetMessageTypes.PING).to_bytes(2, 'big')
|
||||
nonce = peer._sent_nonce[:24]
|
||||
|
||||
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
|
||||
cipher.update(msg_bytes)
|
||||
cipher.update(nonce)
|
||||
|
||||
payload = struct.pack('>I', ping_nonce)
|
||||
payload = ping_nonce.to_bytes(4, 'big')
|
||||
if peer._last_ping_at == 0:
|
||||
payload += self._sc._version
|
||||
ct, mac = cipher.encrypt_and_digest(payload)
|
||||
@ -484,14 +483,14 @@ class Network:
|
||||
self.send_msg(peer, msg_bytes)
|
||||
|
||||
def send_pong(self, peer, ping_nonce):
|
||||
msg_bytes = struct.pack('>H', NetMessageTypes.PONG)
|
||||
msg_bytes = int(NetMessageTypes.PONG).to_bytes(2, 'big')
|
||||
nonce = peer._sent_nonce[:24]
|
||||
|
||||
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
|
||||
cipher.update(msg_bytes)
|
||||
cipher.update(nonce)
|
||||
|
||||
payload = struct.pack('>I', ping_nonce)
|
||||
payload = ping_nonce.to_bytes(4, 'big')
|
||||
ct, mac = cipher.encrypt_and_digest(payload)
|
||||
msg_bytes += ct + mac
|
||||
|
||||
@ -503,7 +502,7 @@ class Network:
|
||||
msg_encoded = msg if isinstance(msg, bytes) else msg.encode()
|
||||
len_encoded = len(msg_encoded)
|
||||
|
||||
msg_packed = bytearray(MSG_START_TOKEN) + struct.pack('>I', len_encoded) + msg_encoded
|
||||
msg_packed = bytearray(MSG_START_TOKEN) + len_encoded.to_bytes(4, 'big') + msg_encoded
|
||||
peer._socket.sendall(msg_packed)
|
||||
|
||||
peer._bytes_sent += len_encoded
|
||||
@ -515,7 +514,7 @@ class Network:
|
||||
try:
|
||||
mv = memoryview(msg_bytes)
|
||||
o = 0
|
||||
msg_type = struct.unpack('>H', mv[o: o + 2])[0]
|
||||
msg_type = int.from_bytes(mv[o: o + 2], 'big')
|
||||
if msg_type == NetMessageTypes.HANDSHAKE:
|
||||
self.process_handshake(peer, mv)
|
||||
elif msg_type == NetMessageTypes.PING:
|
||||
@ -548,13 +547,13 @@ class Network:
|
||||
raise ValueError('Invalid start token')
|
||||
o += 2
|
||||
|
||||
msg_len = struct.unpack('>I', mv[o: o + 4])[0]
|
||||
msg_len = int.from_bytes(mv[o: o + 4], 'big')
|
||||
o += 4
|
||||
if msg_len < 2 or msg_len > MSG_MAX_SIZE:
|
||||
raise ValueError('Invalid data length')
|
||||
|
||||
# Precheck msg_type
|
||||
msg_type = struct.unpack('>H', mv[o: o + 2])[0]
|
||||
msg_type = int.from_bytes(mv[o: o + 2], 'big')
|
||||
# o += 2 # Don't inc offset, msg includes type
|
||||
if not NetMessageTypes.has_value(msg_type):
|
||||
raise ValueError('Invalid msg type')
|
||||
|
Loading…
Reference in New Issue
Block a user