Skip to content

Commit

Permalink
allow use of RSA and ECDSA together
Browse files Browse the repository at this point in the history
  • Loading branch information
tomato42 committed Nov 6, 2019
1 parent 2408a3b commit ca065ca
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 25 deletions.
101 changes: 101 additions & 0 deletions tests/tlstest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
AlertDescription, HTTPTLSConnection, TLSSocketServerMixIn, \
POP3_TLS, m2cryptoLoaded, pycryptoLoaded, gmpyLoaded, tackpyLoaded, \
Checker, __version__
from tlslite.handshakesettings import VirtualHost, Keypair

from tlslite.errors import *
from tlslite.utils.cryptomath import prngName, getRandomBytes
Expand Down Expand Up @@ -360,6 +361,84 @@ def connect():

test_no += 1

print("Test {0} - good RSA and ECDSA, TLSv1.3, rsa"
.format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeClientCert(settings=settings)
testConnClient(connection)
assert connection.session.cipherSuite in\
constants.CipherSuite.tls13Suites
assert isinstance(connection.session.serverCertChain, X509CertChain)
assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\
== "rsa"
assert connection.version == (3, 4)
connection.close()

test_no += 1

print("Test {0} - good RSA and ECDSA, TLSv1.3, ecdsa"
.format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
settings.rsaSigHashes = []
connection.handshakeClientCert(settings=settings)
testConnClient(connection)
assert connection.session.cipherSuite in\
constants.CipherSuite.tls13Suites
assert isinstance(connection.session.serverCertChain, X509CertChain)
assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\
== "ecdsa"
assert connection.version == (3, 4)
connection.close()

test_no += 1

print("Test {0} - good RSA and ECDSA, TLSv1.2, rsa"
.format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 3)
settings.maxVersion = (3, 3)
connection.handshakeClientCert(settings=settings)
testConnClient(connection)
assert connection.session.cipherSuite in\
constants.CipherSuite.ecdheCertSuites, connection.session.cipherSuite
assert isinstance(connection.session.serverCertChain, X509CertChain)
assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\
== "rsa"
assert connection.version == (3, 3)
connection.close()

test_no += 1

print("Test {0} - good RSA and ECDSA, TLSv1.2, ecdsa"
.format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 3)
settings.maxVersion = (3, 3)
settings.rsaSigHashes = []
connection.handshakeClientCert(settings=settings)
testConnClient(connection)
assert connection.session.cipherSuite in\
constants.CipherSuite.ecdheEcdsaSuites, connection.session.cipherSuite
assert isinstance(connection.session.serverCertChain, X509CertChain)
assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\
== "ecdsa"
assert connection.version == (3, 3)
connection.close()

test_no += 1

print("Test {0} - good X.509, mismatched key_share".format(test_no))
synchro.recv(1)
connection = connect()
Expand Down Expand Up @@ -1502,6 +1581,28 @@ def connect():

test_no += 1

for prot in ["TLSv1.3", "TLSv1.2"]:
for c_type, exp_chain in (("rsa", x509Chain),
("ecdsa", x509ecdsaChain)):
print("Test {0} - good RSA and ECDSA, {2}, {1}"
.format(test_no, c_type, prot))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 3)
settings.maxVersion = (3, 4)
v_host = VirtualHost()
v_host.keys = [Keypair(x509ecdsaKey, x509ecdsaChain.x509List)]
settings.virtual_hosts = [v_host]
connection.handshakeServer(certChain=x509Chain,
privateKey=x509Key, settings=settings)
assert connection.extendedMasterSecret
assert connection.session.serverCertChain == exp_chain
testConnServer(connection)
connection.close()

test_no += 1

print("Test {0} - good X.509, mismatched key_share".format(test_no))
synchro.send(b'R')
connection = connect()
Expand Down
52 changes: 27 additions & 25 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3136,20 +3136,18 @@ def _serverGetClientHello(self, settings, private_key, cert_chain,
for result in self._sendMsg(alert):
yield result

sig_scheme = None
if version >= (3, 4):
try:
sig_scheme, cert_chain, private_key = \
self._pickServerKeyExchangeSig(settings,
clientHello,
cert_chain,
private_key,
version)
except TLSHandshakeFailure as alert:
for result in self._sendError(
AlertDescription.handshake_failure,
str(alert)):
yield result
try:
sig_scheme, cert_chain, private_key = \
self._pickServerKeyExchangeSig(settings,
clientHello,
cert_chain,
private_key,
version)
except TLSHandshakeFailure as alert:
for result in self._sendError(
AlertDescription.handshake_failure,
str(alert)):
yield result

#Check if there's intersection between supported curves by client and
#server
Expand Down Expand Up @@ -4121,19 +4119,23 @@ def _pickServerKeyExchangeSig(settings, clientHello, certList=None,
# sha1 should be picked
return "sha1", certList, private_key

supported = TLSConnection._sigHashesToList(settings,
certList=certList,
version=version)
alt_certs = ((X509CertChain(i.certificates), i.key) for vh in
settings.virtual_hosts for i in vh.keys)

for certs, key in chain([(certList, private_key)], alt_certs):
supported = TLSConnection._sigHashesToList(settings,
certList=certs,
version=version)

for schemeID in supported:
if schemeID in hashAndAlgsExt.sigalgs:
name = SignatureScheme.toRepr(schemeID)
if not name and schemeID[1] in (SignatureAlgorithm.rsa,
SignatureAlgorithm.ecdsa):
name = HashAlgorithm.toRepr(schemeID[0])
for schemeID in supported:
if schemeID in hashAndAlgsExt.sigalgs:
name = SignatureScheme.toRepr(schemeID)
if not name and schemeID[1] in (SignatureAlgorithm.rsa,
SignatureAlgorithm.ecdsa):
name = HashAlgorithm.toRepr(schemeID[0])

if name:
return name, certList, private_key
if name:
return name, certs, key

# if no match, we must abort per RFC 5246
raise TLSHandshakeFailure("No common signature algorithms")
Expand Down

0 comments on commit ca065ca

Please sign in to comment.