Allow running queries for papers to be extracted from

This commit is contained in:
Marty Oehme 2022-12-24 14:23:07 +01:00
parent e201f6cf5f
commit 488dd0eb41
Signed by: Marty
GPG Key ID: 73BA40D5AFAF49C9
1 changed files with 55 additions and 20 deletions

View File

@ -11,6 +11,8 @@ from pubs.events import DocAddEvent, NoteEvent
from pubs import repo from pubs import repo
from pubs.utils import resolve_citekey_list from pubs.utils import resolve_citekey_list
from pubs.content import check_file, read_text_file, write_file from pubs.content import check_file, read_text_file, write_file
from pubs.query import get_paper_filter
class ExtractPlugin(PapersPlugin): class ExtractPlugin(PapersPlugin):
"""Extract annotations from any pdf document. """Extract annotations from any pdf document.
@ -35,10 +37,6 @@ class ExtractPlugin(PapersPlugin):
self.pubsdir = os.path.expanduser(conf["main"]["pubsdir"]) self.pubsdir = os.path.expanduser(conf["main"]["pubsdir"])
self.broker = self.repository.databroker self.broker = self.repository.databroker
# TODO implement custom annotation formatting, akin to main config citekey format
# e.g. `> [{page}] {annotation}`
# or `:: {annotation} :: {page} ::`
# and so on
self.on_import = conf["plugins"].get("extract", {}).get("on_import", False) self.on_import = conf["plugins"].get("extract", {}).get("on_import", False)
self.minimum_similarity = float( self.minimum_similarity = float(
conf["plugins"].get("extract", {}).get("minimum_similarity", 0.75) conf["plugins"].get("extract", {}).get("minimum_similarity", 0.75)
@ -56,11 +54,6 @@ class ExtractPlugin(PapersPlugin):
"""Allow the usage of the pubs extract subcommand""" """Allow the usage of the pubs extract subcommand"""
# TODO option for ignoring missing documents or erroring. # TODO option for ignoring missing documents or erroring.
extract_parser = subparsers.add_parser(self.name, help=self.description) extract_parser = subparsers.add_parser(self.name, help=self.description)
extract_parser.add_argument(
"citekeys",
nargs=argparse.REMAINDER,
help="citekey(s) of the documents to extract from",
)
extract_parser.add_argument( extract_parser.add_argument(
"-w", "-w",
"--write", "--write",
@ -75,29 +68,58 @@ class ExtractPlugin(PapersPlugin):
action="store_true", action="store_true",
default=False, default=False,
) )
extract_parser.add_argument(
"-q",
"--query",
help="Query library instead of providing individual citekeys. For query help see pubs list command.",
action="store_true",
default=None,
dest="is_query",
)
extract_parser.add_argument(
"-i",
"--ignore-case",
action="store_false",
default=None,
dest="case_sensitive",
help="When using query mode, perform case insensitive search.",
)
extract_parser.add_argument(
"-I",
"--force-case",
action="store_true",
dest="case_sensitive",
help="When using query mode, perform case sensitive search.",
)
extract_parser.add_argument(
"--strict",
action="store_true",
default=False,
help="Force strict unicode comparison of query.",
)
extract_parser.add_argument(
"query",
nargs=argparse.REMAINDER,
help="Citekey(s)/query for the documents to extract from.",
)
extract_parser.set_defaults(func=self.command) extract_parser.set_defaults(func=self.command)
def command(self, conf, args): def command(self, conf, args):
"""Run the annotation extraction command.""" """Run the annotation extraction command."""
citekeys = resolve_citekey_list( papers = self._gather_papers(conf, args)
self.repository, conf, args.citekeys, ui=self.ui, exit_on_fail=True all_annotations = self.extract(papers)
)
if not citekeys:
return
all_annotations = self.extract(citekeys)
if args.write: if args.write:
self._to_notes(all_annotations, self.note_extension, args.edit) self._to_notes(all_annotations, self.note_extension, args.edit)
else: else:
self._to_stdout(all_annotations) self._to_stdout(all_annotations)
self.repository.close() self.repository.close()
def extract(self, citekeys): def extract(self, papers):
"""Extracts annotations from citekeys. """Extracts annotations from citekeys.
Returns all annotations belonging to the papers that Returns all annotations belonging to the papers that
are described by the citekeys passed in. are described by the citekeys passed in.
""" """
papers = self._gather_papers(citekeys)
papers_annotated = [] papers_annotated = []
for paper in papers: for paper in papers:
file = self._get_file(paper) file = self._get_file(paper)
@ -107,15 +129,28 @@ class ExtractPlugin(PapersPlugin):
self.ui.error(f"Document {file} is broken: {e}") self.ui.error(f"Document {file} is broken: {e}")
return papers_annotated return papers_annotated
def _gather_papers(self, citekeys): def _gather_papers(self, conf, args):
"""Get all papers for citekeys. """Get all papers for citekeys.
Returns all Paper objects described by the citekeys Returns all Paper objects described by the citekeys
passed in. passed in.
""" """
papers = [] papers = []
for key in citekeys: if not args.is_query:
papers.append(self.repository.pull_paper(key)) citekeys = resolve_citekey_list(
self.repository, conf, args.query, ui=self.ui, exit_on_fail=True
)
for key in citekeys:
papers.append(self.repository.pull_paper(key))
else:
papers = filter(
get_paper_filter(
args.query,
case_sensitive=args.case_sensitive,
strict=args.strict,
),
self.repository.all_papers(),
)
return papers return papers
def _get_file(self, paper): def _get_file(self, paper):