From 553bc9b478f48580c6c51ddaa65c906cac0ee4e7 Mon Sep 17 00:00:00 2001
From: jsing <>
Date: Sat, 4 Dec 2021 14:03:22 +0000
Subject: Clean up and refactor server side DHE key exchange.

Provide ssl_kex_generate_dhe_params_auto() which handles DHE key generation
based on parameters determined by the specified key bits. Convert the
existing DHE auto parameter selection code into a function that just tells
us how many key bits to use.

Untangle and rework the server side DHE key exchange to use the ssl_kex_*
functions.

ok inoguchi@ tb@
---
 src/lib/libssl/ssl_kex.c  |  48 ++++++++++++++++-
 src/lib/libssl/ssl_lib.c  |  50 +++++------------
 src/lib/libssl/ssl_locl.h |   5 +-
 src/lib/libssl/ssl_srvr.c | 133 +++++++++++++++++++---------------------------
 4 files changed, 116 insertions(+), 120 deletions(-)

diff --git a/src/lib/libssl/ssl_kex.c b/src/lib/libssl/ssl_kex.c
index 639981bec9..78b528b168 100644
--- a/src/lib/libssl/ssl_kex.c
+++ b/src/lib/libssl/ssl_kex.c
@@ -1,6 +1,6 @@
-/* $OpenBSD: ssl_kex.c,v 1.7 2021/12/04 13:50:35 jsing Exp $ */
+/* $OpenBSD: ssl_kex.c,v 1.8 2021/12/04 14:03:22 jsing Exp $ */
 /*
- * Copyright (c) 2020 Joel Sing <jsing@openbsd.org>
+ * Copyright (c) 2020, 2021 Joel Sing <jsing@openbsd.org>
  *
  * Permission to use, copy, modify, and distribute this software for any
  * purpose with or without fee is hereby granted, provided that the above
@@ -17,6 +17,7 @@
 
 #include <stdlib.h>
 
+#include <openssl/bn.h>
 #include <openssl/dh.h>
 #include <openssl/ec.h>
 #include <openssl/ecdh.h>
@@ -40,7 +41,50 @@ ssl_kex_generate_dhe(DH *dh, DH *dh_params)
 
 	if (!DH_set0_pqg(dh, p, NULL, g))
 		goto err;
+	p = NULL;
+	g = NULL;
+
+	if (!DH_generate_key(dh))
+		goto err;
+
+	ret = 1;
+
+ err:
+	BN_free(p);
+	BN_free(g);
+
+	return ret;
+}
 
+int
+ssl_kex_generate_dhe_params_auto(DH *dh, size_t key_bits)
+{
+	BIGNUM *p = NULL, *g = NULL;
+	int ret = 0;
+
+	if (key_bits >= 8192)
+		p = get_rfc3526_prime_8192(NULL);
+	else if (key_bits >= 4096)
+		p = get_rfc3526_prime_4096(NULL);
+	else if (key_bits >= 3072)
+		p = get_rfc3526_prime_3072(NULL);
+	else if (key_bits >= 2048)
+		p = get_rfc3526_prime_2048(NULL);
+	else if (key_bits >= 1536)
+		p = get_rfc3526_prime_1536(NULL);
+	else
+		p = get_rfc2409_prime_1024(NULL);
+
+	if (p == NULL)
+		goto err;
+
+	if ((g = BN_new()) == NULL)
+		goto err;
+	if (!BN_set_word(g, 2))
+		goto err;
+
+	if (!DH_set0_pqg(dh, p, NULL, g))
+		goto err;
 	p = NULL;
 	g = NULL;
 
diff --git a/src/lib/libssl/ssl_lib.c b/src/lib/libssl/ssl_lib.c
index 662013378e..a0d3d05775 100644
--- a/src/lib/libssl/ssl_lib.c
+++ b/src/lib/libssl/ssl_lib.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_lib.c,v 1.279 2021/11/14 22:31:29 tb Exp $ */
+/* $OpenBSD: ssl_lib.c,v 1.280 2021/12/04 14:03:22 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -147,7 +147,6 @@
 #include <limits.h>
 #include <stdio.h>
 
-#include <openssl/bn.h>
 #include <openssl/dh.h>
 #include <openssl/lhash.h>
 #include <openssl/objects.h>
@@ -2319,54 +2318,29 @@ ssl_get_sign_pkey(SSL *s, const SSL_CIPHER *cipher, const EVP_MD **pmd,
 	return (pkey);
 }
 
-DH *
-ssl_get_auto_dh(SSL *s)
+size_t
+ssl_dhe_params_auto_key_bits(SSL *s)
 {
 	CERT_PKEY *cpk;
-	int keylen;
-	DH *dhp;
+	int key_bits;
 
 	if (s->cert->dh_tmp_auto == 2) {
-		keylen = 1024;
+		key_bits = 1024;
 	} else if (S3I(s)->hs.cipher->algorithm_auth & SSL_aNULL) {
-		keylen = 1024;
+		key_bits = 1024;
 		if (S3I(s)->hs.cipher->strength_bits == 256)
-			keylen = 3072;
+			key_bits = 3072;
 	} else {
 		if ((cpk = ssl_get_server_send_pkey(s)) == NULL)
-			return (NULL);
+			return 0;
 		if (cpk->privatekey == NULL ||
 		    EVP_PKEY_get0_RSA(cpk->privatekey) == NULL)
-			return (NULL);
-		if ((keylen = EVP_PKEY_bits(cpk->privatekey)) <= 0)
-			return (NULL);
+			return 0;
+		if ((key_bits = EVP_PKEY_bits(cpk->privatekey)) <= 0)
+			return 0;
 	}
 
-	if ((dhp = DH_new()) == NULL)
-		return (NULL);
-
-	dhp->g = BN_new();
-	if (dhp->g != NULL)
-		BN_set_word(dhp->g, 2);
-
-	if (keylen >= 8192)
-		dhp->p = get_rfc3526_prime_8192(NULL);
-	else if (keylen >= 4096)
-		dhp->p = get_rfc3526_prime_4096(NULL);
-	else if (keylen >= 3072)
-		dhp->p = get_rfc3526_prime_3072(NULL);
-	else if (keylen >= 2048)
-		dhp->p = get_rfc3526_prime_2048(NULL);
-	else if (keylen >= 1536)
-		dhp->p = get_rfc3526_prime_1536(NULL);
-	else
-		dhp->p = get_rfc2409_prime_1024(NULL);
-
-	if (dhp->p == NULL || dhp->g == NULL) {
-		DH_free(dhp);
-		return (NULL);
-	}
-	return (dhp);
+	return key_bits;
 }
 
 static int
diff --git a/src/lib/libssl/ssl_locl.h b/src/lib/libssl/ssl_locl.h
index 0051989ea0..d53c9ec273 100644
--- a/src/lib/libssl/ssl_locl.h
+++ b/src/lib/libssl/ssl_locl.h
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_locl.h,v 1.371 2021/12/04 13:50:35 jsing Exp $ */
+/* $OpenBSD: ssl_locl.h,v 1.372 2021/12/04 14:03:22 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -1343,7 +1343,7 @@ int ssl_undefined_const_function(const SSL *s);
 CERT_PKEY *ssl_get_server_send_pkey(const SSL *s);
 EVP_PKEY *ssl_get_sign_pkey(SSL *s, const SSL_CIPHER *c, const EVP_MD **pmd,
     const struct ssl_sigalg **sap);
-DH *ssl_get_auto_dh(SSL *s);
+size_t ssl_dhe_params_auto_key_bits(SSL *s);
 int ssl_cert_type(X509 *x, EVP_PKEY *pkey);
 void ssl_set_cert_masks(CERT *c, const SSL_CIPHER *cipher);
 STACK_OF(SSL_CIPHER) *ssl_get_ciphers_by_id(SSL *s);
@@ -1448,6 +1448,7 @@ int ssl3_get_client_key_exchange(SSL *s);
 int ssl3_get_cert_verify(SSL *s);
 
 int ssl_kex_generate_dhe(DH *dh, DH *dh_params);
+int ssl_kex_generate_dhe_params_auto(DH *dh, size_t key_len);
 int ssl_kex_params_dhe(DH *dh, CBB *cbb);
 int ssl_kex_public_dhe(DH *dh, CBB *cbb);
 int ssl_kex_peer_params_dhe(DH *dh, CBS *cbs, int *invalid_params);
diff --git a/src/lib/libssl/ssl_srvr.c b/src/lib/libssl/ssl_srvr.c
index 0c217d6d3e..e9ea6b141c 100644
--- a/src/lib/libssl/ssl_srvr.c
+++ b/src/lib/libssl/ssl_srvr.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_srvr.c,v 1.126 2021/11/29 16:03:56 jsing Exp $ */
+/* $OpenBSD: ssl_srvr.c,v 1.127 2021/12/04 14:03:22 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -1309,43 +1309,38 @@ ssl3_send_server_done(SSL *s)
 static int
 ssl3_send_server_kex_dhe(SSL *s, CBB *cbb)
 {
-	DH *dh = NULL, *dhp;
+	DH *dh = NULL;
 	int al;
 
+	if ((dh = DH_new()) == NULL)
+		goto err;
+
 	if (s->cert->dh_tmp_auto != 0) {
-		if ((dhp = ssl_get_auto_dh(s)) == NULL) {
+		size_t key_bits;
+
+		if ((key_bits = ssl_dhe_params_auto_key_bits(s)) == 0) {
 			al = SSL_AD_INTERNAL_ERROR;
 			SSLerror(s, ERR_R_INTERNAL_ERROR);
 			goto fatal_err;
 		}
-	} else
-		dhp = s->cert->dh_tmp;
 
-	if (dhp == NULL && s->cert->dh_tmp_cb != NULL)
-		dhp = s->cert->dh_tmp_cb(s, 0,
-		    SSL_C_PKEYLENGTH(S3I(s)->hs.cipher));
+		if (!ssl_kex_generate_dhe_params_auto(dh, key_bits))
+			goto err;
+	} else {
+		DH *dh_params = s->cert->dh_tmp;
 
-	if (dhp == NULL) {
-		al = SSL_AD_HANDSHAKE_FAILURE;
-		SSLerror(s, SSL_R_MISSING_TMP_DH_KEY);
-		goto fatal_err;
-	}
+		if (dh_params == NULL && s->cert->dh_tmp_cb != NULL)
+			dh_params = s->cert->dh_tmp_cb(s, 0,
+			    SSL_C_PKEYLENGTH(S3I(s)->hs.cipher));
 
-	if (S3I(s)->tmp.dh != NULL) {
-		SSLerror(s, ERR_R_INTERNAL_ERROR);
-		goto err;
-	}
+		if (dh_params == NULL) {
+			al = SSL_AD_HANDSHAKE_FAILURE;
+			SSLerror(s, SSL_R_MISSING_TMP_DH_KEY);
+			goto fatal_err;
+		}
 
-	if (s->cert->dh_tmp_auto != 0) {
-		dh = dhp;
-	} else if ((dh = DHparams_dup(dhp)) == NULL) {
-		SSLerror(s, ERR_R_DH_LIB);
-		goto err;
-	}
-	S3I(s)->tmp.dh = dh;
-	if (!DH_generate_key(dh)) {
-		SSLerror(s, ERR_R_DH_LIB);
-		goto err;
+		if (!ssl_kex_generate_dhe(dh, dh_params))
+			goto err;
 	}
 
 	if (!ssl_kex_params_dhe(dh, cbb))
@@ -1353,12 +1348,20 @@ ssl3_send_server_kex_dhe(SSL *s, CBB *cbb)
 	if (!ssl_kex_public_dhe(dh, cbb))
 		goto err;
 
-	return (1);
+	if (S3I(s)->tmp.dh != NULL) {
+		SSLerror(s, ERR_R_INTERNAL_ERROR);
+		goto err;
+	}
+	S3I(s)->tmp.dh = dh;
+
+	return 1;
 
  fatal_err:
 	ssl3_send_alert(s, SSL3_AL_FATAL, al);
  err:
-	return (-1);
+	DH_free(dh);
+
+	return -1;
 }
 
 static int
@@ -1787,53 +1790,35 @@ ssl3_get_client_kex_rsa(SSL *s, CBS *cbs)
 static int
 ssl3_get_client_kex_dhe(SSL *s, CBS *cbs)
 {
-	int key_size = 0;
-	int key_is_invalid, key_len, al;
-	unsigned char *key = NULL;
-	BIGNUM *bn = NULL;
-	CBS dh_Yc;
-	DH *dh;
-
-	if (!CBS_get_u16_length_prefixed(cbs, &dh_Yc))
-		goto decode_err;
-	if (CBS_len(cbs) != 0)
-		goto decode_err;
+	DH *dh_clnt = NULL;
+	DH *dh_srvr;
+	int invalid_key;
+	uint8_t *key = NULL;
+	size_t key_len = 0;
+	int ret = -1;
 
-	if (S3I(s)->tmp.dh == NULL) {
-		al = SSL_AD_HANDSHAKE_FAILURE;
+	if ((dh_srvr = S3I(s)->tmp.dh) == NULL) {
+		ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
 		SSLerror(s, SSL_R_MISSING_TMP_DH_KEY);
-		goto fatal_err;
+		goto err;
 	}
-	dh = S3I(s)->tmp.dh;
 
-	if ((bn = BN_bin2bn(CBS_data(&dh_Yc), CBS_len(&dh_Yc), NULL)) == NULL) {
-		SSLerror(s, SSL_R_BN_LIB);
+	if ((dh_clnt = DHparams_dup(dh_srvr)) == NULL)
 		goto err;
-	}
 
-	if ((key_size = DH_size(dh)) <= 0) {
-		SSLerror(s, ERR_R_DH_LIB);
+	if (!ssl_kex_peer_public_dhe(dh_clnt, cbs, &invalid_key)) {
+		ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+		SSLerror(s, SSL_R_BAD_PACKET_LENGTH);
 		goto err;
 	}
-	if ((key = malloc(key_size)) == NULL) {
-		SSLerror(s, ERR_R_MALLOC_FAILURE);
+	if (invalid_key) {
+		ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
+		SSLerror(s, SSL_R_BAD_DH_PUB_KEY_LENGTH);
 		goto err;
 	}
-	if (!DH_check_pub_key(dh, bn, &key_is_invalid)) {
-		al = SSL_AD_INTERNAL_ERROR;
-		SSLerror(s, ERR_R_DH_LIB);
-		goto fatal_err;
-	}
-	if (key_is_invalid) {
-		al = SSL_AD_ILLEGAL_PARAMETER;
-		SSLerror(s, ERR_R_DH_LIB);
-		goto fatal_err;
-	}
-	if ((key_len = DH_compute_key(key, bn, dh)) <= 0) {
-		al = SSL_AD_INTERNAL_ERROR;
-		SSLerror(s, ERR_R_DH_LIB);
-		goto fatal_err;
-	}
+
+	if (!ssl_kex_derive_dhe(dh_srvr, dh_clnt, &key, &key_len))
+		goto err;
 
 	if (!tls12_derive_master_secret(s, key, key_len))
 		goto err;
@@ -1841,21 +1826,13 @@ ssl3_get_client_kex_dhe(SSL *s, CBS *cbs)
 	DH_free(S3I(s)->tmp.dh);
 	S3I(s)->tmp.dh = NULL;
 
-	freezero(key, key_size);
-	BN_clear_free(bn);
-
-	return (1);
+	ret = 1;
 
- decode_err:
-	al = SSL_AD_DECODE_ERROR;
-	SSLerror(s, SSL_R_BAD_PACKET_LENGTH);
- fatal_err:
-	ssl3_send_alert(s, SSL3_AL_FATAL, al);
  err:
-	freezero(key, key_size);
-	BN_clear_free(bn);
+	freezero(key, key_len);
+	DH_free(dh_clnt);
 
-	return (-1);
+	return ret;
 }
 
 static int
-- 
cgit v1.2.3-55-g6feb