feat: Loop through all chosen extractors

This commit is contained in:
Marty Oehme 2024-01-23 09:10:42 +01:00
parent f477deea7c
commit 629932a5e8
Signed by: Marty
GPG key ID: EDBF2ED917B2EF6A
2 changed files with 37 additions and 27 deletions

View file

@ -7,7 +7,7 @@ import papis.notes
import papis.strings import papis.strings
from papis.document import Document from papis.document import Document
from papis_extract import exporter, extractor from papis_extract import exporter, extraction
from papis_extract.annotation import Annotation from papis_extract.annotation import Annotation
from papis_extract.formatter import Formatter, formatters from papis_extract.formatter import Formatter, formatters
@ -55,10 +55,10 @@ papis.config.register_default_settings(DEFAULT_OPTIONS)
"-e", "-e",
"extractors", "extractors",
type=click.Choice( type=click.Choice(
list(extractor.extractors.keys()), list(extraction.extractors.keys()),
case_sensitive=False, case_sensitive=False,
), ),
default=list(extractor.extractors.keys()), default=list(extraction.extractors.keys()),
multiple=True, multiple=True,
help="Choose an extractor to apply to the selected documents.", help="Choose an extractor to apply to the selected documents.",
) )
@ -76,7 +76,7 @@ def main(
doc_folder: str, doc_folder: str,
manual: bool, manual: bool,
write: bool, write: bool,
extractors: str, extractors: list[str],
template: str, template: str,
git: bool, git: bool,
force: bool, force: bool,
@ -99,35 +99,46 @@ def main(
logger.warning(papis.strings.no_documents_retrieved_message) logger.warning(papis.strings.no_documents_retrieved_message)
return return
print(extractors)
formatter = formatters.get(template) formatter = formatters.get(template)
run(documents, edit=manual, write=write, git=git, formatter=formatter, force=force) run(
documents,
edit=manual,
write=write,
git=git,
formatter=formatter,
extractors=[extraction.extractors.get(e) for e in extractors],
force=force,
)
def run( def run(
documents: list[Document], documents: list[Document],
formatter: Formatter | None, formatter: Formatter | None,
extractors: list[extraction.Extractor | None],
edit: bool = False, edit: bool = False,
write: bool = False, write: bool = False,
git: bool = False, git: bool = False,
force: bool = False, force: bool = False,
) -> None: ) -> None:
for doc in documents: for doc in documents:
annotations: list[Annotation] = extractor.start(doc) for ext in extractors:
if not ext:
continue
if write: annotations: list[Annotation] = extraction.start(ext, doc)
exporter.to_notes( if write:
formatter=formatter or formatters["markdown-atx"], exporter.to_notes(
document=doc, formatter=formatter or formatters["markdown-atx"],
annotations=annotations, document=doc,
edit=edit, annotations=annotations,
git=git, edit=edit,
force=force, git=git,
) force=force,
else: )
exporter.to_stdout( else:
formatter=formatter or formatters["markdown"], exporter.to_stdout(
document=doc, formatter=formatter or formatters["markdown"],
annotations=annotations, document=doc,
) annotations=annotations,
)

View file

@ -23,6 +23,7 @@ class Extractor(Protocol):
def start( def start(
extractor: Extractor,
document: Document, document: Document,
) -> list[Annotation]: ) -> list[Annotation]:
"""Extract all annotations from passed documents. """Extract all annotations from passed documents.
@ -30,19 +31,17 @@ def start(
Returns all annotations contained in the papis Returns all annotations contained in the papis
documents passed in. documents passed in.
""" """
pdf_extractor: Extractor = PdfExtractor()
annotations: list[Annotation] = [] annotations: list[Annotation] = []
file_available: bool = False file_available: bool = False
for file in document.get_files(): for file in document.get_files():
fname = Path(file) fname = Path(file)
if not pdf_extractor.can_process(fname): if not extractor.can_process(fname):
continue continue
file_available = True file_available = True
try: try:
annotations.extend(pdf_extractor.run(fname)) annotations.extend(extractor.run(fname))
except fitz.FileDataError as e: except fitz.FileDataError as e:
print(f"File structure errors for {file}.\n{e}") print(f"File structure errors for {file}.\n{e}")