# vim: set ts=8 sw=4 sts=4 et ai tw=79:
"""
minidom_xhtml -- Provides parseStringXHTML to parse XHTML entities
Copyright (C) 2014  Walter Doekes <wdoekes>, OSSO B.V.

Problem::

    When you feed the xml.dom.minidom parseString an XHTML document that
    contains HTML entities other than the ampersant, quote and the two
    angle brackets, these entities get eaten.

    parseString('...&agrave;bc...') yields u'...bc...'

Cause::

    parseString doesn't preload the XHTML entities as defined in
    http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd. But it does cope
    with entities defined inline.

Solution::

    This module provides parseStringXHTML which adds the necessary
    entity declarations from xhtml-lat1.ent, xhtml-special.ent and
    xhtml-symbol.ent.

    parseStringXHTML('...&agrave;bc...') yields u'...<U+00E0>bc...'

Info::

    license = Public Domain
    version = 1 (2014-01-15)
"""
import os.path
from unittest import TestCase, main
from xml.dom.minidom import parseString
from xml.parsers.expat import ExpatError


def parseStringXHTML(input):
    """
    Same as xml.dom.minidom.parseString, but it adds the ENTITY
    declarations so characters like &agrave; will get translated
    into U+00E0 properly.
    """
    # xhtml1-strict.dtd contains these ENTITY declarations:
    # <!ENTITY % HTMLlat1 PUBLIC
    #   "-//W3C//ENTITIES Latin 1 for XHTML//EN"
    #   "xhtml-lat1.ent">
    # for xhtml-lat1, xhtml-symbol and xhtml-special.
    #
    # We load those up manually here.
    here = os.path.dirname(__file__)
    entities = []
    for name in ('xhtml-lat1.ent', 'xhtml-symbol.ent', 'xhtml-special.ent'):
        with open(os.path.join(here, name), 'r') as file:
            entities.append(file.read())
    entities = '\n'.join(entities)

    # And now we'll insert them into the DOCTYPE.
    input = input.lstrip()  # remove leading blanks

    if not input.startswith('<!DOCTYPE'):
        # Prepend a bit of doctype with the entities:
        input = '''<!DOCTYPE html
                     PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"
                     "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd"
                [\n%s\n]>%s''' % (entities, input)

    else:
        # Attempt to insert the entities into the DOCTYPE.
        # NOTE: this doesn't work if other entities are defined..!
        head, tail = input.split('>', 1)
        input = '%s\n[\n%s\n]>%s' % (head, entities, tail)

    return parseString(input)


TEST_INPUT = u'Agr&agrave;ve &amp; an &mdash; extra &mdash; d\u00e4sh.'
BAD_OUTPUT = u'Agrve & an  extra  d\u00e4sh.'
GOOD_OUTPUT = u'Agr\u00e0ve & an \u2014 extra \u2014 d\u00e4sh.'
TEST_INPUTS = (
    # parseString silently eats the entity characters here.
    u'''<!DOCTYPE html
          PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"
          "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
        <html xmlns="http://www.w3.org/1999/xhtml" lang="en" xml:lang="en">
            <head><title>title</title>
                  <meta http-equiv="Content-Type"
                        content="text/html; charset=UTF-8"/></head>
            <body><p>%s</p></body></html>''' % (TEST_INPUT,),
    # parseString throws an ExpatError here.
    u'''<body><p>%s</p></body>''' % (TEST_INPUT,),
    # .. and here.
    u'''<!DOCTYPE whatever><body><p>%s</p></body>''' % (TEST_INPUT,),
)


class BaseTest(object):
    def get_parse_func_and_output(self):
        raise NotImplementedError()

    def run_func(self, n):
        # parseString takes bstring input, wholeText returns unicode.
        parse_func, output = self.get_parse_func_and_output()
        try:
            parsed = parse_func(TEST_INPUTS[n].encode('utf-8'))
        except ExpatError:
            if output == BAD_OUTPUT:
                # We expect an ExpatError for certain test cases of
                # parseString.
                pass
            else:
                # The fixed version shouldn't have any of that.
                raise
        else:
            p = parsed.getElementsByTagName('p')[0]
            text = p.firstChild
            self.assertEquals(type(text.wholeText), type(u''))  # is-unicode
            self.assertEquals(text.wholeText, output)

    for n in range(len(TEST_INPUTS)):
        func_n = (lambda m: (lambda self: self.run_func(m)))(n)
        locals()['test_%d' % n] = func_n


class TestParseStringFailure(BaseTest, TestCase):
    def get_parse_func_and_output(self):
        return parseString, BAD_OUTPUT


class TestParseStringXHTMLSuccess(BaseTest, TestCase):
    def get_parse_func_and_output(self):
        return parseStringXHTML, GOOD_OUTPUT


if __name__ == '__main__':
    main()
