summaryrefslogtreecommitdiff
path: root/src/regress/lib/libcrypto/bn/bn_mod_sqrt.c
blob: e193755b7482b96eda675fd1f53e76f765339de7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/*	$OpenBSD: bn_mod_sqrt.c,v 1.1 2022/12/01 20:50:10 tb Exp $ */
/*
 * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <openssl/bn.h>

/* Test that sqrt * sqrt = A (mod p) where p is a prime */
struct mod_sqrt_test {
	const char *sqrt;
	const char *a;
	const char *p;
	int bn_mod_sqrt_fails;
} mod_sqrt_test_data[] = {
	{
		.sqrt = "1",
		.a = "1",
		.p = "2",
		.bn_mod_sqrt_fails = 0,
	},
	{
		.sqrt = "-1",
		.a = "20a7ee",
		.p = "460201", /* 460201 == 4D5 * E7D */
		.bn_mod_sqrt_fails = 1,
	},
	{
		.sqrt = "-1",
		.a = "65bebdb00a96fc814ec44b81f98b59fba3c30203928fa521"
		     "4c51e0a97091645280c947b005847f239758482b9bfc45b0"
		     "66fde340d1fe32fc9c1bf02e1b2d0ed",
		.p = "9df9d6cc20b8540411af4e5357ef2b0353cb1f2ab5ffc3e2"
		     "46b41c32f71e951f",
		.bn_mod_sqrt_fails = 1,
	},
};

const size_t N_TESTS = sizeof(mod_sqrt_test_data) / sizeof(*mod_sqrt_test_data);

int mod_sqrt_test(struct mod_sqrt_test *test);

int
mod_sqrt_test(struct mod_sqrt_test *test)
{
	BN_CTX *ctx = NULL;
	BIGNUM *a = NULL, *p = NULL, *want = NULL, *got = NULL, *diff = NULL;
	int failed = 1;

	if ((ctx = BN_CTX_new()) == NULL) {
		fprintf(stderr, "BN_CTX_new failed\n");
		goto out;
	}

	if (!BN_hex2bn(&a, test->a)) {
		fprintf(stderr, "BN_hex2bn(a) failed\n");
		goto out;
	}
	if (!BN_hex2bn(&p, test->p)) {
		fprintf(stderr, "BN_hex2bn(p) failed\n");
		goto out;
	}
	if (!BN_hex2bn(&want, test->sqrt)) {
		fprintf(stderr, "BN_hex2bn(want) failed\n");
		goto out;
	}

	if (((got = BN_mod_sqrt(NULL, a, p, ctx)) == NULL) !=
	   test->bn_mod_sqrt_fails) {
		fprintf(stderr, "BN_mod_sqrt %s unexpectedly\n",
		    test->bn_mod_sqrt_fails ? "succeeded" : "failed");
		goto out;
	}

	if (test->bn_mod_sqrt_fails) {
		failed = 0;
		goto out;
	}

	if ((diff = BN_new()) == NULL) {
		fprintf(stderr, "diff = BN_new() failed\n");
		goto out;
	}

	if (!BN_mod_sub(diff, want, got, p, ctx)) {
		fprintf(stderr, "BN_mod_sub failed\n");
		goto out;
	}

	if (!BN_is_zero(diff)) {
		fprintf(stderr, "want != got\n");
		goto out;
	}

	failed = 0;

 out:
	BN_CTX_free(ctx);
	BN_free(a);
	BN_free(p);
	BN_free(want);
	BN_free(got);
	BN_free(diff);

	return failed;
}

int
main(void)
{
	size_t i;
	int failed = 0;

	for (i = 0; i < N_TESTS; i++)
		failed |= mod_sqrt_test(&mod_sqrt_test_data[i]);

	return failed;
}