# vim: set ts=8 sw=4 sts=4 et ai:
#
# Implement JSON extension Date(<milliseconds>), described here:
# http://weblogs.asp.net/bleroy/archive/2008/01/18/dates-and-json.aspx
#
# Copyright (C) Walter Doekes, OSSO B.V, 2012
# License: Public Domain
# Version: 2
#
# It abuses the fact that according to the JSON specs, both "/" and "\/"
# mean the same thing. And, since you'd wouldn't escape the '/', we can
# use that to store new objects.
#
# In the Microsoft Ajax Library, they added this extension:
#
#  "\/Date(<milliseconds>)\/" -- time zone agnostic datetime
#
# Usage:
#
#  encoded = json.dumps(datetime.date.today(), cls=ExtendedJSONEncoder)
#  decoded = json.loads(encoded, cls=ExtendedJSONDecoder)
#
# Changes:
#
#  v2: added more tests
#  v2: fixed so negative milliseconds work too
#  v2: added lots of bloaty _version=2 style for python2.7 json
#

from datetime import date, datetime
from time import mktime


#-----------------------------------------------------------------------
# Encoding extensions
#-----------------------------------------------------------------------

try:
    from json.encoder import _make_iterencode
except ImportError:
    _encode_version = 1
else:
    del _make_iterencode
    _encode_version = 2


#-----------------------------------------------------------------------
# Encoding extensions (old style)
#-----------------------------------------------------------------------

if _encode_version == 1:
    from json.encoder import JSONEncoder

    class ExtendedJSONEncoder(JSONEncoder):
        def _iterencode(self, o, markers=None):
            # Microsoft JSON Extension for Date(<milliseconds>)
            if isinstance(o, date) or isinstance(o, datetime):
                microseconds = long(mktime(o.timetuple()) * 1000)
                if hasattr(o, 'microsecond'):
                    microseconds += long(o.microsecond / 1000)
                return r'"\/Date(%d)\/"' % (microseconds,)
            return super(ExtendedJSONEncoder, self)._iterencode(o, markers=markers)


#-----------------------------------------------------------------------
# Encoding extensions (new style)
#-----------------------------------------------------------------------

elif _encode_version == 2:
    from json.encoder import JSONEncoder, encode_basestring_ascii, FLOAT_REPR, INFINITY

    # BEWARE: huge chunks of code copy-pasted here, because there is no
    # sane hook to work with

    class ExtendedJSONEncoder(JSONEncoder):
        def iterencode(self, o, _one_shot=False):
            """Encode the given object and yield each string
            representation as available.

            For example::

                for chunk in JSONEncoder().iterencode(bigobject):
                    mysocket.write(chunk)

            """
            if self.check_circular:
                markers = {}
            else:
                markers = None
            if self.ensure_ascii:
                _encoder = encode_basestring_ascii
            else:
                _encoder = encode_basestring
            if self.encoding != 'utf-8':
                def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding):
                    if isinstance(o, str):
                        o = o.decode(_encoding)
                    return _orig_encoder(o)

            def floatstr(o, allow_nan=self.allow_nan,
                    _repr=FLOAT_REPR, _inf=INFINITY, _neginf=-INFINITY):
                # Check for specials.  Note that this type of test is processor
                # and/or platform-specific, so do tests which don't depend on the
                # internals.

                if o != o:
                    text = 'NaN'
                elif o == _inf:
                    text = 'Infinity'
                elif o == _neginf:
                    text = '-Infinity'
                else:
                    return _repr(o)

                if not allow_nan:
                    raise ValueError(
                        "Out of range float values are not JSON compliant: " +
                        repr(o))

                return text

            # DIFFERENCE WITH PYTHON2.7 VERSION HERE
            # We need our custom _make_iterencode.
            _iterencode = _make_iterencode(
                markers, self.default, _encoder, self.indent, floatstr,
                self.key_separator, self.item_separator, self.sort_keys,
                self.skipkeys, _one_shot)
            # END DIFFERENCE
            return _iterencode(o, 0)

    def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
            _key_separator, _item_separator, _sort_keys, _skipkeys, _one_shot,
            ## HACK: hand-optimized bytecode; turn globals into locals
            ValueError=ValueError,
            basestring=basestring,
            dict=dict,
            float=float,
            id=id,
            int=int,
            isinstance=isinstance,
            list=list,
            long=long,
            str=str,
            tuple=tuple,
        ):

        def _iterencode_list(lst, _current_indent_level):
            if not lst:
                yield '[]'
                return
            if markers is not None:
                markerid = id(lst)
                if markerid in markers:
                    raise ValueError("Circular reference detected")
                markers[markerid] = lst
            buf = '['
            if _indent is not None:
                _current_indent_level += 1
                newline_indent = '\n' + (' ' * (_indent * _current_indent_level))
                separator = _item_separator + newline_indent
                buf += newline_indent
            else:
                newline_indent = None
                separator = _item_separator
            first = True
            for value in lst:
                if first:
                    first = False
                else:
                    buf = separator
                if isinstance(value, basestring):
                    yield buf + _encoder(value)
                elif value is None:
                    yield buf + 'null'
                elif value is True:
                    yield buf + 'true'
                elif value is False:
                    yield buf + 'false'
                elif isinstance(value, (int, long)):
                    yield buf + str(value)
                elif isinstance(value, float):
                    yield buf + _floatstr(value)
                else:
                    yield buf
                    if isinstance(value, (list, tuple)):
                        chunks = _iterencode_list(value, _current_indent_level)
                    elif isinstance(value, dict):
                        chunks = _iterencode_dict(value, _current_indent_level)
                    else:
                        chunks = _iterencode(value, _current_indent_level)
                    for chunk in chunks:
                        yield chunk
            if newline_indent is not None:
                _current_indent_level -= 1
                yield '\n' + (' ' * (_indent * _current_indent_level))
            yield ']'
            if markers is not None:
                del markers[markerid]

        def _iterencode_dict(dct, _current_indent_level):
            if not dct:
                yield '{}'
                return
            if markers is not None:
                markerid = id(dct)
                if markerid in markers:
                    raise ValueError("Circular reference detected")
                markers[markerid] = dct
            yield '{'
            if _indent is not None:
                _current_indent_level += 1
                newline_indent = '\n' + (' ' * (_indent * _current_indent_level))
                item_separator = _item_separator + newline_indent
                yield newline_indent
            else:
                newline_indent = None
                item_separator = _item_separator
            first = True
            if _sort_keys:
                items = sorted(dct.items(), key=lambda kv: kv[0])
            else:
                items = dct.iteritems()
            for key, value in items:
                if isinstance(key, basestring):
                    pass
                # JavaScript is weakly typed for these, so it makes sense to
                # also allow them.  Many encoders seem to do something like this.
                elif isinstance(key, float):
                    key = _floatstr(key)
                elif key is True:
                    key = 'true'
                elif key is False:
                    key = 'false'
                elif key is None:
                    key = 'null'
                elif isinstance(key, (int, long)):
                    key = str(key)
                elif _skipkeys:
                    continue
                else:
                    raise TypeError("key " + repr(key) + " is not a string")
                if first:
                    first = False
                else:
                    yield item_separator
                yield _encoder(key)
                yield _key_separator
                if isinstance(value, basestring):
                    yield _encoder(value)
                elif value is None:
                    yield 'null'
                elif value is True:
                    yield 'true'
                elif value is False:
                    yield 'false'
                elif isinstance(value, (int, long)):
                    yield str(value)
                elif isinstance(value, float):
                    yield _floatstr(value)
                else:
                    if isinstance(value, (list, tuple)):
                        chunks = _iterencode_list(value, _current_indent_level)
                    elif isinstance(value, dict):
                        chunks = _iterencode_dict(value, _current_indent_level)
                    else:
                        chunks = _iterencode(value, _current_indent_level)
                    for chunk in chunks:
                        yield chunk
            if newline_indent is not None:
                _current_indent_level -= 1
                yield '\n' + (' ' * (_indent * _current_indent_level))
            yield '}'
            if markers is not None:
                del markers[markerid]

        def _iterencode(o, _current_indent_level):
            if isinstance(o, basestring):
                yield _encoder(o)
            elif o is None:
                yield 'null'
            elif o is True:
                yield 'true'
            elif o is False:
                yield 'false'
            elif isinstance(o, (int, long)):
                yield str(o)
            elif isinstance(o, float):
                yield _floatstr(o)
            elif isinstance(o, (list, tuple)):
                for chunk in _iterencode_list(o, _current_indent_level):
                    yield chunk
            elif isinstance(o, dict):
                for chunk in _iterencode_dict(o, _current_indent_level):
                    yield chunk
            # DIFFERENCE WITH PYTHON2.7 VERSION HERE
            # Yes.. the only place we wanted to edit.
            elif isinstance(o, date) or isinstance(o, datetime):
                # Microsoft JSON Extension for Date(<milliseconds>)
                if isinstance(o, date) or isinstance(o, datetime):
                    microseconds = long(mktime(o.timetuple()) * 1000)
                    if hasattr(o, 'microsecond'):
                        microseconds += long(o.microsecond / 1000)
                    yield r'"\/Date(%d)\/"' % (microseconds,)
            # END DIFFERENCE
            else:
                if markers is not None:
                    markerid = id(o)
                    if markerid in markers:
                        raise ValueError("Circular reference detected")
                    markers[markerid] = o
                o = _default(o)
                for chunk in _iterencode(o, _current_indent_level):
                    yield chunk
                if markers is not None:
                    del markers[markerid]

        return _iterencode


#-----------------------------------------------------------------------
# Decoding extensions
# (this one is a bit hairier than _encode_version 1, but still better
# than _encode_version 2)
#-----------------------------------------------------------------------

try:
    from json.scanner import Scanner
except ImportError:
    _decode_version = 2
else:
    _decode_version = 1

def _isnumeric(data):
    if not data:
        return False
    if data.isdigit():
        return True
    if data[0] == '-' and data[1:].isdigit():
        return True
    return False


#-----------------------------------------------------------------------
# Decoding extensions (old style)
#-----------------------------------------------------------------------

if _decode_version == 1:
    from json import JSONDecoder, decoder
    from json.decoder import JSONObject, JSONArray, JSONString, JSONConstant, JSONNumber
    from json.scanner import Scanner, pattern

    def MicrosoftJSONExtensions(match, context):
        encoding = getattr(context, 'encoding', None)
        strict = getattr(context, 'strict', True)

        s = match.string
        e = match.end()
        escaped = False

        for i in xrange(e, len(s)):
            ch = s[i]
            if escaped:
                escaped = False
            else:
                if ch == '\\':
                    escaped = True
                elif ch == '"':
                    break
        else:
            raise ValueError('Something went wrong...')

        if s[i-2:i] != '\/':
            raise ValueError('Something went wrong...')

        extension = s[e:i-2]

        # Date(<milliseconds>) => datetime
        if (extension.startswith('Date(') and _isnumeric(extension[5:-1])
            and extension.endswith(')')):
            date_ms = long(extension[5:-1])
            object = datetime.fromtimestamp(date_ms / 1000.0)
        else:
            raise ValueError('Unknown extension: %s' % (extension,))

        return (object, i + 1)
    pattern(r'"\\/')(MicrosoftJSONExtensions)

    ANYTHING = [
        JSONObject,
        JSONArray,
        MicrosoftJSONExtensions,
        JSONString,
        JSONConstant,
        JSONNumber,
    ]

    # We must update the decoder.JSONScanner, because the JSON* functions from there
    # reference that directly.
    # Observe that this may affect global scope!
    decoder.JSONScanner = Scanner(ANYTHING)

    class ExtendedJSONDecoder(JSONDecoder):
        _scanner = Scanner(ANYTHING)


#-----------------------------------------------------------------------
# Decoding extensions (new style)
#-----------------------------------------------------------------------

elif _decode_version == 2:
    from json import scanner
    from json.decoder import JSONDecoder, DEFAULT_ENCODING, BACKSLASH, STRINGCHUNK, errmsg

    def scanstring(s, end, encoding=None, strict=True,
            _b=BACKSLASH, _m=STRINGCHUNK.match):
        """Scan the string s for a JSON string. End is the index of the
        character in s after the quote that started the JSON string.
        Unescapes all valid JSON string escape sequences and raises ValueError
        on attempt to decode an invalid string. If strict is False then literal
        control characters are allowed in the string.
    
        Returns a tuple of the decoded string and the index of the character in s
        after the end quote."""
        if encoding is None:
            encoding = DEFAULT_ENCODING
        chunks = []
        _append = chunks.append
        begin = end - 1
        while 1:
            chunk = _m(s, end)
            if chunk is None:
                raise ValueError(
                    errmsg("Unterminated string starting at", s, begin))
            end = chunk.end()
            content, terminator = chunk.groups()
            # Content is contains zero or more unescaped string characters
            if content:
                if not isinstance(content, unicode):
                    content = unicode(content, encoding)
                _append(content)
            # Terminator is the end of string, a literal control character,
            # or a backslash denoting that an escape sequence follows
            if terminator == '"':
                break
            elif terminator != '\\':
                if strict:
                    #msg = "Invalid control character %r at" % (terminator,)
                    msg = "Invalid control character {0!r} at".format(terminator)
                    raise ValueError(errmsg(msg, s, end))
                else:
                    _append(terminator)
                    continue
            try:
                esc = s[end]
            except IndexError:
                raise ValueError(
                    errmsg("Unterminated string starting at", s, begin))
            # If not a unicode escape sequence, must be in the lookup table
            if esc != 'u':
                try:
                    char = _b[esc]
                except KeyError:
                    msg = "Invalid \\escape: " + repr(esc)
                    raise ValueError(errmsg(msg, s, end))
                # DIFFERENCE WITH PYTHON2.7
                if not chunks and char == '/':
                    if s.startswith('/Date(', end):
                        tmp = s.find(r')\/"', end + 6)
                        if tmp != -1 and _isnumeric(s[end+6:tmp]):
                            date_ms = long(s[end+6:tmp])
                            object = datetime.fromtimestamp(date_ms / 1000.0)
                            return object, tmp + 4
                # END DIFFERENCE
                end += 1
            else:
                # Unicode escape sequence
                esc = s[end + 1:end + 5]
                next_end = end + 5
                if len(esc) != 4:
                    msg = "Invalid \\uXXXX escape"
                    raise ValueError(errmsg(msg, s, end))
                uni = int(esc, 16)
                # Check for surrogate pair on UCS-4 systems
                if 0xd800 <= uni <= 0xdbff and sys.maxunicode > 65535:
                    msg = "Invalid \\uXXXX\\uXXXX surrogate pair"
                    if not s[end + 5:end + 7] == '\\u':
                        raise ValueError(errmsg(msg, s, end))
                    esc2 = s[end + 7:end + 11]
                    if len(esc2) != 4:
                        raise ValueError(errmsg(msg, s, end))
                    uni2 = int(esc2, 16)
                    uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00))
                    next_end += 6
                char = unichr(uni)
                end = next_end
            # Append the unescaped character
            _append(char)
        return u''.join(chunks), end
    
    class ExtendedJSONDecoder(JSONDecoder):
        def __init__(self, *args, **kwargs):
            super(ExtendedJSONDecoder, self).__init__(*args, **kwargs)
            self.parse_string = scanstring
            # Must re-create scanner, and use the py_* version, because
            # the C-version doesn't take our custom parser.
            self.scan_once = scanner.py_make_scanner(self)


#-----------------------------------------------------------------------
# Test
#-----------------------------------------------------------------------
       
if __name__ == '__main__':
    import json, unittest

    class ExtensionsTest(unittest.TestCase):
        def get_now(self):
            # Round datetime to millisecond precision
            now = datetime.now()
            return now.replace(microsecond=int(now.microsecond / 1000) * 1000)

        def stays_the_same(self, data):
            encoded = json.dumps(data, cls=ExtendedJSONEncoder)
            decoded = json.loads(encoded, cls=ExtendedJSONDecoder)
            self.assertEquals(data, decoded)

        def test_date_to_string(self):
            once_upon_a_time = datetime(1911, 2, 3, 4, 5, 6, 7000)
            encoded = json.dumps(once_upon_a_time, cls=ExtendedJSONEncoder)
            expected = r'"\/Date(-1859055265993)\/"'
            self.assertEquals(encoded, expected)

        def test_string_to_date(self):
            once_upon_a_time = r'"\/Date(-1859055265992)\/"'
            decoded = json.loads(once_upon_a_time, cls=ExtendedJSONDecoder)
            expected = datetime(1911, 2, 3, 4, 5, 6, 8000)
            self.assertEquals(decoded, expected)

        def test_string(self):
            self.stays_the_same('/Date(123)/')

        def test_date(self):
            # Special, since dates will get cast to a datetime
            today = date.today()
            encoded = json.dumps(today, cls=ExtendedJSONEncoder)
            decoded = json.loads(encoded, cls=ExtendedJSONDecoder)
            self.assertEquals(today, decoded.date())

        def test_datetime(self):
            self.stays_the_same(self.get_now())

        def test_array(self):
            self.stays_the_same([1, self.get_now(), 3])

        def test_object(self):
            self.stays_the_same({'key1': self.get_now(), 'key2': 'something'})

    unittest.main()
