summaryrefslogtreecommitdiff
path: root/src/regress/lib/libcrypto/mlkem/parse_test_file.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/regress/lib/libcrypto/mlkem/parse_test_file.c')
-rw-r--r--src/regress/lib/libcrypto/mlkem/parse_test_file.c752
1 files changed, 752 insertions, 0 deletions
diff --git a/src/regress/lib/libcrypto/mlkem/parse_test_file.c b/src/regress/lib/libcrypto/mlkem/parse_test_file.c
new file mode 100644
index 0000000000..d79fa3dfd5
--- /dev/null
+++ b/src/regress/lib/libcrypto/mlkem/parse_test_file.c
@@ -0,0 +1,752 @@
1/* $OpenBSD: parse_test_file.c,v 1.1 2024/12/26 00:04:24 tb Exp $ */
2
3/*
4 * Copyright (c) 2024 Theo Buehler <tb@openbsd.org>
5 *
6 * Permission to use, copy, modify, and distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <sys/types.h>
20
21#include <assert.h>
22#include <err.h>
23#include <stdarg.h>
24#include <stdint.h>
25#include <stdio.h>
26#include <stdlib.h>
27#include <string.h>
28
29#include "bytestring.h"
30
31#include "parse_test_file.h"
32
33struct line_data {
34 uint8_t *data;
35 size_t data_len;
36 CBS cbs;
37 int val;
38};
39
40static struct line_data *
41line_data_new(void)
42{
43 return calloc(1, sizeof(struct line_data));
44}
45
46static void
47line_data_clear(struct line_data *ld)
48{
49 freezero(ld->data, ld->data_len);
50 explicit_bzero(ld, sizeof(*ld));
51}
52
53static void
54line_data_free(struct line_data *ld)
55{
56 if (ld == NULL)
57 return;
58 line_data_clear(ld);
59 free(ld);
60}
61
62static void
63line_data_get_int(struct line_data *ld, int *out)
64{
65 *out = ld->val;
66}
67
68static void
69line_data_get_cbs(struct line_data *ld, CBS *out)
70{
71 CBS_dup(&ld->cbs, out);
72}
73
74static void
75line_data_set_int(struct line_data *ld, int val)
76{
77 ld->val = val;
78}
79
80static int
81line_data_set_from_cbb(struct line_data *ld, CBB *cbb)
82{
83 if (!CBB_finish(cbb, &ld->data, &ld->data_len))
84 return 0;
85
86 CBS_init(&ld->cbs, ld->data, ld->data_len);
87
88 return 1;
89}
90
91struct parse_state {
92 size_t line;
93 size_t test;
94
95 size_t max;
96 size_t cur;
97 struct line_data **data;
98
99 size_t instruction_max;
100 size_t instruction_cur;
101 struct line_data **instruction_data;
102
103 int running_test_case;
104};
105
106static void
107parse_state_init(struct parse_state *ps, size_t max, size_t instruction_max)
108{
109 size_t i;
110
111 assert(max > 0);
112
113 memset(ps, 0, sizeof(*ps));
114 ps->test = 1;
115
116 ps->max = max;
117 if ((ps->data = calloc(max, sizeof(*ps->data))) == NULL)
118 err(1, NULL);
119 for (i = 0; i < max; i++) {
120 if ((ps->data[i] = line_data_new()) == NULL)
121 err(1, NULL);
122 }
123
124 if ((ps->instruction_max = instruction_max) > 0) {
125 if ((ps->instruction_data = calloc(instruction_max,
126 sizeof(*ps->instruction_data))) == NULL)
127 err(1, NULL);
128 for (i = 0; i < instruction_max; i++)
129 if ((ps->instruction_data[i] = line_data_new()) == NULL)
130 err(1, NULL);
131 }
132}
133
134static void
135parse_state_finish(struct parse_state *ps)
136{
137 size_t i;
138
139 for (i = 0; i < ps->max; i++)
140 line_data_free(ps->data[i]);
141 free(ps->data);
142
143 for (i = 0; i < ps->instruction_max; i++)
144 line_data_free(ps->instruction_data[i]);
145 free(ps->instruction_data);
146}
147
148static void
149parse_state_new_line(struct parse_state *ps)
150{
151 ps->line++;
152}
153
154static void
155parse_instruction_advance(struct parse_state *ps)
156{
157 assert(ps->instruction_cur < ps->instruction_max);
158 ps->instruction_cur++;
159}
160
161static void
162parse_state_advance(struct parse_state *ps)
163{
164 assert(ps->cur < ps->max);
165
166 ps->cur++;
167 if ((ps->cur %= ps->max) == 0)
168 ps->test++;
169}
170
171struct parse {
172 struct parse_state state;
173 CBS cbs;
174 char *buf;
175 size_t buf_max;
176 const struct test_parse *tctx;
177 void *ctx;
178
179 const char *fn;
180 FILE *fp;
181};
182
183static int
184parse_instructions_parsed(struct parse *p)
185{
186 return p->state.instruction_max == p->state.instruction_cur;
187}
188
189static void
190parse_advance(struct parse *p)
191{
192 if (!parse_instructions_parsed(p)) {
193 parse_instruction_advance(&p->state);
194 return;
195 }
196 parse_state_advance(&p->state);
197}
198
199static size_t
200parse_max(struct parse *p)
201{
202 return p->state.max;
203}
204
205static size_t
206parse_instruction_max(struct parse *p)
207{
208 return p->state.instruction_max;
209}
210
211static size_t
212parse_cur(struct parse *p)
213{
214 if (!parse_instructions_parsed(p)) {
215 assert(p->state.instruction_cur < p->state.instruction_max);
216 return p->state.instruction_cur;
217 }
218
219 assert(p->state.cur < parse_max(p));
220 return p->state.cur;
221}
222
223static size_t
224parse_must_run_test_case(struct parse *p)
225{
226 return parse_instructions_parsed(p) && parse_max(p) - parse_cur(p) == 1;
227}
228
229static const struct line_spec *
230parse_states(struct parse *p)
231{
232 if (!parse_instructions_parsed(p))
233 return p->tctx->instructions;
234 return p->tctx->states;
235}
236
237static const struct line_spec *
238parse_instruction_states(struct parse *p)
239{
240 return p->tctx->instructions;
241}
242
243static const struct line_spec *
244parse_state(struct parse *p)
245{
246 return &parse_states(p)[parse_cur(p)];
247}
248
249static size_t
250line(struct parse *p)
251{
252 return p->state.line;
253}
254
255static size_t
256test(struct parse *p)
257{
258 return p->state.test;
259}
260
261static const char *
262name(struct parse *p)
263{
264 if (p->state.running_test_case)
265 return "running test case";
266 return parse_state(p)->name;
267}
268
269static const char *
270label(struct parse *p)
271{
272 return parse_state(p)->label;
273}
274
275static const char *
276match(struct parse *p)
277{
278 return parse_state(p)->match;
279}
280
281static enum line
282parse_line_type(struct parse *p)
283{
284 return parse_state(p)->type;
285}
286
287static void
288parse_vinfo(struct parse *p, const char *fmt, va_list ap)
289{
290 fprintf(stderr, "%s:%zu test #%zu (%s): ",
291 p->fn, line(p), test(p), name(p));
292 vfprintf(stderr, fmt, ap);
293 fprintf(stderr, "\n");
294}
295
296void
297parse_info(struct parse *p, const char *fmt, ...)
298{
299 va_list ap;
300
301 va_start(ap, fmt);
302 parse_vinfo(p, fmt, ap);
303 va_end(ap);
304}
305
306void
307parse_errx(struct parse *p, const char *fmt, ...)
308{
309 va_list ap;
310
311 va_start(ap, fmt);
312 parse_vinfo(p, fmt, ap);
313 va_end(ap);
314
315 exit(1);
316}
317
318int
319parse_length_equal(struct parse *p, const char *descr, size_t want, size_t got)
320{
321 if (want == got)
322 return 1;
323
324 parse_info(p, "%s length: want %zu, got %zu", descr, want, got);
325 return 0;
326}
327
328static void
329hexdump(const uint8_t *buf, size_t len, const uint8_t *compare)
330{
331 const char *mark = "", *newline;
332 size_t i;
333
334 for (i = 1; i <= len; i++) {
335 if (compare != NULL)
336 mark = (buf[i - 1] != compare[i - 1]) ? "*" : " ";
337 newline = i % 8 ? "" : "\n";
338 fprintf(stderr, " %s0x%02x,%s", mark, buf[i - 1], newline);
339 }
340 if ((len % 8) != 0)
341 fprintf(stderr, "\n");
342}
343
344int
345parse_data_equal(struct parse *p, const char *descr, CBS *want,
346 const uint8_t *got, size_t got_len)
347{
348 if (!parse_length_equal(p, descr, CBS_len(want), got_len))
349 return 0;
350 if (CBS_mem_equal(want, got, got_len))
351 return 1;
352
353 parse_info(p, "%s differs", descr);
354 fprintf(stderr, "want:\n");
355 hexdump(CBS_data(want), CBS_len(want), got);
356 fprintf(stderr, "got:\n");
357 hexdump(got, got_len, CBS_data(want));
358 fprintf(stderr, "\n");
359
360 return 0;
361}
362
363static void
364parse_line_data_clear(struct parse *p)
365{
366 size_t i;
367
368 for (i = 0; i < parse_max(p); i++)
369 line_data_clear(p->state.data[i]);
370}
371
372static struct line_data **
373parse_state_data(struct parse *p)
374{
375 if (!parse_instructions_parsed(p))
376 return p->state.instruction_data;
377 return p->state.data;
378}
379
380static void
381parse_state_set_int(struct parse *p, int val)
382{
383 if (parse_line_type(p) != LINE_STRING_MATCH)
384 parse_errx(p, "%s: want %d, got %d", __func__,
385 LINE_STRING_MATCH, parse_line_type(p));
386 line_data_set_int(parse_state_data(p)[parse_cur(p)], val);
387}
388
389static void
390parse_state_set_from_cbb(struct parse *p, CBB *cbb)
391{
392 if (parse_line_type(p) != LINE_HEX)
393 parse_errx(p, "%s: want %d, got %d", __func__,
394 LINE_STRING_MATCH, parse_line_type(p));
395 if (!line_data_set_from_cbb(parse_state_data(p)[parse_cur(p)], cbb))
396 parse_errx(p, "line_data_set_from_cbb");
397}
398
399int
400parse_get_int(struct parse *p, size_t idx, int *out)
401{
402 assert(parse_must_run_test_case(p));
403 assert(idx < parse_max(p));
404 assert(parse_states(p)[idx].type == LINE_STRING_MATCH);
405
406 line_data_get_int(p->state.data[idx], out);
407
408 return 1;
409}
410
411int
412parse_get_cbs(struct parse *p, size_t idx, CBS *out)
413{
414 assert(parse_must_run_test_case(p));
415 assert(idx < parse_max(p));
416 assert(parse_states(p)[idx].type == LINE_HEX);
417
418 line_data_get_cbs(p->state.data[idx], out);
419
420 return 1;
421}
422
423int
424parse_instruction_get_int(struct parse *p, size_t idx, int *out)
425{
426 assert(parse_must_run_test_case(p));
427 assert(idx < parse_instruction_max(p));
428 assert(parse_instruction_states(p)[idx].type == LINE_STRING_MATCH);
429
430 line_data_get_int(p->state.instruction_data[idx], out);
431
432 return 1;
433}
434
435int
436parse_instruction_get_cbs(struct parse *p, size_t idx, CBS *out)
437{
438 assert(parse_must_run_test_case(p));
439 assert(idx < parse_instruction_max(p));
440 assert(parse_instruction_states(p)[idx].type == LINE_HEX);
441
442 line_data_get_cbs(p->state.instruction_data[idx], out);
443
444 return 1;
445}
446
447static int
448CBS_peek_bytes(CBS *cbs, CBS *out, size_t len)
449{
450 CBS dup;
451
452 CBS_dup(cbs, &dup);
453 return CBS_get_bytes(&dup, out, len);
454}
455
456static int
457parse_peek_string_cbs(struct parse *p, const char *str)
458{
459 CBS cbs;
460 size_t len = strlen(str);
461
462 if (!CBS_peek_bytes(&p->cbs, &cbs, len))
463 parse_errx(p, "CBS_peek_data");
464
465 return CBS_mem_equal(&cbs, (const uint8_t *)str, len);
466}
467
468static int
469parse_get_string_cbs(struct parse *p, const char *str)
470{
471 CBS cbs;
472 size_t len = strlen(str);
473
474 if (!CBS_get_bytes(&p->cbs, &cbs, len))
475 parse_errx(p, "CBS_get_bytes");
476
477 return CBS_mem_equal(&cbs, (const uint8_t *)str, len);
478}
479
480static int
481parse_get_string_end_cbs(struct parse *p, const char *str)
482{
483 CBS cbs;
484 int equal = 1;
485
486 CBS_init(&cbs, (const uint8_t *)str, strlen(str));
487
488 if (CBS_len(&p->cbs) < CBS_len(&cbs))
489 parse_errx(p, "line too short to match %s", str);
490
491 while (CBS_len(&cbs) > 0) {
492 uint8_t want, got;
493
494 if (!CBS_get_last_u8(&cbs, &want))
495 parse_errx(p, "CBS_get_last_u8");
496 if (!CBS_get_last_u8(&p->cbs, &got))
497 parse_errx(p, "CBS_get_last_u8");
498 if (want != got)
499 equal = 0;
500 }
501
502 return equal;
503}
504
505static void
506parse_check_label_matches(struct parse *p)
507{
508 const char *sep = ": ";
509
510 if (!parse_get_string_cbs(p, label(p)))
511 parse_errx(p, "label mismatch %s", label(p));
512
513 /* Now we expect either ": " or " = ". */
514 if (!parse_peek_string_cbs(p, sep))
515 sep = " = ";
516 if (!parse_get_string_cbs(p, sep))
517 parse_errx(p, "error getting \"%s\"", sep);
518}
519
520static int
521parse_empty_or_comment_line(struct parse *p)
522{
523 if (CBS_len(&p->cbs) == 0) {
524 return 1;
525 }
526 if (parse_peek_string_cbs(p, "#")) {
527 if (!CBS_skip(&p->cbs, CBS_len(&p->cbs)))
528 parse_errx(p, "CBS_skip");
529 return 1;
530 }
531 return 0;
532}
533
534static void
535parse_string_match_line(struct parse *p)
536{
537 int string_matches;
538
539 parse_check_label_matches(p);
540
541 string_matches = parse_get_string_cbs(p, match(p));
542 parse_state_set_int(p, string_matches);
543
544 if (!string_matches) {
545 if (!CBS_skip(&p->cbs, CBS_len(&p->cbs)))
546 parse_errx(p, "CBS_skip");
547 }
548}
549
550static int
551parse_get_hex_nibble_cbs(CBS *cbs, uint8_t *out_nibble)
552{
553 uint8_t c;
554
555 if (!CBS_get_u8(cbs, &c))
556 return 0;
557
558 if (c >= '0' && c <= '9') {
559 *out_nibble = c - '0';
560 return 1;
561 }
562 if (c >= 'a' && c <= 'f') {
563 *out_nibble = c - 'a' + 10;
564 return 1;
565 }
566 if (c >= 'A' && c <= 'F') {
567 *out_nibble = c - 'A' + 10;
568 return 1;
569 }
570
571 return 0;
572}
573
574static void
575parse_hex_line(struct parse *p)
576{
577 CBB cbb;
578
579 parse_check_label_matches(p);
580
581 if (!CBB_init(&cbb, 0))
582 parse_errx(p, "CBB_init");
583
584 while (CBS_len(&p->cbs) > 0) {
585 uint8_t hi, lo;
586
587 if (!parse_get_hex_nibble_cbs(&p->cbs, &hi))
588 parse_errx(p, "parse_get_hex_nibble_cbs");
589 if (!parse_get_hex_nibble_cbs(&p->cbs, &lo))
590 parse_errx(p, "parse_get_hex_nibble_cbs");
591
592 if (!CBB_add_u8(&cbb, hi << 4 | lo))
593 parse_errx(p, "CBB_add_u8");
594 }
595
596 parse_state_set_from_cbb(p, &cbb);
597}
598
599static void
600parse_maybe_prepare_instruction_line(struct parse *p)
601{
602 if (parse_instructions_parsed(p))
603 return;
604
605 /* Should not happen due to parse_empty_or_comment_line(). */
606 if (CBS_len(&p->cbs) == 0)
607 parse_errx(p, "empty instruction line");
608
609 if (!parse_peek_string_cbs(p, "["))
610 parse_errx(p, "expected instruction line");
611 if (!parse_get_string_cbs(p, "["))
612 parse_errx(p, "expected start of instruction line");
613 if (!parse_get_string_end_cbs(p, "]"))
614 parse_errx(p, "expected end of instruction line");
615}
616
617static void
618parse_check_line_consumed(struct parse *p)
619{
620 if (CBS_len(&p->cbs) > 0)
621 parse_errx(p, "%zu unprocessed bytes", CBS_len(&p->cbs));
622}
623
624static int
625parse_run_test_case(struct parse *p)
626{
627 const struct test_parse *tctx = p->tctx;
628
629 p->state.running_test_case = 1;
630 return tctx->run_test_case(p->ctx);
631}
632
633static void
634parse_reinit(struct parse *p)
635{
636 const struct test_parse *tctx = p->tctx;
637
638 p->state.running_test_case = 0;
639 parse_line_data_clear(p);
640 tctx->finish(p->ctx);
641 tctx->init(p->ctx, p);
642}
643
644static int
645parse_maybe_run_test_case(struct parse *p)
646{
647 int failed = 0;
648
649 if (parse_must_run_test_case(p)) {
650 failed |= parse_run_test_case(p);
651 parse_reinit(p);
652 }
653
654 parse_advance(p);
655
656 return failed;
657}
658
659static int
660parse_process_line(struct parse *p)
661{
662 if (parse_empty_or_comment_line(p))
663 return 0;
664
665 parse_maybe_prepare_instruction_line(p);
666
667 switch (parse_line_type(p)) {
668 case LINE_STRING_MATCH:
669 parse_string_match_line(p);
670 break;
671 case LINE_HEX:
672 parse_hex_line(p);
673 break;
674 default:
675 parse_errx(p, "unknown line type %d", parse_line_type(p));
676 }
677 parse_check_line_consumed(p);
678
679 return parse_maybe_run_test_case(p);
680}
681
682static void
683parse_init(struct parse *p, const char *fn, const struct test_parse *tctx,
684 void *ctx)
685{
686 FILE *fp;
687
688 memset(p, 0, sizeof(*p));
689
690 if ((fp = fopen(fn, "r")) == NULL)
691 err(1, "error opening %s", fn);
692
693 /* Poor man's basename since POSIX basename is stupid. */
694 if ((p->fn = strrchr(fn, '/')) != NULL)
695 p->fn++;
696 else
697 p->fn = fn;
698
699 p->fp = fp;
700 parse_state_init(&p->state, tctx->num_states, tctx->num_instructions);
701 p->tctx = tctx;
702 p->ctx = ctx;
703 tctx->init(ctx, p);
704}
705
706static int
707parse_next_line(struct parse *p)
708{
709 ssize_t len;
710 uint8_t u8;
711
712 if ((len = getline(&p->buf, &p->buf_max, p->fp)) == -1)
713 return 0;
714
715 CBS_init(&p->cbs, (const uint8_t *)p->buf, len);
716 parse_state_new_line(&p->state);
717
718 if (!CBS_get_last_u8(&p->cbs, &u8))
719 parse_errx(p, "CBS_get_last_u8");
720
721 assert(u8 == '\n');
722
723 return 1;
724}
725
726static void
727parse_finish(struct parse *p)
728{
729 parse_state_finish(&p->state);
730
731 free(p->buf);
732
733 if (ferror(p->fp))
734 err(1, "%s", p->fn);
735 fclose(p->fp);
736}
737
738int
739parse_test_file(const char *fn, const struct test_parse *tctx, void *ctx)
740{
741 struct parse p;
742 int failed = 0;
743
744 parse_init(&p, fn, tctx, ctx);
745
746 while (parse_next_line(&p))
747 failed |= parse_process_line(&p);
748
749 parse_finish(&p);
750
751 return failed;
752}