From acb1fa879beb68c10287b8438f18b80b2700042f Mon Sep 17 00:00:00 2001
From: Benoit Germain <benoit.germain@ubisoft.com>
Date: Thu, 4 Jul 2024 08:49:45 +0200
Subject: Error handling in coroutine lanes

---
 src/lane.cpp | 42 ++++++++++++++++++++----------------------
 src/lane.h   |  9 ++++++++-
 2 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/src/lane.cpp b/src/lane.cpp
index d6c9960..10060ad 100644
--- a/src/lane.cpp
+++ b/src/lane.cpp
@@ -499,7 +499,7 @@ namespace {
 // xxh64 of string "kStackTraceRegKey" generated at https://www.pelock.com/products/hash-calculator
 static constexpr RegistryUniqueKey kStackTraceRegKey{ 0x3F327747CACAA904ull };
 
-[[nodiscard]] static int lane_error(lua_State* L_)
+int Lane::LuaErrorHandler(lua_State* L_)
 {
     // error message (any type)
     STACK_CHECK_START_ABS(L_, 1);                                                                  // L_: some_error
@@ -510,7 +510,7 @@ static constexpr RegistryUniqueKey kStackTraceRegKey{ 0x3F327747CACAA904ull };
         return 1; // just pass on
     }
 
-    STACK_GROW(L_, 3);
+    STACK_GROW(L_, 4); // lua_setfield consumes a stack slot, so we have to account for it
     bool const _extended{ kExtendedStackTraceRegKey.readBoolValue(L_) };
     STACK_CHECK(L_, 1);
 
@@ -537,19 +537,19 @@ static constexpr RegistryUniqueKey kStackTraceRegKey{ 0x3F327747CACAA904ull };
             lua_newtable(L_);                                                                      // L_: some_error {} {}
 
             lua_pushstring(L_, _ar.source);                                                        // L_: some_error {} {} source
-            lua_setfield(L_, -2, "source");                                                        // L_: some_error {} {}
+            luaG_setfield(L_, -2, std::string_view{ "source" });                                   // L_: some_error {} {}
 
             lua_pushinteger(L_, _ar.currentline);                                                  // L_: some_error {} {} currentline
-            lua_setfield(L_, -2, "currentline");                                                   // L_: some_error {} {}
+            luaG_setfield(L_, -2, std::string_view{ "currentline" });                              // L_: some_error {} {}
 
-            lua_pushstring(L_, _ar.name);                                                          // L_: some_error {} {} name
-            lua_setfield(L_, -2, "name");                                                          // L_: some_error {} {}
+            lua_pushstring(L_, _ar.name ? _ar.name : "<?>");                                       // L_: some_error {} {} name
+            luaG_setfield(L_, -2, std::string_view{ "name" });                                     // L_: some_error {} {}
 
             lua_pushstring(L_, _ar.namewhat);                                                      // L_: some_error {} {} namewhat
-            lua_setfield(L_, -2, "namewhat");                                                      // L_: some_error {} {}
+            luaG_setfield(L_, -2, std::string_view{ "namewhat" });                                 // L_: some_error {} {}
 
             lua_pushstring(L_, _ar.what);                                                          // L_: some_error {} {} what
-            lua_setfield(L_, -2, "what");                                                          // L_: some_error {} {}
+            luaG_setfield(L_, -2, std::string_view{ "what" });                                     // L_: some_error {} {}
         } else if (_ar.currentline > 0) {
             luaG_pushstring(L_, "%s:%d", _ar.short_src, _ar.currentline);                          // L_: some_error {} "blah:blah"
         } else {
@@ -630,7 +630,7 @@ static void push_stack_trace(lua_State* L_, Lane::ErrorTraceLevel errorTraceLeve
     STACK_GROW(L_, 5);
 
     int const _finalizers_index{ lua_gettop(L_) };
-    int const _err_handler_index{ (errorTraceLevel_ != Lane::Minimal) ? (lua_pushcfunction(L_, lane_error), lua_gettop(L_)) : 0 };
+    int const _err_handler_index{ (errorTraceLevel_ != Lane::Minimal) ? (lua_pushcfunction(L_, Lane::LuaErrorHandler), lua_gettop(L_)) : 0 };
 
     LuaError _rc{ LuaError::OK };
     for (int _n = static_cast<int>(lua_rawlen(L_, _finalizers_index)); _n > 0; --_n) {
@@ -763,7 +763,7 @@ static void lane_main(Lane* const lane_)
     LuaError _rc{ LuaError::ERRRUN };
     if (lane_->status == Lane::Pending) { // nothing wrong happened during preparation, we can work
         // At this point, the lane function and arguments are on the stack, possibly preceded by the error handler
-        int const _errorHandlerCount{ lane_->errorTraceLevel == Lane::Minimal ? 0 : 1};
+        int const _errorHandlerCount{ lane_->errorHandlerCount() };
         int _nargs{ lua_gettop(_L) - 1 - _errorHandlerCount };
         {
             std::unique_lock _guard{ lane_->doneMutex };
@@ -777,7 +777,7 @@ static void lane_main(Lane* const lane_)
             // S and L are different: we run as a coroutine in Lua thread L created in state S
             do {
                 int _nresults{};
-                _rc = luaG_resume(_L, nullptr, _nargs, &_nresults);                                // L: eh? retvals|err
+                _rc = luaG_resume(_L, nullptr, _nargs, &_nresults);                                // L: eh? retvals|err...
                 if (_rc == LuaError::YIELD) {
                     // change our status to suspended, and wait until someone wants us to resume
                     std::unique_lock _guard{ lane_->doneMutex };
@@ -794,6 +794,15 @@ static void lane_main(Lane* const lane_)
                     _nargs = lua_gettop(_L);
                 }
             } while (_rc == LuaError::YIELD);
+            if (_rc != LuaError::OK) {                                                             // : err...
+                // for some reason, in my tests with Lua 5.4, when the coroutine raises an error, I have 3 copies of it on the stack
+                // or false + the error message when running Lua 5.1
+                // since the rest of our code wants only the error message, let us keep only the latter.
+                lua_replace(_L, 1);                                                                // L: err...
+                lua_settop(_L, 1);                                                                 // L: err
+                // now we build the stack trace table if the error trace level requests it
+                std::ignore = Lane::LuaErrorHandler(_L);                                           // L: err
+            }
         }
 
         if (_errorHandlerCount) {
@@ -1098,17 +1107,6 @@ void Lane::PushMetatable(lua_State* L_)
 
 // #################################################################################################
 
-[[nodiscard]] int Lane::pushErrorHandler() const
-{
-    if (errorTraceLevel != ErrorTraceLevel::Minimal) {
-        lua_pushcfunction(L, lane_error);
-        return 1;
-    }
-    return 0;
-}
-
-// #################################################################################################
-
 void Lane::pushStatusString(lua_State* L_) const
 {
     std::string_view const _str{ threadStatusString() };
diff --git a/src/lane.h b/src/lane.h
index 6dccd3a..f0fd0ac 100644
--- a/src/lane.h
+++ b/src/lane.h
@@ -158,12 +158,19 @@ class Lane
         lua_close(_L); // this collects our coroutine thread at the same time
     }
     [[nodiscard]] std::string_view errorTraceLevelString() const;
+    [[nodiscard]] int errorHandlerCount() const noexcept
+    {
+        // don't push a error handler when in coroutine mode, as the first lua_resume wants only the function and its arguments on the stack
+        return ((errorTraceLevel == Lane::Minimal) || isCoroutine()) ? 0 : 1; 
+    }
+    [[nodiscard]] bool isCoroutine() const noexcept { return S != L; }
     [[nodiscard]] std::string_view getDebugName() const
     {
         std::lock_guard<std::mutex> _guard{ debugNameMutex };
         return debugName;
     }
-    [[nodiscard]] int pushErrorHandler() const;
+    static int LuaErrorHandler(lua_State* L_);
+    [[nodiscard]] int pushErrorHandler() const noexcept { return (errorHandlerCount() == 0) ? 0 : (lua_pushcfunction(L, LuaErrorHandler), 1); }
     [[nodiscard]] std::string_view pushErrorTraceLevel(lua_State* L_) const;
     static void PushMetatable(lua_State* L_);
     void pushStatusString(lua_State* L_) const;
-- 
cgit v1.2.3-55-g6feb