From bde03eeb489fff399715b17465ac07fdbb33acf1 Mon Sep 17 00:00:00 2001 From: Roberto Ierusalimschy Date: Thu, 12 Nov 2015 16:07:25 -0200 Subject: in 'table.sort': tighter checks for invalid order function + "random" pivot for larger intervals (to avoid attacks with bad data) --- ltablib.c | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) (limited to 'ltablib.c') diff --git a/ltablib.c b/ltablib.c index 7dae7049..c093187f 100644 --- a/ltablib.c +++ b/ltablib.c @@ -1,5 +1,5 @@ /* -** $Id: ltablib.c,v 1.83 2015/09/17 15:53:50 roberto Exp roberto $ +** $Id: ltablib.c,v 1.84 2015/11/06 16:07:14 roberto Exp roberto $ ** Library for Table Manipulation ** See Copyright Notice in lua.h */ @@ -233,7 +233,6 @@ static int unpack (lua_State *L) { ** ======================================================= */ - static void set2 (lua_State *L, int i, int j) { lua_seti(L, 1, i); lua_seti(L, 1, j); @@ -269,14 +268,14 @@ static int partition (lua_State *L, int lo, int up) { for (;;) { /* next loop: repeat ++i while a[i] < P */ while (lua_geti(L, 1, ++i), sort_comp(L, -1, -2)) { - if (i >= up) + if (i == up - 1) /* a[i] < P but a[up - 1] == P ?? */ luaL_error(L, "invalid order function for sorting"); lua_pop(L, 1); /* remove a[i] */ } /* after the loop, a[i] >= P and a[lo .. i - 1] < P */ /* next loop: repeat --j while P < a[j] */ while (lua_geti(L, 1, --j), sort_comp(L, -3, -1)) { - if (j < lo) + if (j < i) /* j < i but a[j] > P ?? */ luaL_error(L, "invalid order function for sorting"); lua_pop(L, 1); /* remove a[j] */ } @@ -294,6 +293,20 @@ static int partition (lua_State *L, int lo, int up) { } +/* +** Choose a "random" pivot in the middle part of the interval [lo, up]. +** Use 'time' and 'clock' as sources of "randomness". +*/ +static int choosePivot (int lo, int up) { + unsigned int t = (unsigned int)(unsigned long)time(NULL); /* time */ + unsigned int c = (unsigned int)(unsigned long)clock(); /* clock */ + unsigned int r4 = (unsigned int)(up - lo) / 4u; /* range/4 */ + unsigned int p = (c + t) % (r4 * 2) + (lo + r4); + lua_assert(lo + r4 <= p && p <= up - r4); + return (int)p; +} + + static void auxsort (lua_State *L, int lo, int up) { while (lo < up) { /* loop for tail recursion */ int p; @@ -306,7 +319,10 @@ static void auxsort (lua_State *L, int lo, int up) { lua_pop(L, 2); /* remove both values */ if (up - lo == 1) /* only 2 elements? */ return; /* already sorted */ - p = (lo + up)/2; + if (up - lo < 100) /* small interval? */ + p = (lo + up)/2; /* middle element is a good pivot */ + else /* for larger intervals, it is worth a random pivot */ + p = choosePivot(lo, up); lua_geti(L, 1, p); lua_geti(L, 1, lo); if (sort_comp(L, -2, -1)) /* a[p] < a[lo]? */ @@ -338,6 +354,7 @@ static void auxsort (lua_State *L, int lo, int up) { } /* tail call auxsort(L, lo, up) */ } + static int sort (lua_State *L) { int n = (int)aux_getn(L, 1, TAB_RW); luaL_checkstack(L, 50, ""); /* assume array is smaller than 2^50 */ -- cgit v1.2.3-55-g6feb