From 1b047e62f4af6467042af846e55eb0b5e2db1530 Mon Sep 17 00:00:00 2001 From: Alex Moneger Date: Wed, 9 Nov 2016 16:48:01 -0800 Subject: [PATCH] Made tls_do_handshake take an extension argument - Allows to perform fast handshakes while controlling extensions - Added a do_round_trip method, to performing send/recv whilst wrapping errors --- scapy_ssl_tls/ssl_tls.py | 53 ++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/scapy_ssl_tls/ssl_tls.py b/scapy_ssl_tls/ssl_tls.py index e213ad2..57d94f1 100644 --- a/scapy_ssl_tls/ssl_tls.py +++ b/scapy_ssl_tls/ssl_tls.py @@ -989,6 +989,12 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() + def do_handshake(self, version, ciphers, extensions=[]): + return tls_do_handshake(self, version, ciphers, extensions) + + def do_round_trip(self, pkt, recv=True): + return tls_do_round_trip(self, pkt, recv) + # entry class class SSL(Packet): @@ -1149,23 +1155,44 @@ class TLSProtocolError(Exception): def __init__(self, *args, **kwargs): try: - self.pkt = kwargs["pkt"] - except KeyError: - self.pkt = None - Exception.__init__(self, *args, **kwargs) + self.response = args[2] + except IndexError: + self.response = kwargs.get("response", TLS()) + + try: + self.request = args[1] + except IndexError: + self.request = kwargs.get("request", TLS()) + Exception.__init__(self, args[0], **kwargs) + + +def tls_do_round_trip(tls_socket, pkt, recv=True): + resp = TLS() + try: + tls_socket.sendall(pkt) + if recv: + resp = tls_socket.recvall() + if resp.haslayer(TLSAlert): + alert = resp[TLSAlert] + level = TLS_ALERT_LEVELS.get(alert.level, "unknown") + description = TLS_ALERT_DESCRIPTIONS.get(alert.description, "unknown description") + raise TLSProtocolError("%s alert returned by server: %s" % (level.upper(), description.upper()), pkt, resp) + except socket.error as se: + raise TLSProtocolError(se, pkt, resp) + return resp + + +def tls_do_handshake(tls_socket, version, ciphers, extensions=[]): + client_hello = TLSRecord(version=version) / TLSHandshake() / TLSClientHello(version=version, cipher_suites=ciphers, extensions=extensions) + resp1 = tls_do_round_trip(tls_socket, client_hello) -def tls_do_handshake(tls_socket, version, ciphers): - client_hello = TLSRecord(version=version) / TLSHandshake() / TLSClientHello(version=version, cipher_suites=ciphers) - tls_socket.sendall(client_hello) - r = tls_socket.recvall() - if r.haslayer(TLSAlert): - raise TLSProtocolError("Alert returned by server", r) client_key_exchange = TLSRecord(version=version) / TLSHandshake() / tls_socket.tls_ctx.get_client_kex_data() client_ccs = TLSRecord(version=version) / TLSChangeCipherSpec() - tls_socket.sendall(TLS.from_records([client_key_exchange, client_ccs])) - tls_socket.sendall(to_raw(TLSFinished(), tls_socket.tls_ctx)) - tls_socket.recvall() + tls_do_round_trip(tls_socket, TLS.from_records([client_key_exchange, client_ccs]), False) + + resp2 = tls_do_round_trip(tls_socket, to_raw(TLSFinished(), tls_socket.tls_ctx)) + return resp1, resp2 def tls_fragment_payload(pkt, record=None, size=2**14):