Fix issues with stdout and updates tests.

Not so clean since trying to change stdout encoding requires accessing
sys.stdout.buffer, so fake_env has to mock this layer also. The basic
differences between p2 and p3 are handled in p3.py.
main
Olivier Mangin 11 years ago
parent 52813439dd
commit 0479636393

@ -8,6 +8,11 @@ if sys.version_info[0] == 2:
def input(): def input():
raw_input().decode(sys.stdin.encoding or 'utf8', 'ignore') raw_input().decode(sys.stdin.encoding or 'utf8', 'ignore')
# The following has to be a function so that it can be mocked
# for test_usecase.
def _get_raw_stdout():
return sys.stdout
ustr = unicode ustr = unicode
uchr = unichr uchr = unichr
from urlparse import urlparse from urlparse import urlparse
@ -15,6 +20,12 @@ if sys.version_info[0] == 2:
from httplib import HTTPConnection from httplib import HTTPConnection
file = None file = None
_fake_stdio = io.BytesIO # Only for tests to capture std{out,err} _fake_stdio = io.BytesIO # Only for tests to capture std{out,err}
def _get_fake_stdio_ucontent(stdio):
ustdio = io.TextIOWrapper(stdio)
ustdio.seek(0)
return ustdio.read()
else: else:
import configparser import configparser
_read_config = configparser.SafeConfigParser.read_file _read_config = configparser.SafeConfigParser.read_file
@ -23,7 +34,19 @@ else:
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.request import urlopen from urllib.request import urlopen
from http.client import HTTPConnection from http.client import HTTPConnection
_fake_stdio = io.StringIO # Only for tests to capture std{out,err}
def _fake_stdio():
return io.TextIOWrapper(io.BytesIO()) # Only for tests to capture std{out,err}
def _get_fake_stdio_ucontent(stdio):
stdio.flush()
stdio.seek(0)
return stdio.read()
# The following has to be a function so that it can be mocked
# for test_usecase.
def _get_raw_stdout():
return sys.stdout.buffer
configparser = configparser configparser = configparser
input = input input = input

@ -1,10 +1,12 @@
from __future__ import print_function from __future__ import print_function
import sys import sys
import locale
import codecs
from .content import editor_input from .content import editor_input
from . import color from . import color
import locale from .p3 import _get_raw_stdout
# package-shared ui that can be accessed using : # package-shared ui that can be accessed using :
@ -16,12 +18,12 @@ _ui = None
def _get_encoding(config): def _get_encoding(config):
"""Get local terminal encoding or user preference in config.""" """Get local terminal encoding or user preference in config."""
enc = 'utf8' enc = None
try: try:
enc = locale.getdefaultlocale()[1] or 'utf8' enc = locale.getdefaultlocale()[1]
except ValueError: except ValueError:
pass # Keep default pass # Keep default
return config.get('terminal-encoding', enc) return config.get('terminal-encoding', enc or 'utf8')
def get_ui(): def get_ui():
@ -40,16 +42,18 @@ class UI:
""" """
def __init__(self, config): def __init__(self, config):
self.encoding = _get_encoding(config)
color.setup(config.color) color.setup(config.color)
self.editor = config.edit_cmd self.editor = config.edit_cmd
self.encoding = _get_encoding(config)
self._stdout = codecs.getwriter(self.encoding)(_get_raw_stdout(),
errors='replace')
def print_(self, *strings): def print_(self, *strings):
"""Like print, but rather than raising an error when a character """Like print, but rather than raising an error when a character
is not in the terminal's encoding's character set, just silently is not in the terminal's encoding's character set, just silently
replaces it. replaces it.
""" """
print(' '.join(strings).encode(self.encoding, 'replace')) print(' '.join(strings), file=self._stdout)
def input(self): def input(self):
try: try:

@ -10,7 +10,7 @@ import fake_filesystem
import fake_filesystem_shutil import fake_filesystem_shutil
import fake_filesystem_glob import fake_filesystem_glob
from pubs.p3 import input, _fake_stdio from pubs.p3 import input, _fake_stdio, _get_fake_stdio_ucontent
from pubs import content, filebroker from pubs import content, filebroker
# code for fake fs # code for fake fs
@ -176,7 +176,7 @@ def redirect(f):
stderr = _fake_stdio() stderr = _fake_stdio()
sys.stdout, sys.stderr = stdout, stderr sys.stdout, sys.stderr = stdout, stderr
try: try:
return f(*args, **kwargs), stdout, stderr return f(*args, **kwargs), _get_fake_stdio_ucontent(stdout), _get_fake_stdio_ucontent(stderr)
finally: finally:
sys.stderr, sys.stdout = old_stderr, old_stdout sys.stderr, sys.stdout = old_stderr, old_stdout
return newf return newf

@ -59,7 +59,7 @@ class CommandTestCase(unittest.TestCase):
self.fs = fake_env.create_fake_fs([content, filebroker, configs, init_cmd, import_cmd]) self.fs = fake_env.create_fake_fs([content, filebroker, configs, init_cmd, import_cmd])
self.default_pubs_dir = self.fs['os'].path.expanduser('~/.pubs') self.default_pubs_dir = self.fs['os'].path.expanduser('~/.pubs')
def execute_cmds(self, cmds, fs=None, capture_output=CAPTURE_OUTPUT): def execute_cmds(self, cmds, capture_output=CAPTURE_OUTPUT):
""" Execute a list of commands, and capture their output """ Execute a list of commands, and capture their output
A command can be a string, or a tuple of size 2 or 3. A command can be a string, or a tuple of size 2 or 3.
@ -79,7 +79,7 @@ class CommandTestCase(unittest.TestCase):
if capture_output: if capture_output:
_, stdout, stderr = fake_env.redirect(pubs_cmd.execute)(cmd[0].split()) _, stdout, stderr = fake_env.redirect(pubs_cmd.execute)(cmd[0].split())
if len(cmd) == 3 and capture_output: if len(cmd) == 3 and capture_output:
actual_out = color.undye(stdout.getvalue()) actual_out = color.undye(stdout)
correct_out = color.undye(cmd[2]) correct_out = color.undye(cmd[2])
self.assertEqual(actual_out, correct_out) self.assertEqual(actual_out, correct_out)
else: else:
@ -93,8 +93,8 @@ class CommandTestCase(unittest.TestCase):
pubs_cmd.execute(cmd.split()) pubs_cmd.execute(cmd.split())
if capture_output: if capture_output:
assert(stderr.getvalue() == '') assert(stderr == '')
outs.append(color.undye(stdout.getvalue())) outs.append(color.undye(stdout))
if PRINT_OUTPUT: if PRINT_OUTPUT:
print(outs) print(outs)
return outs return outs
@ -119,13 +119,13 @@ class TestInit(CommandTestCase):
def test_init(self): def test_init(self):
pubsdir = os.path.expanduser('~/pubs_test2') pubsdir = os.path.expanduser('~/pubs_test2')
pubs_cmd.execute('pubs init -p {}'.format(pubsdir).split()) self.execute_cmds(['pubs init -p {}'.format(pubsdir)])
self.assertEqual(set(self.fs['os'].listdir(pubsdir)), self.assertEqual(set(self.fs['os'].listdir(pubsdir)),
{'bib', 'doc', 'meta', 'notes'}) {'bib', 'doc', 'meta', 'notes'})
def test_init2(self): def test_init2(self):
pubsdir = os.path.expanduser('~/.pubs') pubsdir = os.path.expanduser('~/.pubs')
pubs_cmd.execute('pubs init'.split()) self.execute_cmds(['pubs init'])
self.assertEqual(set(self.fs['os'].listdir(pubsdir)), self.assertEqual(set(self.fs['os'].listdir(pubsdir)),
{'bib', 'doc', 'meta', 'notes'}) {'bib', 'doc', 'meta', 'notes'})
@ -295,11 +295,11 @@ class TestUsecase(DataCommandTestCase):
def test_tag_list(self): def test_tag_list(self):
correct = [b'Initializing pubs in /paper_first\n', correct = ['Initializing pubs in /paper_first\n',
b'', '',
b'', '',
b'', '',
b'search network\n', 'search network\n',
] ]
cmds = ['pubs init -p paper_first/', cmds = ['pubs init -p paper_first/',
@ -363,8 +363,8 @@ class TestUsecase(DataCommandTestCase):
'pubs export Page99', 'pubs export Page99',
] ]
outs = self.execute_cmds(cmds) outs = self.execute_cmds(cmds)
out_raw = outs[2].decode() self.assertEqual(endecoder.EnDecoder().decode_bibdata(outs[2]),
self.assertEqual(endecoder.EnDecoder().decode_bibdata(out_raw), fixtures.page_bibdata) fixtures.page_bibdata)
def test_import(self): def test_import(self):
cmds = ['pubs init', cmds = ['pubs init',

Loading…
Cancel
Save