basicswap_miserver/basicswap/network.py

604 lines
20 KiB
Python
Raw Normal View History

2020-11-30 14:29:40 +00:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2020 tecnovert
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
'''
2020-12-15 18:00:44 +00:00
Message 2 bytes msg_class, 4 bytes length, [ 2 bytes msg_type, payload ]
Handshake procedure:
node0 connecting to node1
node0 send_handshake
node1 process_handshake
node1 send_ping - With a version field
node0 recv_ping
Both nodes are initialised
2020-12-22 11:21:25 +00:00
XChaCha20_Poly1305 mac is 16bytes
2020-11-30 14:29:40 +00:00
'''
2020-12-15 18:00:44 +00:00
import time
import queue
import random
import select
import socket
2020-12-15 18:00:44 +00:00
import struct
import hashlib
import logging
2020-12-15 18:00:44 +00:00
import secrets
import threading
2020-12-15 18:00:44 +00:00
import traceback
from enum import IntEnum, auto
from collections import OrderedDict
from Crypto.Cipher import ChaCha20_Poly1305 # TODO: Add to libsecp256k1/coincurve fork
from coincurve.keys import PrivateKey, PublicKey
2020-12-22 11:21:25 +00:00
from basicswap.contrib.rfc6979 import (
2020-12-15 18:00:44 +00:00
rfc6979_hmac_sha256_initialize,
rfc6979_hmac_sha256_generate)
START_TOKEN = 0xabcd
MSG_START_TOKEN = struct.pack('>H', START_TOKEN)
MSG_MAX_SIZE = 0x200000 # 2MB
MSG_HEADER_LEN = 8
MAX_SEEN_EPHEM_KEYS = 1000
TIMESTAMP_LEEWAY = 8
2020-12-15 18:00:44 +00:00
class NetMessageTypes(IntEnum):
HANDSHAKE = auto()
PING = auto()
PONG = auto()
DATA = auto()
2020-12-22 11:21:25 +00:00
ONION_PACKET = auto()
2020-12-15 18:00:44 +00:00
@classmethod
def has_value(cls, value):
return value in cls._value2member_map_
'''
class NetMessage:
def __init__(self):
2020-12-15 18:00:44 +00:00
self._msg_class = None # 2 bytes
self._len = None # 4 bytes
self._msg_type = None # 2 bytes
'''
# Ensure handshake keys are not reused by including the time in the msg, mac and key hash
# Verify timestamp is not too old
# Add keys to db to catch concurrent attempts, records can be cleared periodically, the timestamp should catch older replay attempts
class MsgHandshake:
__slots__ = ('_timestamp', '_ephem_pk', '_ct', '_mac')
def __init__(self):
pass
def encode_aad(self): # Additional Authenticated Data
return struct.pack('>H', NetMessageTypes.HANDSHAKE) + \
struct.pack('>Q', self._timestamp) + \
self._ephem_pk
def encode(self):
return self.encode_aad() + self._ct + self._mac
def decode(self, msg_mv):
o = 2
self._timestamp = struct.unpack('>Q', msg_mv[o: o + 8])[0]
o += 8
self._ephem_pk = bytes(msg_mv[o: o + 33])
o += 33
self._ct = bytes(msg_mv[o: -16])
self._mac = bytes(msg_mv[-16:])
2020-11-30 14:29:40 +00:00
class Peer:
2020-12-15 18:00:44 +00:00
__slots__ = (
2020-12-18 21:04:06 +00:00
'_mx', '_pubkey', '_address', '_socket', '_version', '_ready', '_incoming',
2020-12-15 18:00:44 +00:00
'_connected_at', '_last_received_at', '_bytes_sent', '_bytes_received',
'_receiving_length', '_receiving_buffer', '_recv_messages', '_misbehaving_score',
'_ke', '_km', '_dir', '_sent_nonce', '_recv_nonce', '_last_handshake_at',
'_ping_nonce', '_last_ping_at', '_last_ping_rtt')
def __init__(self, address, socket, pubkey):
self._mx = threading.Lock()
self._pubkey = pubkey
self._address = address
2020-12-15 18:00:44 +00:00
self._socket = socket
self._version = None
2020-12-18 21:04:06 +00:00
self._ready = False # True when handshake is complete
self._incoming = False
2020-12-15 18:00:44 +00:00
self._connected_at = time.time()
self._last_received_at = 0
self._last_handshake_at = 0
self._bytes_sent = 0
self._bytes_received = 0
self._receiving_length = 0
self._receiving_buffer = None
self._recv_messages = queue.Queue() # Built in mutex
2020-12-18 21:04:06 +00:00
self._misbehaving_score = 0 # TODO: Must be persistent - save to db
2020-12-15 18:00:44 +00:00
self._ping_nonce = 0
self._last_ping_at = 0 # ms
self._last_ping_rtt = 0 # ms
def close(self):
self._socket.close()
def listen_thread(cls):
timeout = 1.0
max_bytes = 0x10000
while cls._running:
# logging.info('[rm] network loop %d', cls._running)
readable, writable, errored = select.select(cls._read_sockets, cls._write_sockets, cls._error_sockets, timeout)
cls._mx.acquire()
try:
disconnected_peers = []
for s in readable:
if s == cls._socket:
peer_socket, address = cls._socket.accept()
logging.info('Connection from %s', address)
2020-12-18 21:04:06 +00:00
new_peer = Peer(address, peer_socket, None)
new_peer._incoming = True
cls._peers.append(new_peer)
2020-12-15 18:00:44 +00:00
cls._error_sockets.append(peer_socket)
cls._read_sockets.append(peer_socket)
else:
for peer in cls._peers:
if peer._socket == s:
try:
bytes_recv = s.recv(max_bytes, socket.MSG_DONTWAIT)
except socket.error as se:
if se.args[0] not in (socket.EWOULDBLOCK, ):
logging.error('Receive error %s', str(se))
disconnected_peers.append(peer)
continue
except Exception as e:
logging.error('Receive error %s', str(e))
disconnected_peers.append(peer)
continue
if len(bytes_recv) < 1:
disconnected_peers.append(peer)
continue
cls.receive_bytes(peer, bytes_recv)
for s in errored:
logging.warning('Socket error')
for peer in disconnected_peers:
cls.disconnect(peer)
finally:
cls._mx.release()
def msg_thread(cls):
timeout = 0.1
while cls._running:
processed = False
2020-12-18 21:04:06 +00:00
with cls._mx:
for peer in cls._peers:
try:
now_us = time.time_ns() // 1000
if peer._ready is True:
if now_us - peer._last_ping_at >= 5000000: # 5 seconds TODO: Make variable
cls.send_ping(peer)
msg = peer._recv_messages.get(False)
cls.process_message(peer, msg)
processed = True
except queue.Empty:
pass
except Exception as e:
logging.warning('process message error %s', str(e))
if cls._sc.debug:
traceback.print_exc()
2020-12-15 18:00:44 +00:00
if processed is False:
time.sleep(timeout)
class Network:
2020-12-15 18:00:44 +00:00
__slots__ = (
'_p2p_host', '_p2p_port', '_network_key', '_network_pubkey',
'_sc', '_peers', '_max_connections', '_running', '_network_thread', '_msg_thread',
'_mx', '_socket', '_read_sockets', '_write_sockets', '_error_sockets', '_csprng', '_seen_ephem_keys')
def __init__(self, p2p_host, p2p_port, network_key, swap_client):
self._p2p_host = p2p_host
self._p2p_port = p2p_port
self._network_key = network_key
2020-12-15 18:00:44 +00:00
self._network_pubkey = PublicKey.from_secret(network_key).format()
self._sc = swap_client
self._peers = []
self._max_connections = 10
2020-12-15 18:00:44 +00:00
self._running = False
self._network_thread = None
2020-12-15 18:00:44 +00:00
self._msg_thread = None
self._mx = threading.Lock()
2020-12-15 18:00:44 +00:00
self._socket = None
self._read_sockets = []
self._write_sockets = []
self._error_sockets = [] # Check for error events
self._seen_ephem_keys = OrderedDict()
def startNetwork(self):
2020-12-15 18:00:44 +00:00
self._mx.acquire()
try:
self._csprng = rfc6979_hmac_sha256_initialize(secrets.token_bytes(32))
self._running = True
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._socket.bind((self._p2p_host, self._p2p_port))
self._socket.listen(self._max_connections)
self._read_sockets.append(self._socket)
self._network_thread = threading.Thread(target=listen_thread, args=(self,))
self._network_thread.start()
self._msg_thread = threading.Thread(target=msg_thread, args=(self,))
self._msg_thread.start()
finally:
self._mx.release()
def stopNetwork(self):
2020-12-15 18:00:44 +00:00
self._mx.acquire()
try:
self._running = False
finally:
self._mx.release()
2020-12-15 18:00:44 +00:00
if self._network_thread:
self._network_thread.join()
if self._msg_thread:
self._msg_thread.join()
2020-12-15 18:00:44 +00:00
self._mx.acquire()
try:
if self._socket:
self._socket.close()
for peer in self._peers:
peer.close()
finally:
self._mx.release()
def add_connection(self, host, port, peer_pubkey):
self._sc.log.info('Connecting from %s to %s at %s %d', self._network_pubkey.hex(), peer_pubkey.hex(), host, port)
self._mx.acquire()
try:
address = (host, port)
peer_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
peer_socket.connect(address)
peer = Peer(address, peer_socket, peer_pubkey)
self._peers.append(peer)
self._error_sockets.append(peer_socket)
self._read_sockets.append(peer_socket)
finally:
self._mx.release()
self.send_handshake(peer)
def disconnect(self, peer):
self._sc.log.info('Closing peer socket %s', peer._address)
self._read_sockets.pop(self._read_sockets.index(peer._socket))
self._error_sockets.pop(self._error_sockets.index(peer._socket))
peer.close()
self._peers.pop(self._peers.index(peer))
def check_handshake_ephem_key(self, peer, timestamp, ephem_pk, direction=1):
# assert ._mx.acquire() ?
used = self._seen_ephem_keys.get(ephem_pk)
if used:
raise ValueError('Handshake ephem_pk reused %s peer %s', 'for' if direction == 1 else 'by', used[0])
self._seen_ephem_keys[ephem_pk] = (peer._address, timestamp)
while len(self._seen_ephem_keys) > MAX_SEEN_EPHEM_KEYS:
self._seen_ephem_keys.popitem(last=False)
def send_handshake(self, peer):
self._sc.log.debug('send_handshake %s', peer._address)
peer._mx.acquire()
try:
# TODO: Drain peer._recv_messages
if not peer._recv_messages.empty():
self._sc.log.warning('send_handshake %s - Receive queue dumped.', peer._address)
while not peer._recv_messages.empty():
peer._recv_messages.get(False)
msg = MsgHandshake()
msg._timestamp = int(time.time())
key_r = rfc6979_hmac_sha256_generate(self._csprng, 32)
k = PrivateKey(key_r)
msg._ephem_pk = PublicKey.from_secret(key_r).format()
self.check_handshake_ephem_key(peer, msg._timestamp, msg._ephem_pk)
ss = k.ecdh(peer._pubkey)
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
peer._ke = hashed[:32]
peer._km = hashed[32:]
nonce = peer._km[24:]
payload = self._sc._version
nk = PrivateKey(self._network_key)
sig = nk.sign_recoverable(peer._km)
payload += sig
aad = msg.encode_aad()
aad += nonce
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(aad)
msg._ct, msg._mac = cipher.encrypt_and_digest(payload)
peer._sent_nonce = hashlib.sha256(nonce + msg._mac).digest()
peer._recv_nonce = hashlib.sha256(peer._km).digest() # Init nonce
peer._last_handshake_at = msg._timestamp
peer._ready = False # Wait for peer to complete handshake
self.send_msg(peer, msg)
finally:
peer._mx.release()
def process_handshake(self, peer, msg_mv):
self._sc.log.debug('process_handshake %s', peer._address)
# TODO: Drain peer._recv_messages
if not peer._recv_messages.empty():
self._sc.log.warning('process_handshake %s - Receive queue dumped.', peer._address)
while not peer._recv_messages.empty():
peer._recv_messages.get(False)
msg = MsgHandshake()
msg.decode(msg_mv)
try:
now = int(time.time())
if now - peer._last_handshake_at < 30:
raise ValueError('Too many handshakes from peer %s', peer._address)
if abs(msg._timestamp - now) > TIMESTAMP_LEEWAY:
raise ValueError('Bad handshake timestamp from peer %s', peer._address)
self.check_handshake_ephem_key(peer, msg._timestamp, msg._ephem_pk, direction=2)
nk = PrivateKey(self._network_key)
ss = nk.ecdh(msg._ephem_pk)
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
peer._ke = hashed[:32]
peer._km = hashed[32:]
nonce = peer._km[24:]
aad = msg.encode_aad()
aad += nonce
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(aad)
plaintext = cipher.decrypt_and_verify(msg._ct, msg._mac) # Will raise error if mac doesn't match
peer._version = plaintext[:6]
sig = plaintext[6:]
pk_peer = PublicKey.from_signature_and_message(sig, peer._km)
# TODO: Should pk_peer be linked to public data?
peer._pubkey = pk_peer.format()
peer._recv_nonce = hashlib.sha256(nonce + msg._mac).digest()
peer._sent_nonce = hashlib.sha256(peer._km).digest() # Init nonce
peer._last_handshake_at = msg._timestamp
peer._ready = True
# Schedule a ping to complete the handshake, TODO: Send here?
peer._last_ping_at = 0
except Exception as e:
# TODO: misbehaving
self._sc.log.debug('[rm] process_handshake %s', str(e))
def process_ping(self, peer, msg_mv):
nonce = peer._recv_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_mv[0: 2])
cipher.update(nonce)
mac = msg_mv[-16:]
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
ping_nonce = struct.unpack('>I', plaintext[:4])[0]
# Version is added to a ping following a handshake message
if len(plaintext) >= 10:
peer._ready = True
version = plaintext[4: 10]
if peer._version is None:
peer._version = version
self._sc.log.debug('Set version from ping %s, %s', peer._pubkey.hex(), peer._version.hex())
peer._recv_nonce = hashlib.sha256(nonce + mac).digest()
self.send_pong(peer, ping_nonce)
def process_pong(self, peer, msg_mv):
nonce = peer._recv_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_mv[0: 2])
cipher.update(nonce)
mac = msg_mv[-16:]
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
pong_nonce = struct.unpack('>I', plaintext[:4])[0]
if pong_nonce == peer._ping_nonce:
peer._last_ping_rtt = (time.time_ns() // 1000) - peer._last_ping_at
else:
self._sc.log.debug('Pong received out of order %s', peer._address)
peer._recv_nonce = hashlib.sha256(nonce + mac).digest()
def send_ping(self, peer):
ping_nonce = random.getrandbits(32)
msg_bytes = struct.pack('>H', NetMessageTypes.PING)
nonce = peer._sent_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_bytes)
cipher.update(nonce)
payload = struct.pack('>I', ping_nonce)
if peer._last_ping_at == 0:
payload += self._sc._version
ct, mac = cipher.encrypt_and_digest(payload)
msg_bytes += ct + mac
peer._sent_nonce = hashlib.sha256(nonce + mac).digest()
peer._last_ping_at = time.time_ns() // 1000
peer._ping_nonce = ping_nonce
self.send_msg(peer, msg_bytes)
def send_pong(self, peer, ping_nonce):
msg_bytes = struct.pack('>H', NetMessageTypes.PONG)
nonce = peer._sent_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_bytes)
cipher.update(nonce)
payload = struct.pack('>I', ping_nonce)
ct, mac = cipher.encrypt_and_digest(payload)
msg_bytes += ct + mac
peer._sent_nonce = hashlib.sha256(nonce + mac).digest()
self.send_msg(peer, msg_bytes)
def send_msg(self, peer, msg):
msg_encoded = msg if isinstance(msg, bytes) else msg.encode()
len_encoded = len(msg_encoded)
msg_packed = bytearray(MSG_START_TOKEN) + struct.pack('>I', len_encoded) + msg_encoded
peer._socket.sendall(msg_packed)
peer._bytes_sent += len_encoded
def process_message(self, peer, msg_bytes):
logging.info('[rm] process_message %s len %d', peer._address, len(msg_bytes))
peer._mx.acquire()
try:
mv = memoryview(msg_bytes)
o = 0
msg_type = struct.unpack('>H', mv[o: o + 2])[0]
if msg_type == NetMessageTypes.HANDSHAKE:
self.process_handshake(peer, mv)
elif msg_type == NetMessageTypes.PING:
self.process_ping(peer, mv)
elif msg_type == NetMessageTypes.PONG:
self.process_pong(peer, mv)
else:
self._sc.log.debug('Unknown message type %d', msg_type)
finally:
peer._mx.release()
def receive_bytes(self, peer, bytes_recv):
# logging.info('[rm] receive_bytes %s %s', peer._address, bytes_recv)
len_received = len(bytes_recv)
peer._last_received_at = time.time()
peer._bytes_received += len_received
invalid_msg = False
mv = memoryview(bytes_recv)
o = 0
try:
while o < len_received:
if peer._receiving_length == 0:
if len(bytes_recv) < MSG_HEADER_LEN:
raise ValueError('Msg too short')
if mv[o: o + 2] != MSG_START_TOKEN:
raise ValueError('Invalid start token')
o += 2
msg_len = struct.unpack('>I', mv[o: o + 4])[0]
o += 4
if msg_len < 2 or msg_len > MSG_MAX_SIZE:
raise ValueError('Invalid data length')
# Precheck msg_type
msg_type = struct.unpack('>H', mv[o: o + 2])[0]
# o += 2 # Don't inc offset, msg includes type
if not NetMessageTypes.has_value(msg_type):
raise ValueError('Invalid msg type')
peer._receiving_length = msg_len
len_pkt = (len_received - o)
nc = msg_len if len_pkt > msg_len else len_pkt
peer._receiving_buffer = mv[o: o + nc]
o += nc
else:
len_to_go = peer._receiving_length - len(peer._receiving_buffer)
len_pkt = (len_received - o)
nc = len_to_go if len_pkt > len_to_go else len_pkt
peer._receiving_buffer = mv[o: o + nc]
o += nc
if len(peer._receiving_buffer) == peer._receiving_length:
peer._recv_messages.put(peer._receiving_buffer)
peer._receiving_length = 0
except Exception as e:
if self._sc.debug:
self._sc.log.error('Invalid message received from %s %s', peer._address, str(e))
# TODO: misbehaving
2020-12-18 21:04:06 +00:00
def test_onion(self, path):
self._sc.log.debug('test_onion packet')
def get_info(self):
rv = {}
peers = []
with self._mx:
for peer in self._peers:
peer_info = {
'pubkey': 'Unknown' if not peer._pubkey else peer._pubkey.hex(),
'address': '{}:{}'.format(peer._address[0], peer._address[1]),
'bytessent': peer._bytes_sent,
'bytesrecv': peer._bytes_received,
'ready': peer._ready,
'incoming': peer._incoming,
}
peers.append(peer_info)
rv['peers'] = peers
return rv