From d0f34d91373fa265d4445e456e4a10ce206c1559 Mon Sep 17 00:00:00 2001 From: Roberto Ierusalimschy Date: Mon, 18 Jan 2021 11:40:45 -0300 Subject: Allow yields in '__close' metamethods ater errors Completes commit b07fc10e91a. '__close' metamethods can yield even when they are being called due to an error. '__close' metamethods from C functions are still not allowed to yield. --- ldo.c | 127 ++++++++++++++++++++++++++++-------------------------- lstate.h | 22 ++++++++-- testes/locals.lua | 48 +++++++++++++++++---- 3 files changed, 126 insertions(+), 71 deletions(-) diff --git a/ldo.c b/ldo.c index aa159cf0..45cfd592 100644 --- a/ldo.c +++ b/ldo.c @@ -565,25 +565,64 @@ void luaD_callnoyield (lua_State *L, StkId func, int nResults) { /* -** Completes the execution of an interrupted C function, calling its -** continuation function. +** Finish the job of 'lua_pcallk' after it was interrupted by an yield. +** (The caller, 'finishCcall', does the final call to 'adjustresults'.) +** The main job is to complete the 'luaD_pcall' called by 'lua_pcallk'. +** If a '__close' method yields here, eventually control will be back +** to 'finishCcall' (when that '__close' method finally returns) and +** 'finishpcallk' will run again and close any still pending '__close' +** methods. Similarly, if a '__close' method errs, 'precover' calls +** 'unroll' which calls ''finishCcall' and we are back here again, to +** close any pending '__close' methods. +** Note that, up to the call to 'luaF_close', the corresponding +** 'CallInfo' is not modified, so that this repeated run works like the +** first one (except that it has at least one less '__close' to do). In +** particular, field CIST_RECST preserves the error status across these +** multiple runs, changing only if there is a new error. */ -static void finishCcall (lua_State *L, int status) { - CallInfo *ci = L->ci; +static int finishpcallk (lua_State *L, CallInfo *ci) { + int status = getcistrecst(ci); /* get original status */ + if (status == LUA_OK) /* no error? */ + status = LUA_YIELD; /* was interrupted by an yield */ + else { /* error */ + StkId func = restorestack(L, ci->u2.funcidx); + L->allowhook = getoah(ci->callstatus); /* restore 'allowhook' */ + luaF_close(L, func, status, 1); /* can yield or raise an error */ + func = restorestack(L, ci->u2.funcidx); /* stack may be moved */ + luaD_seterrorobj(L, status, func); + luaD_shrinkstack(L); /* restore stack size in case of overflow */ + setcistrecst(ci, LUA_OK); /* clear original status */ + } + ci->callstatus &= ~CIST_YPCALL; + L->errfunc = ci->u.c.old_errfunc; + /* if it is here, there were errors or yields; unlike 'lua_pcallk', + do not change status */ + return status; +} + + +/* +** Completes the execution of a C function interrupted by an yield. +** The interruption must have happened while the function was +** executing 'lua_callk' or 'lua_pcallk'. In the second case, the +** call to 'finishpcallk' finishes the interrupted execution of +** 'lua_pcallk'. After that, it calls the continuation of the +** interrupted function and finally it completes the job of the +** 'luaD_call' that called the function. +** In the call to 'adjustresults', we do not know the number of +** results of the function called by 'lua_callk'/'lua_pcallk', +** so we are conservative and use LUA_MULTRET (always adjust). +*/ +static void finishCcall (lua_State *L, CallInfo *ci) { int n; + int status = LUA_YIELD; /* default if there were no errors */ /* must have a continuation and must be able to call it */ lua_assert(ci->u.c.k != NULL && yieldable(L)); - /* error status can only happen in a protected call */ - lua_assert((ci->callstatus & CIST_YPCALL) || status == LUA_YIELD); - if (ci->callstatus & CIST_YPCALL) { /* was inside a pcall? */ - ci->callstatus &= ~CIST_YPCALL; /* continuation is also inside it */ - L->errfunc = ci->u.c.old_errfunc; /* with the same error function */ - } - /* finish 'lua_callk'/'lua_pcall'; CIST_YPCALL and 'errfunc' already - handled */ - adjustresults(L, ci->nresults); + if (ci->callstatus & CIST_YPCALL) /* was inside a 'lua_pcallk'? */ + status = finishpcallk(L, ci); /* finish it */ + adjustresults(L, LUA_MULTRET); /* finish 'lua_callk' */ lua_unlock(L); - n = (*ci->u.c.k)(L, status, ci->u.c.ctx); /* call continuation function */ + n = (*ci->u.c.k)(L, status, ci->u.c.ctx); /* call continuation */ lua_lock(L); api_checknelems(L, n); luaD_poscall(L, ci, n); /* finish 'luaD_call' */ @@ -600,7 +639,7 @@ static void unroll (lua_State *L, void *ud) { UNUSED(ud); while ((ci = L->ci) != &L->base_ci) { /* something in the stack */ if (!isLua(ci)) /* C function? */ - finishCcall(L, LUA_YIELD); /* complete its execution */ + finishCcall(L, ci); /* complete its execution */ else { /* Lua function */ luaV_finishOp(L); /* finish interrupted instruction */ luaV_execute(L, ci); /* execute down to higher C 'boundary' */ @@ -623,40 +662,6 @@ static CallInfo *findpcall (lua_State *L) { } -/* -** Auxiliary structure to call 'recover' in protected mode. -*/ -struct RecoverS { - int status; - CallInfo *ci; -}; - - -/* -** Recovers from an error in a coroutine: completes the execution of the -** interrupted 'luaD_pcall', completes the interrupted C function which -** called 'lua_pcallk', and continues running the coroutine. If there is -** an error in 'luaF_close', this function will be called again and the -** coroutine will continue from where it left. -*/ -static void recover (lua_State *L, void *ud) { - struct RecoverS *r = cast(struct RecoverS *, ud); - int status = r->status; - CallInfo *ci = r->ci; /* recover point */ - StkId func = restorestack(L, ci->u2.funcidx); - /* "finish" luaD_pcall */ - L->ci = ci; - L->allowhook = getoah(ci->callstatus); /* restore original 'allowhook' */ - luaF_close(L, func, status, 0); /* may change the stack */ - func = restorestack(L, ci->u2.funcidx); - luaD_seterrorobj(L, status, func); - luaD_shrinkstack(L); /* restore stack size in case of overflow */ - L->errfunc = ci->u.c.old_errfunc; - finishCcall(L, status); /* finish 'lua_pcallk' callee */ - unroll(L, NULL); /* continue running the coroutine */ -} - - /* ** Signal an error in the call to 'lua_resume', not in the execution ** of the coroutine itself. (Such errors should not be handled by any @@ -705,19 +710,21 @@ static void resume (lua_State *L, void *ud) { /* -** Calls 'recover' in protected mode, repeating while there are -** recoverable errors, that is, errors inside a protected call. (Any -** error interrupts 'recover', and this loop protects it again so it -** can continue.) Stops with a normal end (status == LUA_OK), an yield +** Unrolls a coroutine in protected mode while there are recoverable +** errors, that is, errors inside a protected call. (Any error +** interrupts 'unroll', and this loop protects it again so it can +** continue.) Stops with a normal end (status == LUA_OK), an yield ** (status == LUA_YIELD), or an unprotected error ('findpcall' doesn't ** find a recover point). */ -static int p_recover (lua_State *L, int status) { - struct RecoverS r; - r.status = status; - while (errorstatus(status) && (r.ci = findpcall(L)) != NULL) - r.status = luaD_rawrunprotected(L, recover, &r); - return r.status; +static int precover (lua_State *L, int status) { + CallInfo *ci; + while (errorstatus(status) && (ci = findpcall(L)) != NULL) { + L->ci = ci; /* go down to recovery functions */ + setcistrecst(ci, status); /* status to finish 'pcall' */ + status = luaD_rawrunprotected(L, unroll, NULL); + } + return status; } @@ -738,7 +745,7 @@ LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs, api_checknelems(L, (L->status == LUA_OK) ? nargs + 1 : nargs); status = luaD_rawrunprotected(L, resume, &nargs); /* continue running after recoverable errors */ - status = p_recover(L, status); + status = precover(L, status); if (likely(!errorstatus(status))) lua_assert(status == L->status); /* normal end or yield */ else { /* unrecoverable error */ diff --git a/lstate.h b/lstate.h index 38a6c9b6..38248e57 100644 --- a/lstate.h +++ b/lstate.h @@ -191,17 +191,33 @@ typedef struct CallInfo { */ #define CIST_OAH (1<<0) /* original value of 'allowhook' */ #define CIST_C (1<<1) /* call is running a C function */ -#define CIST_FRESH (1<<2) /* call is on a fresh "luaV_execute" frame */ +#define CIST_FRESH (1<<2) /* call is on a fresh "luaV_execute" frame */ #define CIST_HOOKED (1<<3) /* call is running a debug hook */ #define CIST_YPCALL (1<<4) /* call is a yieldable protected call */ #define CIST_TAIL (1<<5) /* call was tail called */ #define CIST_HOOKYIELD (1<<6) /* last hook called yielded */ -#define CIST_FIN (1<<7) /* call is running a finalizer */ +#define CIST_FIN (1<<7) /* call is running a finalizer */ #define CIST_TRAN (1<<8) /* 'ci' has transfer information */ +/* Bits 9-11 are used for CIST_RECST (see below) */ +#define CIST_RECST 9 #if defined(LUA_COMPAT_LT_LE) -#define CIST_LEQ (1<<9) /* using __lt for __le */ +#define CIST_LEQ (1<<12) /* using __lt for __le */ #endif + +/* +** Field CIST_RECST stores the "recover status", used to keep the error +** status while closing to-be-closed variables in coroutines, so that +** Lua can correctly resume after an yield from a __close method called +** because of an error. (Three bits are enough for error status.) +*/ +#define getcistrecst(ci) (((ci)->callstatus >> CIST_RECST) & 7) +#define setcistrecst(ci,st) \ + check_exp(((st) & 7) == (st), /* status must fit in three bits */ \ + ((ci)->callstatus = ((ci)->callstatus & ~(7 << CIST_RECST)) \ + | ((st) << CIST_RECST))) + + /* active function is a Lua function */ #define isLua(ci) (!((ci)->callstatus & CIST_C)) diff --git a/testes/locals.lua b/testes/locals.lua index c9c93ccf..8506195e 100644 --- a/testes/locals.lua +++ b/testes/locals.lua @@ -697,34 +697,66 @@ end do - -- yielding inside closing metamethods after an error: - -- not yet implemented; raises an error + -- yielding inside closing metamethods after an error local co = coroutine.wrap(function () local function foo (err) + local z = func2close(function(_, msg) + assert(msg == nil or msg == err + 20) + coroutine.yield("z") + return 100, 200 + end) + + local y = func2close(function(_, msg) + -- still gets the original error (if any) + assert(msg == err or (msg == nil and err == 1)) + coroutine.yield("y") + if err then error(err + 20) end -- creates or changes the error + end) + local x = func2close(function(_, msg) - assert(msg == err) + assert(msg == err or (msg == nil and err == 1)) coroutine.yield("x") return 100, 200 end) - if err then error(err) else return 10, 20 end + if err == 10 then error(err) else return 10, 20 end end coroutine.yield(pcall(foo, nil)) -- no error + coroutine.yield(pcall(foo, 1)) -- error in __close return pcall(foo, 10) -- 'foo' will raise an error end) - local a, b = co() + local a, b = co() -- first foo: no error assert(a == "x" and b == nil) -- yields inside 'x'; Ok - + a, b = co() + assert(a == "y" and b == nil) -- yields inside 'y'; Ok + a, b = co() + assert(a == "z" and b == nil) -- yields inside 'z'; Ok local a, b, c = co() assert(a and b == 10 and c == 20) -- returns from 'pcall(foo, nil)' - local st, msg = co() -- error yielding after an error - assert(not st and string.find(msg, "attempt to yield")) + local a, b = co() -- second foo: error in __close + assert(a == "x" and b == nil) -- yields inside 'x'; Ok + a, b = co() + assert(a == "y" and b == nil) -- yields inside 'y'; Ok + a, b = co() + assert(a == "z" and b == nil) -- yields inside 'z'; Ok + local st, msg = co() -- reports the error in 'y' + assert(not st and msg == 21) + + local a, b = co() -- third foo: error in function body + assert(a == "x" and b == nil) -- yields inside 'x'; Ok + a, b = co() + assert(a == "y" and b == nil) -- yields inside 'y'; Ok + a, b = co() + assert(a == "z" and b == nil) -- yields inside 'z'; Ok + local st, msg = co() -- gets final error + assert(not st and msg == 10 + 20) + end -- cgit v1.2.3-55-g6feb