diff --git a/prophet/app.py b/prophet/app.py index 18d45e7..0bbcf64 100644 --- a/prophet/app.py +++ b/prophet/app.py @@ -1,7 +1,5 @@ import json -import pickle from datetime import datetime -from pathlib import Path import feedparser from fastapi import FastAPI @@ -10,21 +8,21 @@ from fastapi.responses import HTMLResponse from fastapi_utils.tasks import repeat_every from prophet.domain.improvement import Improvement +from prophet.domain.improvement_repo import IImprovementRepo from prophet.domain.original import Original +from prophet.infra.improvement_pickle_repo import ImprovementPickleRepo from prophet.llm import LLMClient BEE_FEED = "https://babylonbee.com/feed" BEE_FEED_TEST = "test/resources/feed_short.atom" # NOTE: Switch out when done testing -PICKLE_DIR = "/tmp/pollenprophet" - REFRESH_PERIOD = 3600 # between fetching articles, in seconds llm: LLMClient = LLMClient() +repo: IImprovementRepo = ImprovementPickleRepo() def grab_latest_originals() -> list[Original]: - # TODO: Implement skipping any we already have feed: feedparser.FeedParserDict = feedparser.parse(BEE_FEED) # noqa: F841 results: list[Original] = [] for entry in feed.entries: @@ -38,39 +36,11 @@ def grab_latest_originals() -> list[Original]: return results -def save_new_improvements(improvements: list[Improvement]) -> None: - save_dir = Path(PICKLE_DIR) - save_dir.mkdir(parents=True, exist_ok=True) - for imp in improvements: - fname = save_dir / f"{int(imp.original.date.timestamp())}_{imp.id}" - try: - with open(fname, "wb") as f: - pickle.dump(imp, f) - print(f"Saved {fname}") - except Exception as e: - print(f"Error saving file {fname}: {e}") - - -def load_existing_improvements() -> list[Improvement]: - improvements: list[Improvement] = [] - for fname in Path(PICKLE_DIR).iterdir(): - if not fname.is_file(): - continue - - try: - with open(fname, "rb") as f: - obj: Improvement = pickle.load(f) - improvements.append(obj) - except FileNotFoundError as e: - print(f"Error loading file {fname}: {e}") - return improvements - - def keep_only_new_originals( additional: list[Original], existing: list[Original] | None = None ): if not existing: - existing = [e.original for e in load_existing_improvements()] + existing = [e.original for e in repo.get_all()] existing_hashes = set([e.id for e in existing]) @@ -128,7 +98,7 @@ def improve_summary(original_title: str, new_title: str, original_summary: str): def refresh_articles(): adding = keep_only_new_originals(grab_latest_originals()) improved = improve_originals(adding) - save_new_improvements(improved) + repo.add_all(improved) print(f"Updated articles. Added {len(improved)} new ones.") @@ -141,7 +111,7 @@ async def fetch_update(): ## HTML (& hyperdata) responses @app.get("/improvements", response_class=HTMLResponse) def list_improvements(): - improved = load_existing_improvements() + improved = repo.get_all() return ( """ """ + "\n".join( @@ -160,7 +130,7 @@ def list_improvements(): @app.get("/originals", response_class=HTMLResponse) def list_originals(): - improved = load_existing_improvements() + improved = repo.get_all() return ( """ """ + "\n".join( @@ -228,7 +198,7 @@ if __name__ == "__main__": # save_new_improvements(improved) # migrate to newer version - improved = load_existing_improvements() + improved = repo.get_all() for imp in improved: imp.original.__post_init__() print(f"Old Title: {imp.original.title}") @@ -239,4 +209,4 @@ if __name__ == "__main__": print(f"Summary: {imp.summary}") print("-" * 50) - save_new_improvements(improved) + repo.add_all(improved) diff --git a/prophet/domain/improvement_repo.py b/prophet/domain/improvement_repo.py index 57e7083..7081e34 100644 --- a/prophet/domain/improvement_repo.py +++ b/prophet/domain/improvement_repo.py @@ -1,13 +1,19 @@ - from typing import Protocol from prophet.domain.improvement import Improvement +class ImprovementNotFoundError(Exception): + pass + + class IImprovementRepo(Protocol): def add(self, improvement: Improvement) -> None: raise NotImplementedError + def add_all(self, improvements: list[Improvement]) -> None: + raise NotImplementedError + def get(self, id: str) -> Improvement: raise NotImplementedError diff --git a/prophet/infra/improvement_pickle_repo.py b/prophet/infra/improvement_pickle_repo.py new file mode 100644 index 0000000..7c59ec8 --- /dev/null +++ b/prophet/infra/improvement_pickle_repo.py @@ -0,0 +1,49 @@ +import pickle +from pathlib import Path +from typing import override + +from prophet.domain.improvement import Improvement +from prophet.domain.improvement_repo import IImprovementRepo, ImprovementNotFoundError + + +class ImprovementPickleRepo(IImprovementRepo): + pickle_dir: Path + + def __init__(self, pickle_dir: str | Path = "/tmp/pollenprophet") -> None: + self.pickle_dir = Path(pickle_dir) + self.pickle_dir.mkdir(parents=True, exist_ok=True) + + @override + def add(self, improvement: Improvement) -> None: + fname = self.pickle_dir / improvement.id + try: + with open(fname, "wb") as f: + pickle.dump(improvement, f) + print(f"Saved {fname}") + except FileExistsError: + print(f"Error saving file {fname}") + + @override + def add_all(self, improvements: list[Improvement]) -> None: + for imp in improvements: + self.add(imp) + + @override + def get(self, id: str) -> Improvement: + try: + with open(self.pickle_dir / id, "rb") as f: + improvement: Improvement = pickle.load(f) + except FileNotFoundError: + raise ImprovementNotFoundError + + return improvement + + @override + def get_all(self) -> list[Improvement]: + improvements: list[Improvement] = [] + for fname in Path(self.pickle_dir).iterdir(): + try: + improvements.append(self.get(fname.name)) + except ImprovementNotFoundError: + print(f"File {fname.absolute()} is not a valid Improvement.") + return improvements