Skip to content

Commit

Permalink
Made tls_do_handshake take an extension argument
Browse files Browse the repository at this point in the history
- Allows to perform fast handshakes while controlling extensions
- Added a do_round_trip method, to performing send/recv whilst
  wrapping errors
  • Loading branch information
alexmgr committed Nov 10, 2016
1 parent 1125a70 commit 1b047e6
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions scapy_ssl_tls/ssl_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,12 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def do_handshake(self, version, ciphers, extensions=[]):
return tls_do_handshake(self, version, ciphers, extensions)

def do_round_trip(self, pkt, recv=True):
return tls_do_round_trip(self, pkt, recv)


# entry class
class SSL(Packet):
Expand Down Expand Up @@ -1149,23 +1155,44 @@ class TLSProtocolError(Exception):

def __init__(self, *args, **kwargs):
try:
self.pkt = kwargs["pkt"]
except KeyError:
self.pkt = None
Exception.__init__(self, *args, **kwargs)
self.response = args[2]
except IndexError:
self.response = kwargs.get("response", TLS())

try:
self.request = args[1]
except IndexError:
self.request = kwargs.get("request", TLS())

Exception.__init__(self, args[0], **kwargs)


def tls_do_round_trip(tls_socket, pkt, recv=True):
resp = TLS()
try:
tls_socket.sendall(pkt)
if recv:
resp = tls_socket.recvall()
if resp.haslayer(TLSAlert):
alert = resp[TLSAlert]
level = TLS_ALERT_LEVELS.get(alert.level, "unknown")
description = TLS_ALERT_DESCRIPTIONS.get(alert.description, "unknown description")
raise TLSProtocolError("%s alert returned by server: %s" % (level.upper(), description.upper()), pkt, resp)
except socket.error as se:
raise TLSProtocolError(se, pkt, resp)
return resp


def tls_do_handshake(tls_socket, version, ciphers, extensions=[]):
client_hello = TLSRecord(version=version) / TLSHandshake() / TLSClientHello(version=version, cipher_suites=ciphers, extensions=extensions)
resp1 = tls_do_round_trip(tls_socket, client_hello)

def tls_do_handshake(tls_socket, version, ciphers):
client_hello = TLSRecord(version=version) / TLSHandshake() / TLSClientHello(version=version, cipher_suites=ciphers)
tls_socket.sendall(client_hello)
r = tls_socket.recvall()
if r.haslayer(TLSAlert):
raise TLSProtocolError("Alert returned by server", r)
client_key_exchange = TLSRecord(version=version) / TLSHandshake() / tls_socket.tls_ctx.get_client_kex_data()
client_ccs = TLSRecord(version=version) / TLSChangeCipherSpec()
tls_socket.sendall(TLS.from_records([client_key_exchange, client_ccs]))
tls_socket.sendall(to_raw(TLSFinished(), tls_socket.tls_ctx))
tls_socket.recvall()
tls_do_round_trip(tls_socket, TLS.from_records([client_key_exchange, client_ccs]), False)

resp2 = tls_do_round_trip(tls_socket, to_raw(TLSFinished(), tls_socket.tls_ctx))
return resp1, resp2


def tls_fragment_payload(pkt, record=None, size=2**14):
Expand Down

0 comments on commit 1b047e6

Please sign in to comment.