diff --git a/pubs/commands/list_cmd.py b/pubs/commands/list_cmd.py index f3fa7b9..de44549 100644 --- a/pubs/commands/list_cmd.py +++ b/pubs/commands/list_cmd.py @@ -27,7 +27,7 @@ def parser(subparsers, conf): help='list only pubs without attached documents.') parser.add_argument('query', nargs='*', help='Paper query ("author:Einstein", "title:learning",' - '"year:2000" or "tags:math")') + '"year:2000", "year:2000-2010, or "tags:math")') return parser @@ -108,6 +108,44 @@ class TagFilter(QueryFilter): return any([self.query in self._lower(t) for t in paper.tags]) +class YearFilter(QueryFilter): + """Note: a query like `year:` or `year:-` would match any paper + whose year field is set and can be converted to an int. + """ + + def __init__(self, query, case_sensitive=None): + split = query.split('-') + self.start = self._str_to_year(split[0]) + if len(split) == 1: + self.end = self.start + elif len(split) == 2: + self.end = self._str_to_year(split[1]) + if (len(split) > 2 or ( + self.start is not None and + self.end is not None and + self.start > self.end)): + raise ValueError('Invalid year range "{}"'.format(query)) + + def __call__(self, paper): + """Only checks within last names.""" + if 'year' not in paper.bibdata: + return False + else: + try: + year = int(paper.bibdata['year']) + return ((self.start is None or year >= self.start) and + (self.end is None or year <= self.end)) + except ValueError: + return False + + @staticmethod + def _str_to_year(s): + try: + return int(s) if s else None + except ValueError: + raise ValueError('Invalid year "{}"'.format(s)) + + def _get_field_value(query_block): split_block = query_block.split(':') if len(split_block) != 2: diff --git a/tests/test_queries.py b/tests/test_queries.py index 1e52af0..0edf6a3 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -3,6 +3,7 @@ import unittest import dotdot from pubs.commands.list_cmd import (AuthorFilter, FieldFilter, + YearFilter, _query_block_to_filter, get_paper_filter, InvalidQuery) @@ -46,6 +47,32 @@ class TestCheckTag(unittest.TestCase): pass +class TestCheckYear(unittest.TestCase): + + def test_single_year(self): + self.assertTrue(YearFilter('2013')(doe_paper)) + self.assertFalse(YearFilter('2014')(doe_paper)) + + def test_before_year(self): + self.assertTrue(YearFilter('-2013')(doe_paper)) + self.assertTrue(YearFilter('-2014')(doe_paper)) + self.assertFalse(YearFilter('-2012')(doe_paper)) + + def test_after_year(self): + self.assertTrue(YearFilter('2013-')(doe_paper)) + self.assertTrue(YearFilter('2012-')(doe_paper)) + self.assertFalse(YearFilter('2014-')(doe_paper)) + + def test_year_range(self): + self.assertTrue(YearFilter('')(doe_paper)) + self.assertTrue(YearFilter('-')(doe_paper)) + self.assertTrue(YearFilter('2013-2013')(doe_paper)) + self.assertTrue(YearFilter('2012-2014')(doe_paper)) + self.assertFalse(YearFilter('2014-2015')(doe_paper)) + with self.assertRaises(ValueError): + YearFilter('2015-2014')(doe_paper) + + class TestCheckField(unittest.TestCase): def test_match_case(self):