Refactors paper filtering from queries.

main
Olivier Mangin 7 years ago
parent e3f2e7db26
commit c54de5c3b6

@ -25,9 +25,9 @@ def parser(subparsers, conf):
parser.add_argument('--no-docs', action='store_true', parser.add_argument('--no-docs', action='store_true',
dest='nodocs', default=False, dest='nodocs', default=False,
help='list only pubs without attached documents.') help='list only pubs without attached documents.')
parser.add_argument('query', nargs='*', parser.add_argument('query', nargs='*',
help='Paper query ("author:Einstein", "title:learning", "year:2000" or "tags:math")') help='Paper query ("author:Einstein", "title:learning",'
'"year:2000" or "tags:math")')
return parser return parser
@ -38,7 +38,7 @@ def date_added(p):
def command(conf, args): def command(conf, args):
ui = get_ui() ui = get_ui()
rp = repo.Repository(conf) rp = repo.Repository(conf)
papers = filter(lambda p: filter_paper(p, args.query, papers = filter(get_paper_filter(args.query,
case_sensitive=args.case_sensitive), case_sensitive=args.case_sensitive),
rp.all_papers()) rp.all_papers())
if args.nodocs: if args.nodocs:
@ -60,67 +60,82 @@ FIELD_ALIASES = {
'authors': 'author', 'authors': 'author',
't': 'title', 't': 'title',
'tags': 'tag', 'tags': 'tag',
'y': 'year',
} }
def _get_field_value(query_block): class QueryFilter(object):
split_block = query_block.split(':')
if len(split_block) != 2: def __init__(self, query, case_sensitive=None):
raise InvalidQuery("Invalid query (%s)" % query_block) if case_sensitive is None:
field = split_block[0] case_sensitive = not query.islower()
if field in FIELD_ALIASES: self.case = case_sensitive
field = FIELD_ALIASES[field] self.query = self._lower(query)
value = split_block[1]
return (field, value) def __call__(self, paper):
raise NotImplementedError
def _lower(self, s):
return s if self.case else s.lower()
class FieldFilter(QueryFilter):
"""Generic filter of form `query in paper['field']`"""
def _lower(s, lower=True): def __init__(self, field, query, case_sensitive=None):
return s.lower() if lower else s super(FieldFilter, self).__init__(query, case_sensitive=case_sensitive)
self.field = field
def __call__(self, paper):
return (self.field in paper.bibdata and
self.query in self._lower(paper.bibdata[self.field]))
def _check_author_match(paper, query, case_sensitive=False):
class AuthorFilter(QueryFilter):
def __call__(self, paper):
"""Only checks within last names.""" """Only checks within last names."""
if not 'author' in paper.bibdata: if 'author' not in paper.bibdata:
return False return False
return any([query in _lower(bibstruct.author_last(p), lower=(not case_sensitive)) else:
for p in paper.bibdata['author']]) return any([self.query in self._lower(bibstruct.author_last(author))
for author in paper.bibdata['author']])
class TagFilter(QueryFilter):
def _check_tag_match(paper, query, case_sensitive=False): def __call__(self, paper):
return any([query in _lower(t, lower=(not case_sensitive)) return any([self.query in self._lower(t) for t in paper.tags])
for t in paper.tags])
def _check_field_match(paper, field, query, case_sensitive=False): def _get_field_value(query_block):
return query in _lower(paper.bibdata[field], split_block = query_block.split(':')
lower=(not case_sensitive)) if len(split_block) != 2:
raise InvalidQuery("Invalid query (%s)" % query_block)
field = split_block[0]
if field in FIELD_ALIASES:
field = FIELD_ALIASES[field]
value = split_block[1]
return (field, value)
def _check_query_block(paper, query_block, case_sensitive=None): def _query_block_to_filter(query_block, case_sensitive=None):
field, value = _get_field_value(query_block) field, value = _get_field_value(query_block)
if case_sensitive is None:
case_sensitive = not value.islower()
elif not case_sensitive:
value = value.lower()
if field == 'tag': if field == 'tag':
return _check_tag_match(paper, value, case_sensitive=case_sensitive) return TagFilter(value, case_sensitive=case_sensitive)
elif field == 'author': elif field == 'author':
return _check_author_match(paper, value, case_sensitive=case_sensitive) return AuthorFilter(value, case_sensitive=case_sensitive)
elif field in paper.bibdata:
return _check_field_match(paper, field, value,
case_sensitive=case_sensitive)
else: else:
return False return FieldFilter(field, value, case_sensitive=case_sensitive)
# TODO implement search by type of document # TODO implement search by type of document
def filter_paper(paper, query, case_sensitive=None): def get_paper_filter(query, case_sensitive=None):
"""If case_sensitive is not given, only check case if query """If case_sensitive is not given, only check case if query
is not lowercase. is not lowercase.
:args query: list of query blocks (strings) :args query: list of query blocks (strings)
""" """
return all([_check_query_block(paper, query_block, filters = [_query_block_to_filter(query_block, case_sensitive=case_sensitive)
case_sensitive=case_sensitive) for query_block in query]
for query_block in query]) return lambda paper: all([f(paper) for f in filters])

@ -1,10 +1,10 @@
import unittest import unittest
import dotdot import dotdot
from pubs.commands.list_cmd import (_check_author_match, from pubs.commands.list_cmd import (AuthorFilter,
_check_field_match, FieldFilter,
_check_query_block, _query_block_to_filter,
filter_paper, get_paper_filter,
InvalidQuery) InvalidQuery)
from pubs.paper import Paper from pubs.paper import Paper
@ -16,28 +16,30 @@ page_paper = Paper.from_bibentry(fixtures.page_bibentry)
turing_paper = Paper.from_bibentry(fixtures.turing_bibentry, turing_paper = Paper.from_bibentry(fixtures.turing_bibentry,
metadata=fixtures.turing_metadata) metadata=fixtures.turing_metadata)
class TestAuthorFilter(unittest.TestCase): class TestAuthorFilter(unittest.TestCase):
def test_fails_if_no_author(self): def test_fails_if_no_author(self):
no_doe = doe_paper.deepcopy() no_doe = doe_paper.deepcopy()
no_doe.bibentry['author'] = [] no_doe.bibentry['author'] = []
self.assertTrue(not _check_author_match(no_doe, 'whatever')) self.assertTrue(not AuthorFilter('whatever')(no_doe))
def test_match_case(self): def test_match_case(self):
self.assertTrue(_check_author_match(doe_paper, 'doe')) self.assertTrue(AuthorFilter('doe')(doe_paper))
self.assertTrue(_check_author_match(doe_paper, 'doe', self.assertTrue(AuthorFilter('doe', case_sensitive=False)(doe_paper))
case_sensitive=False)) self.assertTrue(AuthorFilter('Doe')(doe_paper))
def test_do_not_match_case(self): def test_do_not_match_case(self):
self.assertFalse(_check_author_match(doe_paper, 'dOe')) self.assertFalse(AuthorFilter('dOe')(doe_paper))
self.assertFalse(_check_author_match(doe_paper, 'doe', self.assertFalse(AuthorFilter('dOe', case_sensitive=True)(doe_paper))
case_sensitive=True)) self.assertFalse(AuthorFilter('doe', case_sensitive=True)(doe_paper))
self.assertTrue(AuthorFilter('dOe', case_sensitive=False)(doe_paper))
def test_match_not_first_author(self): def test_match_not_first_author(self):
self.assertTrue(_check_author_match(page_paper, 'motwani')) self.assertTrue(AuthorFilter('motwani')(page_paper))
def test_do_not_match_first_name(self): def test_do_not_match_first_name(self):
self.assertTrue(not _check_author_match(page_paper, 'larry')) self.assertTrue(not AuthorFilter('larry')(page_paper))
class TestCheckTag(unittest.TestCase): class TestCheckTag(unittest.TestCase):
@ -47,55 +49,52 @@ class TestCheckTag(unittest.TestCase):
class TestCheckField(unittest.TestCase): class TestCheckField(unittest.TestCase):
def test_match_case(self): def test_match_case(self):
self.assertTrue(_check_field_match(doe_paper, 'title', 'nice')) self.assertTrue(FieldFilter('title', 'nice')(doe_paper))
self.assertTrue(_check_field_match(doe_paper, 'title', 'nice', self.assertTrue(
case_sensitive=False)) FieldFilter('title', 'nice', case_sensitive=False)(doe_paper))
self.assertTrue(_check_field_match(doe_paper, 'year', '2013')) self.assertTrue(FieldFilter('year', '2013')(doe_paper))
def test_do_not_match_case(self): def test_do_not_match_case(self):
self.assertTrue(_check_field_match(doe_paper, 'title', self.assertTrue(
'Title', case_sensitive=True)) FieldFilter('title', 'Title', case_sensitive=True)(doe_paper))
self.assertFalse(_check_field_match(doe_paper, 'title', 'nice', self.assertFalse(
case_sensitive=True)) FieldFilter('title', 'nice', case_sensitive=True)(doe_paper))
class TestCheckQueryBlock(unittest.TestCase): class TestCheckQueryBlock(unittest.TestCase):
def test_raise_invalid_if_no_value(self): def test_raise_invalid_if_no_value(self):
with self.assertRaises(InvalidQuery): with self.assertRaises(InvalidQuery):
_check_query_block(doe_paper, 'title') _query_block_to_filter('title')
def test_raise_invalid_if_too_much(self): def test_raise_invalid_if_too_much(self):
with self.assertRaises(InvalidQuery): with self.assertRaises(InvalidQuery):
_check_query_block(doe_paper, 'whatever:value:too_much') _query_block_to_filter('whatever:value:too_much')
class TestFilterPaper(unittest.TestCase): class TestFilterPaper(unittest.TestCase):
def test_case(self): def test_case(self):
self.assertTrue (filter_paper(doe_paper, ['title:nice'])) self.assertTrue(get_paper_filter(['title:nice'])(doe_paper))
self.assertTrue (filter_paper(doe_paper, ['title:Nice'])) self.assertTrue(get_paper_filter(['title:Nice'])(doe_paper))
self.assertFalse(filter_paper(doe_paper, ['title:nIce'])) self.assertFalse(get_paper_filter(['title:nIce'])(doe_paper))
def test_fields(self): def test_fields(self):
self.assertTrue (filter_paper(doe_paper, ['year:2013'])) self.assertTrue(get_paper_filter(['year:2013'])(doe_paper))
self.assertFalse(filter_paper(doe_paper, ['year:2014'])) self.assertFalse(get_paper_filter(['year:2014'])(doe_paper))
self.assertTrue (filter_paper(doe_paper, ['author:doe'])) self.assertTrue(get_paper_filter(['author:doe'])(doe_paper))
self.assertTrue (filter_paper(doe_paper, ['author:Doe'])) self.assertTrue(get_paper_filter(['author:Doe'])(doe_paper))
def test_tags(self): def test_tags(self):
self.assertTrue (filter_paper(turing_paper, ['tag:computer'])) self.assertTrue(get_paper_filter(['tag:computer'])(turing_paper))
self.assertFalse(filter_paper(turing_paper, ['tag:Ai'])) self.assertFalse(get_paper_filter(['tag:Ai'])(turing_paper))
self.assertTrue (filter_paper(turing_paper, ['tag:AI'])) self.assertTrue(get_paper_filter(['tag:AI'])(turing_paper))
self.assertTrue (filter_paper(turing_paper, ['tag:ai'])) self.assertTrue(get_paper_filter(['tag:ai'])(turing_paper))
def test_multiple(self): def test_multiple(self):
self.assertTrue (filter_paper(doe_paper, self.assertTrue(get_paper_filter(['author:doe', 'year:2013'])(doe_paper))
['author:doe', 'year:2013'])) self.assertFalse(get_paper_filter(['author:doe', 'year:2014'])(doe_paper))
self.assertFalse(filter_paper(doe_paper, self.assertFalse(get_paper_filter(['author:doee', 'year:2014'])(doe_paper))
['author:doe', 'year:2014']))
self.assertFalse(filter_paper(doe_paper,
['author:doee', 'year:2014']))
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save