From 4dc56c6d362f2cd8a79d83369f0b852df07dae3f Mon Sep 17 00:00:00 2001
From: Mark Pulford <mark@kyne.com.au>
Date: Sun, 8 May 2011 20:26:09 +0930
Subject: Add UTF-16 surrogate pair decode support

- Add tests for UTF-16 decoding and failures
- Add getutf8.pl to assist with UTF-16 decode testing

- Re-add test_decode_cycle() which was accidentally removed earlier
- Rename bytestring.dat to octets-escaped.dat
---
 lua_cjson.c              | 65 ++++++++++++++++++++++++++++++++++++++++-------
 tests/bytestring.dat     |  1 -
 tests/common.lua         |  4 +++
 tests/genutf8.pl         | 25 ++++++++++++++++++
 tests/octets-escaped.dat |  1 +
 tests/test.lua           | 66 ++++++++++++++++++++++++++++++++++++++++++++++--
 6 files changed, 150 insertions(+), 12 deletions(-)
 delete mode 100644 tests/bytestring.dat
 create mode 100755 tests/genutf8.pl
 create mode 100644 tests/octets-escaped.dat

diff --git a/lua_cjson.c b/lua_cjson.c
index 3af8157..52b259d 100644
--- a/lua_cjson.c
+++ b/lua_cjson.c
@@ -680,19 +680,24 @@ static int decode_hex4(const char *hex)
             digit[3];
 }
 
+/* Converts a Unicode codepoint to UTF-8.
+ * Returns UTF-8 string length, and up to 4 bytes in *utf8 */
 static int codepoint_to_utf8(char *utf8, int codepoint)
 {
+    /* 0xxxxxxx */
     if (codepoint <= 0x7F) {
         utf8[0] = codepoint;
         return 1;
     }
     
+    /* 110xxxxx 10xxxxxx */
     if (codepoint <= 0x7FF) {
         utf8[0] = (codepoint >> 6) | 0xC0;
         utf8[1] = (codepoint & 0x3F) | 0x80;
         return 2;
     }
 
+    /* 1110xxxx 10xxxxxx 10xxxxxx */
     if (codepoint <= 0xFFFF) {
         utf8[0] = (codepoint >> 12) | 0xE0;
         utf8[1] = ((codepoint >> 6) & 0x3F) | 0x80;
@@ -700,11 +705,20 @@ static int codepoint_to_utf8(char *utf8, int codepoint)
         return 3;
     }
 
+    /* 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx */
+    if (codepoint <= 0x1FFFFF) {
+        utf8[0] = (codepoint >> 18) | 0xF0;
+        utf8[1] = ((codepoint >> 12) & 0x3F) | 0x80;
+        utf8[2] = ((codepoint >> 6) & 0x3F) | 0x80;
+        utf8[3] = (codepoint & 0x3F) | 0x80;
+        return 4;
+    }
+
     return 0;
 }
 
 
-/* Called when index pointing to beginning of UCS-2 hex code: \uXXXX
+/* Called when index pointing to beginning of UTF-16 code escape: \uXXXX
  * \u is guaranteed to exist, but the remaining hex characters may be
  * missing.
  * Translate to UTF-8 and append to temporary token string.
@@ -714,25 +728,58 @@ static int codepoint_to_utf8(char *utf8, int codepoint)
  */
 static int json_append_unicode_escape(json_parse_t *json)
 {
-    char utf8[4];       /* 3 bytes of UTF-8 can handle UCS-2 */
+    char utf8[4];       /* Surrogate pairs require 4 UTF-8 bytes */
     int codepoint;
+    int surrogate_low;
     int len;
+    int escape_len = 6;
 
-    /* Fetch UCS-2 codepoint */
+    /* Fetch UTF-16 code unit */
     codepoint = decode_hex4(&json->data[json->index + 2]);
-    if (codepoint < 0) {
+    if (codepoint < 0)
         return -1;
+
+    /* UTF-16 surrogate pairs take the following 2 byte form:
+     *      11011 x yyyyyyyyyy
+     * When x = 0: y is the high 10 bits of the codepoint
+     *      x = 1: y is the low 10 bits of the codepoint
+     *
+     * Check for a surrogate pair (high or low) */
+    if ((codepoint & 0xF800) == 0xD800) {
+        /* Error if the 1st surrogate is not high */
+        if (codepoint & 0x400)
+            return -1;
+
+        /* Ensure the next code is a unicode escape */
+        if (json->data[json->index + escape_len] != '\\' ||
+            json->data[json->index + escape_len + 1] != 'u') {
+            return -1;
+        }
+
+        /* Fetch the next codepoint */
+        surrogate_low = decode_hex4(&json->data[json->index + 2 + escape_len]);
+        if (surrogate_low < 0)
+            return -1;
+
+        /* Error if the 2nd code is not a low surrogate */
+        if ((surrogate_low & 0xFC00) != 0xDC00)
+            return -1;
+
+        /* Calculate Unicode codepoint */
+        codepoint = (codepoint & 0x3FF) << 10;
+        surrogate_low &= 0x3FF;
+        codepoint = (codepoint | surrogate_low) + 0x10000;
+        escape_len = 12;
     }
 
-    /* Convert to UTF-8 */
+    /* Convert codepoint to UTF-8 */
     len = codepoint_to_utf8(utf8, codepoint);
-    if (!len) {
+    if (!len)
         return -1;
-    }
 
-    /* Append bytes and advance index */
+    /* Append bytes and advance parse index */
     strbuf_append_mem_unsafe(json->tmp, utf8, len);
-    json->index += 6;
+    json->index += escape_len;
 
     return 0;
 }
diff --git a/tests/bytestring.dat b/tests/bytestring.dat
deleted file mode 100644
index ee99a6b..0000000
--- a/tests/bytestring.dat
+++ /dev/null
@@ -1 +0,0 @@
-"\u0000\u0001\u0002\u0003\u0004\u0005\u0006\u0007\b\t\n\u000b\f\r\u000e\u000f\u0010\u0011\u0012\u0013\u0014\u0015\u0016\u0017\u0018\u0019\u001a\u001b\u001c\u001d\u001e\u001f !\"#$%&'()*+,-.\/0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f��������������������������������������������������������������������������������������������������������������������������������"
\ No newline at end of file
diff --git a/tests/common.lua b/tests/common.lua
index 9a7ed19..b8ce01d 100644
--- a/tests/common.lua
+++ b/tests/common.lua
@@ -99,6 +99,10 @@ function file_load(filename)
     local data = file:read("*a")
     file:close()
 
+    if data == nil then
+        error("Failed to read " .. filename)
+    end
+
     return data
 end
 
diff --git a/tests/genutf8.pl b/tests/genutf8.pl
new file mode 100755
index 0000000..4960663
--- /dev/null
+++ b/tests/genutf8.pl
@@ -0,0 +1,25 @@
+#!/usr/bin/perl -w
+
+# Create test comparison data using a different UTF-8 implementation.
+
+use strict;
+use Text::Iconv;
+use FileHandle;
+
+# 0xD800 - 0xDFFF are used to encode supplementary codepoints
+# 0x10000 - 0x10FFFF are supplementary codepoints
+my (@codepoints) = (0 .. 0xD7FF, 0xE000 .. 0x10FFFF);
+
+my ($utf32be) = pack("N*", @codepoints);
+my $iconv = Text::Iconv->new("UTF-32BE", "UTF-8");
+my ($utf8) = $iconv->convert($utf32be);
+defined($utf8) or die "Unable create UTF-8 string\n";
+
+my $fh = FileHandle->new();
+$fh->open("utf8.dat", ">")
+    or die "Unable to open utf8.dat: $!\n";
+$fh->print($utf8)
+    or die "Unable to write utf.dat\n";
+$fh->close();
+
+# vi:ai et sw=4 ts=4:
diff --git a/tests/octets-escaped.dat b/tests/octets-escaped.dat
new file mode 100644
index 0000000..ee99a6b
--- /dev/null
+++ b/tests/octets-escaped.dat
@@ -0,0 +1 @@
+"\u0000\u0001\u0002\u0003\u0004\u0005\u0006\u0007\b\t\n\u000b\f\r\u000e\u000f\u0010\u0011\u0012\u0013\u0014\u0015\u0016\u0017\u0018\u0019\u001a\u001b\u001c\u001d\u001e\u001f !\"#$%&'()*+,-.\/0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f��������������������������������������������������������������������������������������������������������������������������������"
\ No newline at end of file
diff --git a/tests/test.lua b/tests/test.lua
index 9075bab..0e0aad8 100755
--- a/tests/test.lua
+++ b/tests/test.lua
@@ -3,6 +3,8 @@
 -- CJSON tests
 --
 -- Mark Pulford <mark@kyne.com.au>
+--
+-- Note: The output of this script is easier to read with "less -S"
 
 require "common"
 local json = require "cjson"
@@ -95,13 +97,73 @@ local function gen_ascii()
     return table.concat(chars)
 end
 
+-- Generate every UTF-16 codepoint, including supplementary codes
+local function gen_utf16_escaped()
+    -- Create raw table escapes
+    local utf16_escaped = {}
+    local count = 0
+
+    local function append_escape(code)
+        local esc = string.format('\\u%04X', code)
+        table.insert(utf16_escaped, esc)
+    end
+
+    table.insert(utf16_escaped, '"')
+    for i = 0, 0xD7FF do
+        append_escape(i)
+    end
+    -- Skip 0xD800 - 0xDFFF since they are used to encode supplementary
+    -- codepoints
+    for i = 0xE000, 0xFFFF do
+        append_escape(i)
+    end
+    -- Append surrogate pair for each supplementary codepoint
+    for high = 0xD800, 0xDBFF do
+        for low = 0xDC00, 0xDFFF do
+            append_escape(high)
+            append_escape(low)
+        end
+    end
+    table.insert(utf16_escaped, '"')
+   
+    return table.concat(utf16_escaped)
+end
+
 local octets_raw = gen_ascii()
-local octets_escaped = file_load("bytestring.dat")
+local octets_escaped = file_load("octets-escaped.dat")
+local utf8_loaded, utf8_raw = pcall(file_load, "utf8.dat")
+if not utf8_loaded then
+    utf8_raw = "Failed to load utf8.dat"
+end
+local utf16_escaped = gen_utf16_escaped()
+
 local escape_tests = {
+    -- Test 8bit clean
     { json.encode, { octets_raw }, true, { octets_escaped } },
-    { json.decode, { octets_escaped }, true, { octets_raw } }
+    { json.decode, { octets_escaped }, true, { octets_raw } },
+    -- Ensure high bits are removed from surrogate codes
+    { json.decode, { '"\\uF800"' }, true, { "\239\160\128" } },
+    -- Test inverted surrogate pairs
+    { json.decode, { '"\\uDB00\\uD800"' },
+      false, { "Expected value but found invalid unicode escape code at character 2" } },
+    -- Test 2x high surrogate code units
+    { json.decode, { '"\\uDB00\\uDB00"' },
+      false, { "Expected value but found invalid unicode escape code at character 2" } },
+    -- Test invalid 2nd escape
+    { json.decode, { '"\\uDB00\\"' },
+      false, { "Expected value but found invalid unicode escape code at character 2" } },
+    { json.decode, { '"\\uDB00\\uD"' },
+      false, { "Expected value but found invalid unicode escape code at character 2" } },
+    -- Test decoding of all UTF-16 escapes
+    { json.decode, { utf16_escaped }, true, { utf8_raw } }
 }
 
+function test_decode_cycle(filename)
+    local obj1 = json.decode(file_load(filename))
+    local obj2 = json.decode(json.encode(obj1))
+    return compare_values(obj1, obj2)
+end
+
 run_test_group("decode simple value", simple_value_tests)
 run_test_group("decode numeric", numeric_tests)
 
-- 
cgit v1.2.3-55-g6feb