From b95a82a8db13f8e79ab8027f1fdfee0836bf02ec Mon Sep 17 00:00:00 2001
From: tb <>
Date: Thu, 6 Apr 2023 08:38:53 +0000
Subject: wycheproof: use EVP_MD instead of importing "hash"

---
 src/regress/lib/libcrypto/wycheproof/wycheproof.go | 122 +++++++++------------
 1 file changed, 52 insertions(+), 70 deletions(-)

diff --git a/src/regress/lib/libcrypto/wycheproof/wycheproof.go b/src/regress/lib/libcrypto/wycheproof/wycheproof.go
index b3c9225bb6..0698ac90b7 100644
--- a/src/regress/lib/libcrypto/wycheproof/wycheproof.go
+++ b/src/regress/lib/libcrypto/wycheproof/wycheproof.go
@@ -1,4 +1,4 @@
-/* $OpenBSD: wycheproof.go,v 1.141 2023/03/25 09:21:17 tb Exp $ */
+/* $OpenBSD: wycheproof.go,v 1.142 2023/04/06 08:38:53 tb Exp $ */
 /*
  * Copyright (c) 2018 Joel Sing <jsing@openbsd.org>
  * Copyright (c) 2018,2019,2022 Theo Buehler <tb@openbsd.org>
@@ -75,14 +75,10 @@ import "C"
 
 import (
 	"bytes"
-	"crypto/sha1"
-	"crypto/sha256"
-	"crypto/sha512"
 	"encoding/base64"
 	"encoding/hex"
 	"encoding/json"
 	"fmt"
-	"hash"
 	"io/ioutil"
 	"log"
 	"os"
@@ -564,23 +560,6 @@ func nidFromString(ns string) (int, error) {
 	return -1, fmt.Errorf("unknown NID %q", ns)
 }
 
-func hashFromString(hs string) (hash.Hash, error) {
-	switch hs {
-	case "SHA-1":
-		return sha1.New(), nil
-	case "SHA-224":
-		return sha256.New224(), nil
-	case "SHA-256":
-		return sha256.New(), nil
-	case "SHA-384":
-		return sha512.New384(), nil
-	case "SHA-512":
-		return sha512.New(), nil
-	default:
-		return nil, fmt.Errorf("unknown hash %q", hs)
-	}
-}
-
 func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
 	switch hs {
 	case "SHA-1":
@@ -598,6 +577,26 @@ func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
 	}
 }
 
+func hashEvpDigestMessage(md *C.EVP_MD, msg []byte) ([]byte, C.size, error) {
+	size := C.EVP_MD_size(md)
+	if size <= 0 || size > C.EVP_MAX_MD_SIZE {
+		return nil, 0, fmt.Errorf("unexpected MD size %d", size)
+	}
+
+	msgLen := len(msg)
+	if msgLen == 0 {
+		msg = append(msg, 0)
+	}
+
+	digest := make([]byte, size)
+
+	if C.EVP_Digest(unsafe.Pointer(&msg[0]), C.size_t(msgLen), (*C.uchar)(unsafe.Pointer(&digest[0])), nil, md, nil) != 1 {
+		return nil, 0, fmt.Errorf("EVP_Digest failed")
+	}
+
+	return digest, int(size), nil
+}
+
 func checkAesCbcPkcs5(ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int,
 	iv []byte, ivLen int, in []byte, inLen int, out []byte, outLen int,
 	wt *wycheproofTestAesCbcPkcs5) bool {
@@ -1337,19 +1336,15 @@ func encodeDSAP1363Sig(wtSig string) (*C.uchar, C.int) {
 	return cDer, derLen
 }
 
-func runDSATest(dsa *C.DSA, variant testVariant, h hash.Hash, wt *wycheproofTestDSA) bool {
+func runDSATest(dsa *C.DSA, md *C.EVP_MD, variant testVariant, wt *wycheproofTestDSA) bool {
 	msg, err := hex.DecodeString(wt.Msg)
 	if err != nil {
 		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
 	}
 
-	h.Reset()
-	h.Write(msg)
-	msg = h.Sum(nil)
-
-	msgLen := len(msg)
-	if msgLen == 0 {
-		msg = append(msg, 0)
+	msg, msgLen, err := hashEvpDigestMessage(md, msg)
+	if err != nil {
+		log.Fatalf("%v", err)
 	}
 
 	var ret C.int
@@ -1433,7 +1428,7 @@ func runDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestG
 		log.Fatalf("DSA_set0_key returned %d", ret)
 	}
 
-	h, err := hashFromString(wtg.SHA)
+	md, err := hashEvpMdFromString(wtg.SHA)
 	if err != nil {
 		log.Fatalf("Failed to get hash: %v", err)
 	}
@@ -1475,13 +1470,13 @@ func runDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestG
 
 	success := true
 	for _, wt := range wtg.Tests {
-		if !runDSATest(dsa, variant, h, wt) {
+		if !runDSATest(dsa, md, variant, wt) {
 			success = false
 		}
-		if !runDSATest(dsaDER, variant, h, wt) {
+		if !runDSATest(dsaDER, md, variant, wt) {
 			success = false
 		}
-		if !runDSATest(dsaPEM, variant, h, wt) {
+		if !runDSATest(dsaPEM, md, variant, wt) {
 			success = false
 		}
 	}
@@ -1722,19 +1717,15 @@ func runECDHWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDHWeb
 	return success
 }
 
-func runECDSATest(ecKey *C.EC_KEY, nid int, h hash.Hash, variant testVariant, wt *wycheproofTestECDSA) bool {
+func runECDSATest(ecKey *C.EC_KEY, md *C.EVP_MD, nid int, variant testVariant, wt *wycheproofTestECDSA) bool {
 	msg, err := hex.DecodeString(wt.Msg)
 	if err != nil {
 		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
 	}
 
-	h.Reset()
-	h.Write(msg)
-	msg = h.Sum(nil)
-
-	msgLen := len(msg)
-	if msgLen == 0 {
-		msg = append(msg, 0)
+	msg, msgLen, err := hashEvpDigestMessage(md, msg)
+	if err != nil {
+		log.Fatalf("%v", err)
 	}
 
 	var ret C.int
@@ -1810,14 +1801,14 @@ func runECDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTes
 	if err != nil {
 		log.Fatalf("Failed to get MD NID: %v", err)
 	}
-	h, err := hashFromString(wtg.SHA)
+	md, err := hashEvpMdFromString(wtg.SHA)
 	if err != nil {
 		log.Fatalf("Failed to get hash: %v", err)
 	}
 
 	success := true
 	for _, wt := range wtg.Tests {
-		if !runECDSATest(ecKey, nid, h, variant, wt) {
+		if !runECDSATest(ecKey, md, nid, variant, wt) {
 			success = false
 		}
 	}
@@ -1914,14 +1905,14 @@ func runECDSAWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDSAW
 	if err != nil {
 		log.Fatalf("Failed to get MD NID: %v", err)
 	}
-	h, err := hashFromString(wtg.SHA)
+	md, err := hashEvpMdFromString(wtg.SHA)
 	if err != nil {
 		log.Fatalf("Failed to get hash: %v", err)
 	}
 
 	success := true
 	for _, wt := range wtg.Tests {
-		if !runECDSATest(ecKey, nid, h, Webcrypto, wt) {
+		if !runECDSATest(ecKey, md, nid, Webcrypto, wt) {
 			success = false
 		}
 	}
@@ -2512,25 +2503,23 @@ func runRsaesPkcs1TestGroup(algorithm string, wtg *wycheproofTestGroupRsaesPkcs1
 	return success
 }
 
-func runRsassaTest(rsa *C.RSA, h hash.Hash, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
+func runRsassaTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
 	msg, err := hex.DecodeString(wt.Msg)
 	if err != nil {
 		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
 	}
 
-	h.Reset()
-	h.Write(msg)
-	msg = h.Sum(nil)
+	msg, _, err = hashEvpDigestMessage(sha, msg)
+	if err != nil {
+		log.Fatalf("%v", err)
+	}
 
 	sig, err := hex.DecodeString(wt.Sig)
 	if err != nil {
 		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
 	}
 
-	msgLen, sigLen := len(msg), len(sig)
-	if msgLen == 0 {
-		msg = append(msg, 0)
-	}
+	sigLen := len(sig)
 	if sigLen == 0 {
 		sig = append(sig, 0)
 	}
@@ -2599,11 +2588,6 @@ func runRsassaTestGroup(algorithm string, wtg *wycheproofTestGroupRsassa) bool {
 	rsaN = nil
 	rsaE = nil
 
-	h, err := hashFromString(wtg.SHA)
-	if err != nil {
-		log.Fatalf("Failed to get hash: %v", err)
-	}
-
 	sha, err := hashEvpMdFromString(wtg.SHA)
 	if err != nil {
 		log.Fatalf("Failed to get hash: %v", err)
@@ -2616,32 +2600,30 @@ func runRsassaTestGroup(algorithm string, wtg *wycheproofTestGroupRsassa) bool {
 
 	success := true
 	for _, wt := range wtg.Tests {
-		if !runRsassaTest(rsa, h, sha, mgfSha, wtg.SLen, wt) {
+		if !runRsassaTest(rsa, sha, mgfSha, wtg.SLen, wt) {
 			success = false
 		}
 	}
 	return success
 }
 
-func runRSATest(rsa *C.RSA, nid int, h hash.Hash, wt *wycheproofTestRSA) bool {
+func runRSATest(rsa *C.RSA, md *C.EVP_MD, nid int, wt *wycheproofTestRSA) bool {
 	msg, err := hex.DecodeString(wt.Msg)
 	if err != nil {
 		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
 	}
 
-	h.Reset()
-	h.Write(msg)
-	msg = h.Sum(nil)
+	msg, msgLen, err := hashEvpDigestMessage(md, msg)
+	if err != nil {
+		log.Fatalf("%v", err)
+	}
 
 	sig, err := hex.DecodeString(wt.Sig)
 	if err != nil {
 		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
 	}
 
-	msgLen, sigLen := len(msg), len(sig)
-	if msgLen == 0 {
-		msg = append(msg, 0)
-	}
+	sigLen := len(sig)
 	if sigLen == 0 {
 		sig = append(sig, 0)
 	}
@@ -2695,14 +2677,14 @@ func runRSATestGroup(algorithm string, wtg *wycheproofTestGroupRSA) bool {
 	if err != nil {
 		log.Fatalf("Failed to get MD NID: %v", err)
 	}
-	h, err := hashFromString(wtg.SHA)
+	md, err := hashEvpMdFromString(wtg.SHA)
 	if err != nil {
 		log.Fatalf("Failed to get hash: %v", err)
 	}
 
 	success := true
 	for _, wt := range wtg.Tests {
-		if !runRSATest(rsa, nid, h, wt) {
+		if !runRSATest(rsa, md, nid, wt) {
 			success = false
 		}
 	}
-- 
cgit v1.2.3-55-g6feb