diff --git a/extract/extract.py b/extract/extract.py index 88b6b8d..f25bd73 100644 --- a/extract/extract.py +++ b/extract/extract.py @@ -8,10 +8,12 @@ import Levenshtein from pubs.plugins import PapersPlugin from pubs.events import DocAddEvent, NoteEvent -from pubs import repo +from pubs import repo, pretty from pubs.utils import resolve_citekey_list from pubs.content import check_file, read_text_file, write_file +from pubs.query import get_paper_filter +CONFIRMATION_PAPER_THRESHOLD=5 class ExtractPlugin(PapersPlugin): """Extract annotations from any pdf document. @@ -36,32 +38,18 @@ class ExtractPlugin(PapersPlugin): self.pubsdir = os.path.expanduser(conf["main"]["pubsdir"]) self.broker = self.repository.databroker - # TODO implement custom annotation formatting, akin to main config citekey format - # e.g. `> [{page}] {annotation}` - # or `:: {annotation} :: {page} ::` - # and so on - self.on_import = conf["plugins"].get("extract", {}).get("on_import", False) - self.minimum_similarity = float( - conf["plugins"].get("extract", {}).get("minimum_similarity", 0.75) - ) - self.formatting = ( - conf["plugins"] - .get("extract", {}) - .get( - "formatting", - "{newline}{quote_begin}> {quote} {quote_end}[{page}]{note_begin}{newline}Note: {note}{note_end}", - ) + settings = conf["plugins"].get("extract", {}) + self.on_import = settings.get("on_import", False) + self.minimum_similarity = float(settings.get("minimum_similarity", 0.75)) + self.formatting = settings.get( + "formatting", + "{newline}{quote_begin}> {quote} {quote_end}[{page}]{note_begin}{newline}Note: {note}{note_end}", ) - def update_parser(self, subparsers, conf): + def update_parser(self, subparsers, _): """Allow the usage of the pubs extract subcommand""" # TODO option for ignoring missing documents or erroring. extract_parser = subparsers.add_parser(self.name, help=self.description) - extract_parser.add_argument( - "citekeys", - nargs=argparse.REMAINDER, - help="citekey(s) of the documents to extract from", - ) extract_parser.add_argument( "-w", "--write", @@ -76,29 +64,58 @@ class ExtractPlugin(PapersPlugin): action="store_true", default=False, ) + extract_parser.add_argument( + "-q", + "--query", + help="Query library instead of providing individual citekeys. For query help see pubs list command.", + action="store_true", + default=None, + dest="is_query", + ) + extract_parser.add_argument( + "-i", + "--ignore-case", + action="store_false", + default=None, + dest="case_sensitive", + help="When using query mode, perform case insensitive search.", + ) + extract_parser.add_argument( + "-I", + "--force-case", + action="store_true", + dest="case_sensitive", + help="When using query mode, perform case sensitive search.", + ) + extract_parser.add_argument( + "--strict", + action="store_true", + default=False, + help="Force strict unicode comparison of query.", + ) + extract_parser.add_argument( + "query", + nargs=argparse.REMAINDER, + help="Citekey(s)/query for the documents to extract from.", + ) extract_parser.set_defaults(func=self.command) def command(self, conf, args): """Run the annotation extraction command.""" - citekeys = resolve_citekey_list( - self.repository, conf, args.citekeys, ui=self.ui, exit_on_fail=True - ) - if not citekeys: - return - all_annotations = self.extract(citekeys) + papers = self._gather_papers(conf, args) + all_annotations = self.extract(papers) if args.write: self._to_notes(all_annotations, self.note_extension, args.edit) else: self._to_stdout(all_annotations) self.repository.close() - def extract(self, citekeys): + def extract(self, papers): """Extracts annotations from citekeys. Returns all annotations belonging to the papers that are described by the citekeys passed in. """ - papers = self._gather_papers(citekeys) papers_annotated = [] for paper in papers: file = self._get_file(paper) @@ -108,15 +125,44 @@ class ExtractPlugin(PapersPlugin): self.ui.error(f"Document {file} is broken: {e}") return papers_annotated - def _gather_papers(self, citekeys): + def _gather_papers(self, conf, args): """Get all papers for citekeys. Returns all Paper objects described by the citekeys passed in. """ papers = [] - for key in citekeys: - papers.append(self.repository.pull_paper(key)) + if not args.is_query: + keys = resolve_citekey_list( + self.repository, conf, args.query, ui=self.ui, exit_on_fail=True + ) + if not keys: + return [] + for key in keys: + papers.append(self.repository.pull_paper(key)) + else: + papers = list( + filter( + get_paper_filter( + args.query, + case_sensitive=args.case_sensitive, + strict=args.strict, + ), + self.repository.all_papers(), + ) + ) + if len(papers) > CONFIRMATION_PAPER_THRESHOLD: + self.ui.message( + "\n".join( + pretty.paper_oneliner( + p, citekey_only=False, max_authors=conf["main"]["max_authors"] + ) + for p in papers + ) + ) + self.ui.input_yn( + question=f"Extract annotations for these papers?", default="y" + ) return papers def _get_file(self, paper): @@ -207,9 +253,9 @@ class ExtractPlugin(PapersPlugin): paper = contents[0] annotations = contents[1] if annotations: - output += f"------ {paper.citekey} ------\n\n" + output += f"------ {paper.citekey} ------\n" for annot in annotations: - output += f"{annot}\n\n" + output += f"{annot}\n" output += "\n" print(output)