diff --git a/src/net/tls.c b/src/net/tls.c index db01fb291..90f9f9767 100644 --- a/src/net/tls.c +++ b/src/net/tls.c @@ -1271,10 +1271,9 @@ static int tls_new_alert ( struct tls_session *tls, const void *data, uint8_t description; char next[0]; } __attribute__ (( packed )) *alert = data; - const void *end = alert->next; /* Sanity check */ - if ( end != ( data + len ) ) { + if ( sizeof ( *alert ) != len ) { DBGC ( tls, "TLS %p received overlength Alert\n", tls ); DBGC_HD ( tls, data, len ); return -EINVAL_ALERT; @@ -1310,24 +1309,28 @@ static int tls_new_server_hello ( struct tls_session *tls, uint16_t version; uint8_t random[32]; uint8_t session_id_len; - char next[0]; + uint8_t session_id[0]; } __attribute__ (( packed )) *hello_a = data; + const uint8_t *session_id; const struct { - uint8_t session_id[hello_a->session_id_len]; uint16_t cipher_suite; uint8_t compression_method; char next[0]; - } __attribute__ (( packed )) *hello_b = ( void * ) &hello_a->next; - const void *end = hello_b->next; + } __attribute__ (( packed )) *hello_b; uint16_t version; int rc; - /* Sanity check */ - if ( end > ( data + len ) ) { + /* Parse header */ + if ( ( sizeof ( *hello_a ) > len ) || + ( hello_a->session_id_len > ( len - sizeof ( *hello_a ) ) ) || + ( sizeof ( *hello_b ) > ( len - sizeof ( *hello_a ) - + hello_a->session_id_len ) ) ) { DBGC ( tls, "TLS %p received underlength Server Hello\n", tls ); DBGC_HD ( tls, data, len ); return -EINVAL_HELLO; } + session_id = hello_a->session_id; + hello_b = ( ( void * ) ( session_id + hello_a->session_id_len ) ); /* Check and store protocol version */ version = ntohs ( hello_a->version ); @@ -1380,14 +1383,7 @@ static int tls_new_server_hello ( struct tls_session *tls, */ static int tls_parse_chain ( struct tls_session *tls, const void *data, size_t len ) { - const void *end = ( data + len ); - const struct { - tls24_t length; - uint8_t data[0]; - } __attribute__ (( packed )) *certificate; - size_t certificate_len; - struct x509_certificate *cert; - const void *next; + size_t remaining = len; int rc; /* Free any existing certificate chain */ @@ -1402,25 +1398,37 @@ static int tls_parse_chain ( struct tls_session *tls, } /* Add certificates to chain */ - while ( data < end ) { + while ( remaining ) { + const struct { + tls24_t length; + uint8_t data[0]; + } __attribute__ (( packed )) *certificate = data; + size_t certificate_len; + size_t record_len; + struct x509_certificate *cert; - /* Extract raw certificate data */ - certificate = data; + /* Parse header */ + if ( sizeof ( *certificate ) > remaining ) { + DBGC ( tls, "TLS %p underlength certificate:\n", tls ); + DBGC_HDA ( tls, 0, data, remaining ); + rc = -EINVAL_CERTIFICATE; + goto err_underlength; + } certificate_len = tls_uint24 ( &certificate->length ); - next = ( certificate->data + certificate_len ); - if ( next > end ) { + if ( certificate_len > ( remaining - sizeof ( *certificate ) )){ DBGC ( tls, "TLS %p overlength certificate:\n", tls ); - DBGC_HDA ( tls, 0, data, ( end - data ) ); + DBGC_HDA ( tls, 0, data, remaining ); rc = -EINVAL_CERTIFICATE; goto err_overlength; } + record_len = ( sizeof ( *certificate ) + certificate_len ); /* Add certificate to chain */ if ( ( rc = x509_append_raw ( tls->chain, certificate->data, certificate_len ) ) != 0 ) { DBGC ( tls, "TLS %p could not append certificate: %s\n", tls, strerror ( rc ) ); - DBGC_HDA ( tls, 0, data, ( end - data ) ); + DBGC_HDA ( tls, 0, data, remaining ); goto err_parse; } cert = x509_last ( tls->chain ); @@ -1428,13 +1436,15 @@ static int tls_parse_chain ( struct tls_session *tls, tls, x509_name ( cert ) ); /* Move to next certificate in list */ - data = next; + data += record_len; + remaining -= record_len; } return 0; err_parse: err_overlength: + err_underlength: x509_chain_put ( tls->chain ); tls->chain = NULL; err_alloc_chain: @@ -1455,12 +1465,18 @@ static int tls_new_certificate ( struct tls_session *tls, tls24_t length; uint8_t certificates[0]; } __attribute__ (( packed )) *certificate = data; - size_t certificates_len = tls_uint24 ( &certificate->length ); - const void *end = ( certificate->certificates + certificates_len ); + size_t certificates_len; int rc; - /* Sanity check */ - if ( end != ( data + len ) ) { + /* Parse header */ + if ( sizeof ( *certificate ) > len ) { + DBGC ( tls, "TLS %p received underlength Server Certificate\n", + tls ); + DBGC_HD ( tls, data, len ); + return -EINVAL_CERTIFICATES; + } + certificates_len = tls_uint24 ( &certificate->length ); + if ( certificates_len > ( len - sizeof ( *certificate ) ) ) { DBGC ( tls, "TLS %p received overlength Server Certificate\n", tls ); DBGC_HD ( tls, data, len ); @@ -1521,11 +1537,10 @@ static int tls_new_server_hello_done ( struct tls_session *tls, const struct { char next[0]; } __attribute__ (( packed )) *hello_done = data; - const void *end = hello_done->next; int rc; /* Sanity check */ - if ( end != ( data + len ) ) { + if ( sizeof ( *hello_done ) != len ) { DBGC ( tls, "TLS %p received overlength Server Hello Done\n", tls ); DBGC_HD ( tls, data, len ); @@ -1557,12 +1572,11 @@ static int tls_new_finished ( struct tls_session *tls, uint8_t verify_data[12]; char next[0]; } __attribute__ (( packed )) *finished = data; - const void *end = finished->next; uint8_t digest_out[ digest->digestsize ]; uint8_t verify_data[ sizeof ( finished->verify_data ) ]; /* Sanity check */ - if ( end != ( data + len ) ) { + if ( sizeof ( *finished ) != len ) { DBGC ( tls, "TLS %p received overlength Finished\n", tls ); DBGC_HD ( tls, data, len ); return -EINVAL_FINISHED; @@ -1598,27 +1612,37 @@ static int tls_new_finished ( struct tls_session *tls, */ static int tls_new_handshake ( struct tls_session *tls, const void *data, size_t len ) { - const void *end = ( data + len ); + size_t remaining = len; int rc; - while ( data != end ) { + while ( remaining ) { const struct { uint8_t type; tls24_t length; uint8_t payload[0]; } __attribute__ (( packed )) *handshake = data; - const void *payload = &handshake->payload; - size_t payload_len = tls_uint24 ( &handshake->length ); - const void *next = ( payload + payload_len ); + const void *payload; + size_t payload_len; + size_t record_len; - /* Sanity check */ - if ( next > end ) { + /* Parse header */ + if ( sizeof ( *handshake ) > remaining ) { + DBGC ( tls, "TLS %p received underlength Handshake\n", + tls ); + DBGC_HD ( tls, data, remaining ); + return -EINVAL_HANDSHAKE; + } + payload_len = tls_uint24 ( &handshake->length ); + if ( payload_len > ( remaining - sizeof ( *handshake ) ) ) { DBGC ( tls, "TLS %p received overlength Handshake\n", tls ); DBGC_HD ( tls, data, len ); return -EINVAL_HANDSHAKE; } + payload = &handshake->payload; + record_len = ( sizeof ( *handshake ) + payload_len ); + /* Handle payload */ switch ( handshake->type ) { case TLS_SERVER_HELLO: rc = tls_new_server_hello ( tls, payload, payload_len ); @@ -1648,16 +1672,15 @@ static int tls_new_handshake ( struct tls_session *tls, * which are explicitly excluded). */ if ( handshake->type != TLS_HELLO_REQUEST ) - tls_add_handshake ( tls, data, - sizeof ( *handshake ) + - payload_len ); + tls_add_handshake ( tls, data, record_len ); /* Abort on failure */ if ( rc != 0 ) return rc; /* Move to next handshake record */ - data = next; + data += record_len; + remaining -= record_len; } return 0;