From 9645e87961dc68effa8bb7b48444a84b35cc9b15 Mon Sep 17 00:00:00 2001
From: tecnovert <tecnovert@tecnovert.net>
Date: Thu, 11 May 2023 23:45:06 +0200
Subject: [PATCH] protocol: Sign for key halves when not swapping XMR

---
 basicswap/basicswap.py          | 74 ++++++++++++++++++++-------------
 basicswap/interface/__init__.py | 13 ++++++
 basicswap/interface/btc.py      | 17 ++++++++
 basicswap/interface/xmr.py      |  6 +++
 4 files changed, 82 insertions(+), 28 deletions(-)

diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py
index 389fb6a..4721d31 100644
--- a/basicswap/basicswap.py
+++ b/basicswap/basicswap.py
@@ -30,6 +30,7 @@ from typing import Optional
 from sqlalchemy.orm import sessionmaker, scoped_session
 from sqlalchemy.orm.session import close_all_sessions
 
+from .interface import Curves
 from .interface.part import PARTInterface, PARTInterfaceAnon, PARTInterfaceBlind
 from .interface.btc import BTCInterface
 from .interface.ltc import LTCInterface
@@ -2438,7 +2439,7 @@ class BasicSwap(BaseApp):
             xmr_swap.contract_count = self.getNewContractId()
             xmr_swap.dest_af = msg_buf.dest_af
 
-            for_ed25519 = True if coin_to == Coins.XMR else False
+            for_ed25519 = True if ci_to.curve_type() == Curves.ed25519 else False
             kbvf = self.getPathKey(coin_from, coin_to, bid_created_at, xmr_swap.contract_count, KeyTypes.KBVF, for_ed25519)
             kbsf = self.getPathKey(coin_from, coin_to, bid_created_at, xmr_swap.contract_count, KeyTypes.KBSF, for_ed25519)
 
@@ -2450,19 +2451,26 @@ class BasicSwap(BaseApp):
 
             xmr_swap.pkaf = ci_from.getPubkey(kaf)
 
-            if coin_to == Coins.XMR:
+            if ci_to.curve_type() == Curves.ed25519:
                 xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf)
+                xmr_swap.pkasf = xmr_swap.kbsf_dleag[0: 33]
+                msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[:16000]
+            elif ci_to.curve_type() == Curves.secp256k1:
+                for i in range(10):
+                    xmr_swap.kbsf_dleag = ci_to.signRecoverable(kbsf, 'proof kbsf owned for swap')
+                    pk_recovered = ci_to.verifySigAndRecover(xmr_swap.kbsf_dleag, 'proof kbsf owned for swap')
+                    if pk_recovered == xmr_swap.pkbsf:
+                        break
+                    self.log.debug('kbsl recovered pubkey mismatch, retrying.')
+                assert (pk_recovered == xmr_swap.pkbsf)
+                xmr_swap.pkasf = xmr_swap.pkbsf
+                msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag
             else:
-                xmr_swap.kbsf_dleag = xmr_swap.pkbsf
-            xmr_swap.pkasf = xmr_swap.kbsf_dleag[0: 33]
+                raise ValueError('Unknown curve')
             assert (xmr_swap.pkasf == ci_from.getPubkey(kbsf))
 
             msg_buf.pkaf = xmr_swap.pkaf
             msg_buf.kbvf = kbvf
-            if coin_to == Coins.XMR:
-                msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[:16000]
-            else:
-                msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag
 
             bid_bytes = msg_buf.SerializeToString()
             payload_hex = str.format('{:02x}', MessageTypes.XMR_BID_FL) + bid_bytes.hex()
@@ -2472,7 +2480,7 @@ class BasicSwap(BaseApp):
             msg_valid = max(self.SMSG_SECONDS_IN_HOUR * 1, valid_for_seconds)
             xmr_swap.bid_id = self.sendSmsg(bid_addr, offer.addr_from, payload_hex, msg_valid)
 
-            if coin_to == Coins.XMR:
+            if ci_to.curve_type() == Curves.ed25519:
                 msg_buf2 = XmrSplitMessage(
                     msg_id=xmr_swap.bid_id,
                     msg_type=XmrSplitMsgTypes.BID,
@@ -2562,7 +2570,7 @@ class BasicSwap(BaseApp):
             if xmr_swap.contract_count is None:
                 xmr_swap.contract_count = self.getNewContractId()
 
-            for_ed25519 = True if coin_to == Coins.XMR else False
+            for_ed25519 = True if ci_to.curve_type() == Curves.ed25519 else False
             kbvl = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBVL, for_ed25519)
             kbsl = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBSL, for_ed25519)
 
@@ -2579,11 +2587,6 @@ class BasicSwap(BaseApp):
 
             xmr_swap.pkal = ci_from.getPubkey(kal)
 
-            if coin_to == Coins.XMR:
-                xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl)
-            else:
-                xmr_swap.kbsl_dleag = xmr_swap.pkbsl
-
             # MSG2F
             pi = self.pi(SwapTypes.XMR_SWAP)
             xmr_swap.a_lock_tx_script = pi.genScriptLockTxScript(ci_from, xmr_swap.pkal, xmr_swap.pkaf)
@@ -2660,10 +2663,21 @@ class BasicSwap(BaseApp):
             msg_buf.bid_msg_id = bid_id
             msg_buf.pkal = xmr_swap.pkal
             msg_buf.kbvl = kbvl
-            if coin_to == Coins.XMR:
+
+            if ci_to.curve_type() == Curves.ed25519:
+                xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl)
                 msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:16000]
-            else:
+            elif ci_to.curve_type() == Curves.secp256k1:
+                for i in range(10):
+                    xmr_swap.kbsl_dleag = ci_to.signRecoverable(kbsl, 'proof kbsl owned for swap')
+                    pk_recovered = ci_to.verifySigAndRecover(xmr_swap.kbsl_dleag, 'proof kbsl owned for swap')
+                    if pk_recovered == xmr_swap.pkbsl:
+                        break
+                    self.log.debug('kbsl recovered pubkey mismatch, retrying.')
+                assert (pk_recovered == xmr_swap.pkbsl)
                 msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag
+            else:
+                raise ValueError('Unknown curve')
 
             # MSG2F
             msg_buf.a_lock_tx = xmr_swap.a_lock_tx
@@ -2680,7 +2694,7 @@ class BasicSwap(BaseApp):
             bid.accept_msg_id = self.sendSmsg(offer.addr_from, bid.bid_addr, payload_hex, msg_valid)
             xmr_swap.bid_accept_msg_id = bid.accept_msg_id
 
-            if coin_to == Coins.XMR:
+            if ci_to.curve_type() == Curves.ed25519:
                 msg_buf2 = XmrSplitMessage(
                     msg_id=bid_id,
                     msg_type=XmrSplitMsgTypes.BID_ACCEPT,
@@ -4496,7 +4510,7 @@ class BasicSwap(BaseApp):
         ci_from = self.ci(Coins(offer.coin_from))
         ci_to = self.ci(Coins(offer.coin_to))
 
-        if offer.coin_to == Coins.XMR:
+        if ci_to.curve_type() == Curves.ed25519:
             if len(xmr_swap.kbsf_dleag) < ci_to.lengthDLEAG():
                 q = session.query(XmrSplitData).filter(sa.and_(XmrSplitData.bid_id == bid.bid_id, XmrSplitData.msg_type == XmrSplitMsgTypes.BID)).order_by(XmrSplitData.msg_sequence.asc())
                 for row in q:
@@ -4514,11 +4528,13 @@ class BasicSwap(BaseApp):
             xmr_swap.pkbsf = xmr_swap.kbsf_dleag[33: 33 + 32]
             if not ci_to.verifyPubkey(xmr_swap.pkbsf):
                 raise ValueError('Invalid coin b pubkey.')
-        else:
-            xmr_swap.pkasf = xmr_swap.kbsf_dleag[0: 33]
+        elif ci_to.curve_type() == Curves.secp256k1:
+            xmr_swap.pkasf = ci_to.verifySigAndRecover(xmr_swap.kbsf_dleag, 'proof kbsf owned for swap')
             if not ci_from.verifyPubkey(xmr_swap.pkasf):
                 raise ValueError('Invalid coin a pubkey.')
             xmr_swap.pkbsf = xmr_swap.pkasf
+        else:
+            raise ValueError('Unknown curve')
 
         ensure(ci_to.verifyKey(xmr_swap.vkbvf), 'Invalid key, vkbvf')
         ensure(ci_from.verifyPubkey(xmr_swap.pkaf), 'Invalid pubkey, pkaf')
@@ -4548,7 +4564,7 @@ class BasicSwap(BaseApp):
         ci_from = self.ci(offer.coin_from)
         ci_to = self.ci(offer.coin_to)
 
-        if offer.coin_to == Coins.XMR:
+        if ci_to.curve_type() == Curves.ed25519:
             if len(xmr_swap.kbsl_dleag) < ci_to.lengthDLEAG():
                 q = session.query(XmrSplitData).filter(sa.and_(XmrSplitData.bid_id == bid.bid_id, XmrSplitData.msg_type == XmrSplitMsgTypes.BID_ACCEPT)).order_by(XmrSplitData.msg_sequence.asc())
                 for row in q:
@@ -4565,11 +4581,13 @@ class BasicSwap(BaseApp):
             xmr_swap.pkbsl = xmr_swap.kbsl_dleag[33: 33 + 32]
             if not ci_to.verifyPubkey(xmr_swap.pkbsl):
                 raise ValueError('Invalid coin b pubkey.')
-        else:
-            xmr_swap.pkasl = xmr_swap.kbsl_dleag[0: 33]
+        elif ci_to.curve_type() == Curves.secp256k1:
+            xmr_swap.pkasl = ci_to.verifySigAndRecover(xmr_swap.kbsl_dleag, 'proof kbsl owned for swap')
             if not ci_from.verifyPubkey(xmr_swap.pkasl):
                 raise ValueError('Invalid coin a pubkey.')
             xmr_swap.pkbsl = xmr_swap.pkasl
+        else:
+            raise ValueError('Unknown curve')
 
         # vkbv and vkbvl are verified in processXmrBidAccept
         xmr_swap.pkbv = ci_to.sumPubkeys(xmr_swap.pkbvl, xmr_swap.pkbvf)
@@ -5023,7 +5041,7 @@ class BasicSwap(BaseApp):
         ci_from = self.ci(coin_from)
         ci_to = self.ci(coin_to)
 
-        for_ed25519 = True if coin_to == Coins.XMR else False
+        for_ed25519 = True if ci_to.curve_type() == Curves.ed25519 else False
         kbsf = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBSF, for_ed25519)
         kaf = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KAF)
 
@@ -5085,7 +5103,7 @@ class BasicSwap(BaseApp):
             kbsf = ci_from.recoverEncKey(xmr_swap.al_lock_spend_tx_esig, xmr_swap.al_lock_spend_tx_sig, xmr_swap.pkasf)
             assert (kbsf is not None)
 
-            for_ed25519 = True if coin_to == Coins.XMR else False
+            for_ed25519 = True if ci_to.curve_type() == Curves.ed25519 else False
             kbsl = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBSL, for_ed25519)
             vkbs = ci_to.sumKeys(kbsl, kbsf)
 
@@ -5144,7 +5162,7 @@ class BasicSwap(BaseApp):
         kbsl = ci_from.recoverEncKey(xmr_swap.af_lock_refund_spend_tx_esig, af_lock_refund_spend_tx_sig, xmr_swap.pkasl)
         assert (kbsl is not None)
 
-        for_ed25519 = True if coin_to == Coins.XMR else False
+        for_ed25519 = True if ci_to.curve_type() == Curves.ed25519 else False
         kbsf = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBSF, for_ed25519)
         vkbs = ci_to.sumKeys(kbsl, kbsf)
 
@@ -5245,7 +5263,7 @@ class BasicSwap(BaseApp):
             xmr_swap.af_lock_refund_spend_tx_esig = msg_data.af_lock_refund_spend_tx_esig
             xmr_swap.af_lock_refund_tx_sig = msg_data.af_lock_refund_tx_sig
 
-            for_ed25519 = True if coin_to == Coins.XMR else False
+            for_ed25519 = True if ci_to.curve_type() == Curves.ed25519 else False
             kbsl = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KBSL, for_ed25519)
             kal = self.getPathKey(coin_from, coin_to, bid.created_at, xmr_swap.contract_count, KeyTypes.KAL)
 
diff --git a/basicswap/interface/__init__.py b/basicswap/interface/__init__.py
index e69de29..bb1bde0 100644
--- a/basicswap/interface/__init__.py
+++ b/basicswap/interface/__init__.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2023 tecnovert
+# Distributed under the MIT software license, see the accompanying
+# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
+
+from enum import IntEnum
+
+
+class Curves(IntEnum):
+    secp256k1 = 1
+    ed25519 = 2
diff --git a/basicswap/interface/btc.py b/basicswap/interface/btc.py
index 6a3efa4..82af90b 100644
--- a/basicswap/interface/btc.py
+++ b/basicswap/interface/btc.py
@@ -14,6 +14,8 @@ from io import BytesIO
 
 from basicswap.contrib.test_framework import segwit_addr
 
+from basicswap.interface import (
+    Curves)
 from basicswap.util import (
     ensure,
     make_int,
@@ -109,6 +111,10 @@ def find_vout_for_address_from_txobj(tx_obj, addr: str) -> int:
 
 
 class BTCInterface(CoinInterface):
+    @staticmethod
+    def curve_type():
+        return Curves.secp256k1
+
     @staticmethod
     def coin_type():
         return Coins.BTC
@@ -1167,12 +1173,23 @@ class BTCInterface(CoinInterface):
         privkey = PrivateKey(k)
         return privkey.sign_recoverable(message_hash, hasher=None)[:64]
 
+    def signRecoverable(self, k, message):
+        message_hash = hashlib.sha256(bytes(message, 'utf-8')).digest()
+
+        privkey = PrivateKey(k)
+        return privkey.sign_recoverable(message_hash, hasher=None)
+
     def verifyCompactSig(self, K, message, sig):
         message_hash = hashlib.sha256(bytes(message, 'utf-8')).digest()
         pubkey = PublicKey(K)
         rv = pubkey.verify_compact(sig, message_hash, hasher=None)
         assert (rv is True)
 
+    def verifySigAndRecover(self, sig, message):
+        message_hash = hashlib.sha256(bytes(message, 'utf-8')).digest()
+        pubkey = PublicKey.from_signature_and_message(sig, message_hash, hasher=None)
+        return pubkey.format()
+
     def verifyMessage(self, address: str, message: str, signature: str, message_magic: str = None) -> bool:
         if message_magic is None:
             message_magic = self.chainparams()['message_magic']
diff --git a/basicswap/interface/xmr.py b/basicswap/interface/xmr.py
index 36135a1..a0e7848 100644
--- a/basicswap/interface/xmr.py
+++ b/basicswap/interface/xmr.py
@@ -24,6 +24,8 @@ from coincurve.dleag import (
     verify_ed25519_point,
 )
 
+from basicswap.interface import (
+    Curves)
 from basicswap.util import (
     dumpj,
     ensure,
@@ -38,6 +40,10 @@ from basicswap.chainparams import XMR_COIN, CoinInterface, Coins
 
 
 class XMRInterface(CoinInterface):
+    @staticmethod
+    def curve_type():
+        return Curves.ed25519
+
     @staticmethod
     def coin_type():
         return Coins.XMR