From c8709a0fa66b08c1956d57cf8aaa4b76a89a4914 Mon Sep 17 00:00:00 2001 From: Alex Moneger Date: Fri, 18 Nov 2016 11:05:07 -0800 Subject: [PATCH] Half working secure re-negotiation - This change does not work fully. It's exploratory work to address issue #75. It's messy and hacky --- examples/tls_client_with_renegotiation.py | 51 +++++++++ scapy_ssl_tls/ssl_tls.py | 79 ++++++++------ scapy_ssl_tls/ssl_tls_crypto.py | 121 ++++++++++++---------- 3 files changed, 169 insertions(+), 82 deletions(-) create mode 100644 examples/tls_client_with_renegotiation.py diff --git a/examples/tls_client_with_renegotiation.py b/examples/tls_client_with_renegotiation.py new file mode 100644 index 0000000..b443120 --- /dev/null +++ b/examples/tls_client_with_renegotiation.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import with_statement +from __future__ import print_function + +try: + # This import works from the project directory + from scapy_ssl_tls.ssl_tls import * +except ImportError: + # If you installed this package via pip, you just need to execute this + from scapy.layers.ssl_tls import * + + +tls_version = TLSVersion.TLS_1_2 +ciphers = [TLSCipherSuite.ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLSCipherSuite.EMPTY_RENEGOTIATION_INFO_SCSV] +extensions = [TLSExtension() / TLSExtRenegotiationInfo(data="")] + + +def tls_client(ip): + with TLSSocket(socket.socket(), client=True) as tls_socket: + try: + tls_socket.connect(ip) + tls_ctx = tls_socket.tls_ctx + except socket.timeout: + print("Failed to open connection to server: %s" % (ip,), file=sys.stderr) + else: + print("Connected to server: %s" % (ip,)) + try: + server_hello, server_kex = tls_socket.do_handshake(tls_version, ciphers, extensions) + client_verify_data = tls_ctx.client_ctx.verify_data + renegotiation = [TLSExtension() / TLSExtRenegotiationInfo(data=client_verify_data)] + # RSA_WITH_AES_128_CBC_SHA DHE_RSA_WITH_AES_256_CBC_SHA256 + server_hello, server_kex = tls_socket.do_secure_renegotiation(tls_version, [TLSCipherSuite.RSA_WITH_AES_128_CBC_SHA], renegotiation) + # client_hello = TLSHandshake() / TLSClientHello(version=tls_version, cipher_suites=, extensions=renegotiation) + # r = tls_socket.do_round_trip(to_raw(client_hello, tls_ctx)) + server_kex.show() + # http_response = tls_socket.do_round_trip(to_raw(TLSPlaintext(data="GET / HTTP/1.1\r\nHOST: localhost\r\n\r\n"), tls_socket.tls_ctx)) + # http_response.show() + except TLSProtocolError as pe: + print(pe) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + server = (sys.argv[1], int(sys.argv[2])) + else: + server = ("127.0.0.1", 8443) + tls_client(server) diff --git a/scapy_ssl_tls/ssl_tls.py b/scapy_ssl_tls/ssl_tls.py index 57d94f1..97cca45 100644 --- a/scapy_ssl_tls/ssl_tls.py +++ b/scapy_ssl_tls/ssl_tls.py @@ -956,11 +956,12 @@ def __getattr__(self, attr): except AttributeError: return getattr(self._s, attr) - def sendall(self, pkt, timeout=2): + def sendall(self, pkt, timeout=2, save=True): prev_timeout = self._s.gettimeout() self._s.settimeout(timeout) self._s.sendall(str(pkt)) - self.tls_ctx.insert(pkt) + if save: + self.tls_ctx.insert(pkt) self._s.settimeout(prev_timeout) def recvall(self, size=8192, timeout=0.5): @@ -995,6 +996,9 @@ def do_handshake(self, version, ciphers, extensions=[]): def do_round_trip(self, pkt, recv=True): return tls_do_round_trip(self, pkt, recv) + def do_secure_renegotiation(self, version, ciphers, extensions=[]): + return tls_do_secure_renegotiation(self, version, ciphers, extensions) + # entry class class SSL(Packet): @@ -1042,6 +1046,7 @@ def do_dissect(self, raw_bytes): payload_len = record(raw_bytes[pos:pos + record_header_len]).length if self.tls_ctx is not None: payload = record(raw_bytes[pos:pos + record_header_len + payload_len], ctx=self.tls_ctx) + payload = self.do_decrypt(payload) self.tls_ctx.insert(payload) else: payload = record(raw_bytes[pos:pos + record_header_len + payload_len]) @@ -1053,6 +1058,23 @@ def do_dissect(self, raw_bytes): # This will always be empty (equivalent to returning "") return raw_bytes[pos:] + def do_decrypt(self, record): + encrypted_payload, layer = self._get_encrypted_payload(record) + if encrypted_payload is not None: + try: + if self.tls_ctx.client: + cleartext = self.tls_ctx.server_ctx.crypto_ctx.decrypt(encrypted_payload, + record.content_type) + else: + cleartext = self.tls_ctx.client_ctx.crypto_ctx.decrypt(encrypted_payload, + record.content_type) + pkt = layer(cleartext, ctx=self.tls_ctx) + record[self.guessed_next_layer].payload = pkt + # Decryption failed, raise error otherwise we'll be in inconsistent state with sender + except ValueError as ve: + raise ValueError("Decryption failed: %s" % ve) + return record + def _get_encrypted_payload(self, record): encrypted_payload = None decrypted_type = None @@ -1074,30 +1096,6 @@ def _get_encrypted_payload(self, record): decrypted_type = TLSPlaintext return encrypted_payload, decrypted_type - def post_dissect(self, s): - if self.tls_ctx is not None: - for record in self.records: - encrypted_payload, layer = self._get_encrypted_payload(record) - if encrypted_payload is not None: - try: - if self.tls_ctx.client: - cleartext = self.tls_ctx.server_ctx.crypto_ctx.decrypt(encrypted_payload, - record.content_type) - else: - cleartext = self.tls_ctx.client_ctx.crypto_ctx.decrypt(encrypted_payload, - record.content_type) - pkt = layer(cleartext, ctx=self.tls_ctx) - original_record = record - record[self.guessed_next_layer].payload = pkt - # If the encrypted is in the history packet list, update it with the unencrypted version - if original_record in self.tls_ctx.history: - record_index = self.tls_ctx.history.index(original_record) - self.tls_ctx.history[record_index] = record - # Decryption failed, raise error otherwise we'll be in inconsistent state with sender - except ValueError as ve: - raise ValueError("Decryption failed: %s" % ve) - return s - TLS = SSL cleartext_handler = {TLSPlaintext: lambda pkt, tls_ctx: (TLSContentType.APPLICATION_DATA, pkt[TLSPlaintext].data), @@ -1105,7 +1103,8 @@ def post_dissect(self, s): str(TLSHandshake(type=TLSHandshakeType.FINISHED) / tls_ctx.get_verify_data())), TLSChangeCipherSpec: lambda pkt, tls_ctx: (TLSContentType.CHANGE_CIPHER_SPEC, str(pkt)), - TLSAlert: lambda pkt, tls_ctx: (TLSContentType.ALERT, str(pkt))} + TLSAlert: lambda pkt, tls_ctx: (TLSContentType.ALERT, str(pkt)), + TLSHandshake: lambda pkt, tls_ctx: (TLSContentType.HANDSHAKE, str(pkt))} def to_raw(pkt, tls_ctx, include_record=True, compress_hook=None, pre_encrypt_hook=None, encrypt_hook=None): @@ -1167,10 +1166,10 @@ def __init__(self, *args, **kwargs): Exception.__init__(self, args[0], **kwargs) -def tls_do_round_trip(tls_socket, pkt, recv=True): +def tls_do_round_trip(tls_socket, pkt, recv=True, save=True): resp = TLS() try: - tls_socket.sendall(pkt) + tls_socket.sendall(pkt, save) if recv: resp = tls_socket.recvall() if resp.haslayer(TLSAlert): @@ -1184,6 +1183,7 @@ def tls_do_round_trip(tls_socket, pkt, recv=True): def tls_do_handshake(tls_socket, version, ciphers, extensions=[]): + print(version, ciphers) client_hello = TLSRecord(version=version) / TLSHandshake() / TLSClientHello(version=version, cipher_suites=ciphers, extensions=extensions) resp1 = tls_do_round_trip(tls_socket, client_hello) @@ -1195,6 +1195,27 @@ def tls_do_handshake(tls_socket, version, ciphers, extensions=[]): return resp1, resp2 +def tls_do_secure_renegotiation(tls_socket, version, ciphers, extensions=[]): + # This is hacky, depends on insertion sequence + tls_ctx = tls_socket.tls_ctx + client_hello = TLSHandshake() / TLSClientHello(version=version, cipher_suites=ciphers, extensions=extensions) + tls_ctx.insert(client_hello) + resp1 = tls_do_round_trip(tls_socket, to_raw(client_hello, tls_ctx), save=False) + + client_key_exchange = TLSHandshake() / tls_ctx.get_client_kex_data() + client_ccs = TLSChangeCipherSpec() + + print(tls_ctx) + tls_do_round_trip(tls_socket, to_raw(client_key_exchange, tls_ctx), False, save=False) + tls_ctx.insert(client_key_exchange) + tls_do_round_trip(tls_socket, to_raw(client_ccs, tls_ctx), False, save=False) + tls_ctx.insert(client_ccs) + + print(tls_ctx) + resp2 = tls_do_round_trip(tls_socket, to_raw(TLSFinished(), tls_socket.tls_ctx)) + return resp1, resp2 + + def tls_fragment_payload(pkt, record=None, size=2**14): if size <= 0: raise ValueError("Fragment size must be strictly positive") diff --git a/scapy_ssl_tls/ssl_tls_crypto.py b/scapy_ssl_tls/ssl_tls_crypto.py index d8b7ac6..78132d2 100644 --- a/scapy_ssl_tls/ssl_tls_crypto.py +++ b/scapy_ssl_tls/ssl_tls_crypto.py @@ -79,6 +79,7 @@ def __init__(self, name): self.nonce = 0 self.random = None self.session_id = None + self.verify_data = None self.crypto_ctx = None self.compression = None self.asym_keystore = tlsk.EmptyAsymKeystore() @@ -105,10 +106,11 @@ def __str__(self): {name}: random: {random} session_id: {sess_id} + verify_data: {verify_data} {asym_ks} {kex_ks} {sym_ks}""" - return template.format(name=self.name, random=repr(self.random), sess_id=repr(self.session_id), + return template.format(name=self.name, random=repr(self.random), sess_id=repr(self.session_id), verify_data=repr(self.verify_data), asym_ks=self.asym_keystore, kex_ks=self.kex_keystore, sym_ks=self.sym_keystore) @@ -139,6 +141,7 @@ def __init__(self, client=True): self.premaster_secret = None self.master_secret = None self.prf = None + self.num_ccs = 0 def __str__(self): template = """ @@ -249,30 +252,28 @@ def __handle_cert_list(self, cert_list): def __handle_server_kex(self, server_kex): # DHE case if server_kex.haslayer(tls.TLSServerDHParams): - if isinstance(self.server_ctx.kex_keystore, tlsk.EmptyKexKeystore): - p = str_to_int(server_kex[tls.TLSServerDHParams].p) - g = str_to_int(server_kex[tls.TLSServerDHParams].g) - public = str_to_int(server_kex[tls.TLSServerDHParams].y_s) - self.server_ctx.kex_keystore = tlsk.DHKeyStore(g, p, public) + p = str_to_int(server_kex[tls.TLSServerDHParams].p) + g = str_to_int(server_kex[tls.TLSServerDHParams].g) + public = str_to_int(server_kex[tls.TLSServerDHParams].y_s) + self.server_ctx.kex_keystore = tlsk.DHKeyStore(g, p, public) elif server_kex.haslayer(tls.TLSServerECDHParams): - if isinstance(self.server_ctx.kex_keystore, tlsk.EmptyKexKeystore): + try: + curve_id = server_kex[tls.TLSServerECDHParams].curve_name + # TODO: DO NOT assume uncompressed EC points! + point = ansi_str_to_point(server_kex[tls.TLSServerECDHParams].p) + curve_name = tls.TLS_ELLIPTIC_CURVES[curve_id] + # Unknown curve case. Just record raw values, but do nothing with them + except KeyError: + self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(None, point) + warnings.warn("Unknown elliptic curve id: %d. Client KEX calculation is up to you" % curve_id) + # We are on a known curve + else: try: - curve_id = server_kex[tls.TLSServerECDHParams].curve_name - # TODO: DO NOT assume uncompressed EC points! - point = ansi_str_to_point(server_kex[tls.TLSServerECDHParams].p) - curve_name = tls.TLS_ELLIPTIC_CURVES[curve_id] - # Unknown curve case. Just record raw values, but do nothing with them - except KeyError: + curve = ec_reg.get_curve(curve_name) + self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(curve, ec.Point(curve, *point)) + except ValueError: self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(None, point) - warnings.warn("Unknown elliptic curve id: %d. Client KEX calculation is up to you" % curve_id) - # We are on a known curve - else: - try: - curve = ec_reg.get_curve(curve_name) - self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(curve, ec.Point(curve, *point)) - except ValueError: - self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(None, point) - warnings.warn("Unsupported elliptic curve: %s" % curve_name) + warnings.warn("Unsupported elliptic curve: %s" % curve_name) else: warnings.warn("Unknown server key exchange") @@ -290,44 +291,43 @@ def __handle_client_kex(self, client_kex): # End workaround self.premaster_secret = PKCS1_v1_5.new(private).decrypt(self.encrypted_premaster_secret, None) elif client_kex.haslayer(tls.TLSClientDHParams): - # Check if we have an unitialized keystore, and if so build a new one - if isinstance(self.client_ctx.kex_keystore, tlsk.EmptyKexKeystore): - server_kex_keystore = self.server_ctx.kex_keystore - # Check if server side is a DH keystore. Something is messed up otherwise - if isinstance(server_kex_keystore, tlsk.DHKeyStore): - client_public = str_to_int(client_kex[tls.TLSClientDHParams].data) - self.client_ctx.kex_keystore = tlsk.DHKeyStore(server_kex_keystore.g, - server_kex_keystore.p, client_public) - else: - raise RuntimeError("Server keystore is not a DH keystore") - # TODO: Calculate PMS + server_kex_keystore = self.server_ctx.kex_keystore + # Check if server side is a DH keystore. Something is messed up otherwise + if isinstance(server_kex_keystore, tlsk.DHKeyStore): + client_public = str_to_int(client_kex[tls.TLSClientDHParams].data) + self.client_ctx.kex_keystore = tlsk.DHKeyStore(server_kex_keystore.g, + server_kex_keystore.p, client_public) + else: + raise RuntimeError("Server keystore is not a DH keystore") + # TODO: Calculate PMS elif client_kex.haslayer(tls.TLSClientECDHParams): - # Check if we have an unitialized keystore, and if so build a new one - if isinstance(self.client_ctx.kex_keystore, tlsk.EmptyKexKeystore): - server_kex_keystore = self.server_ctx.kex_keystore - # Check if server side is a ECDH keystore. Something is messed up otherwise - if isinstance(server_kex_keystore, tlsk.ECDHKeyStore): - curve = server_kex_keystore.curve - point = ansi_str_to_point(client_kex[tls.TLSClientECDHParams].data) - self.client_ctx.kex_keystore = tlsk.ECDHKeyStore(curve, ec.Point(curve, *point)) - # TODO: Calculate PMS + server_kex_keystore = self.server_ctx.kex_keystore + # Check if server side is a ECDH keystore. Something is messed up otherwise + if isinstance(server_kex_keystore, tlsk.ECDHKeyStore): + curve = server_kex_keystore.curve + point = ansi_str_to_point(client_kex[tls.TLSClientECDHParams].data) + self.client_ctx.kex_keystore = tlsk.ECDHKeyStore(curve, ec.Point(curve, *point)) + # TODO: Calculate PMS else: warnings.warn("Unknown client key exchange") + + + def __generate_secrets(self): + self.__generate_client_secrets() + self.__generate_server_secrets() + + def __generate_client_secrets(self): self.sec_params = TLSSecurityParameters.from_pre_master_secret(self.prf, self.negotiated.ciphersuite, self.premaster_secret, self.client_ctx.random, self.server_ctx.random) - self.__generate_secrets() - - def __generate_secrets(self): - if isinstance(self.client_ctx.sym_keystore, tlsk.EmptySymKeyStore): - self.client_ctx.sym_keystore = self.sec_params.client_keystore - if isinstance(self.server_ctx.sym_keystore, tlsk.EmptySymKeyStore): - self.server_ctx.sym_keystore = self.sec_params.server_keystore self.master_secret = self.sec_params.master_secret - # Retrieve ciphers used for client/server encryption and decryption - # TODO: use factory to assign CryptoContext + self.client_ctx.sym_keystore = self.sec_params.client_keystore factory = CryptoContextFactory(self) self.client_ctx.crypto_ctx = factory.new(self.client_ctx) + + def __generate_server_secrets(self): + self.server_ctx.sym_keystore = self.sec_params.server_keystore + factory = CryptoContextFactory(self) self.server_ctx.crypto_ctx = factory.new(self.server_ctx) def _process(self, pkt): @@ -346,15 +346,23 @@ def _process(self, pkt): self.__handle_server_kex(pkt[tls.TLSServerKeyExchange]) if pkt.haslayer(tls.TLSClientKeyExchange): self.__handle_client_kex(pkt[tls.TLSClientKeyExchange]) + if pkt.haslayer(tls.TLSChangeCipherSpec): + # Dirty hack to initialize the crypto store only once per channel + if self.num_ccs % 2 == 0: + self.__generate_client_secrets() + else: + self.__generate_server_secrets() + self.num_ccs += 1 def _generate_random_pms(self, version): return "%s%s" % (struct.pack("!H", version), os.urandom(46)) def get_encrypted_pms(self, pms=None): - cleartext = pms or self.premaster_secret + if pms is not None: + self.premaster_secret = pms public = self.server_ctx.asym_keystore.public if public is not None: - self.encrypted_premaster_secret = PKCS1_v1_5.new(public).encrypt(cleartext) + self.encrypted_premaster_secret = PKCS1_v1_5.new(public).encrypt(self.premaster_secret) else: raise ValueError("Cannot calculate encrypted MS. No server certificate found in connection") return self.encrypted_premaster_secret @@ -411,8 +419,10 @@ def _walk_handshake_msgs(self): def get_verify_data(self, data=None): if self.client: label = TLSPRF.TLS_MD_CLIENT_FINISH_CONST + ctx = self.client_ctx else: label = TLSPRF.TLS_MD_SERVER_FINISH_CONST + ctx = self.server_ctx if data is None: verify_data = [] for handshake in self._walk_handshake_msgs(): @@ -434,6 +444,7 @@ def get_verify_data(self, data=None): "%s%s" % (MD5.new("".join(verify_data)).digest(), SHA.new("".join(verify_data)).digest()), num_bytes=12) + ctx.verify_data = ctx.verify_data or prf_verify_data return prf_verify_data def get_handshake_hash(self, hash_): @@ -1133,6 +1144,10 @@ class TLSSecurityParameters(object): "key_exchange": {"type": ECDHE, "name": tls.TLSKexNames.ECDHE, "sig": ECDSA}, "cipher": {"type": AES, "name": "AES", "key_len": 32, "mode": AES.MODE_CCM, "mode_name": CipherMode.AEAD}, "hash": {"type": NullHash, "name": "NULL"}}, + tls.TLSCipherSuite.DHE_RSA_WITH_AES_256_CBC_SHA256: {"name": tls.TLS_CIPHER_SUITES[0x006b], "export": False, + "key_exchange": {"type": DHE, "name": tls.TLSKexNames.DHE, "sig": RSA}, + "cipher": {"type": AES, "name": "AES", "key_len": 32, "mode": AES.MODE_CBC, "mode_name": CipherMode.CBC}, + "hash": {"type": SHA256, "name": "SHA256"}} # 0x0087: DHE_DSS_WITH_CAMELLIA_256_CBC_SHA => Camelia support should use camcrypt or the camelia patch for pycrypto # 0x0088: DHE_RSA_WITH_CAMELLIA_256_CBC_SHA => Camelia support should use camcrypt or the camelia patch for pycrypto }