diff --git a/prophet/app.py b/prophet/app.py
index 18d45e7..0bbcf64 100644
--- a/prophet/app.py
+++ b/prophet/app.py
@@ -1,7 +1,5 @@
import json
-import pickle
from datetime import datetime
-from pathlib import Path
import feedparser
from fastapi import FastAPI
@@ -10,21 +8,21 @@ from fastapi.responses import HTMLResponse
from fastapi_utils.tasks import repeat_every
from prophet.domain.improvement import Improvement
+from prophet.domain.improvement_repo import IImprovementRepo
from prophet.domain.original import Original
+from prophet.infra.improvement_pickle_repo import ImprovementPickleRepo
from prophet.llm import LLMClient
BEE_FEED = "https://babylonbee.com/feed"
BEE_FEED_TEST = "test/resources/feed_short.atom" # NOTE: Switch out when done testing
-PICKLE_DIR = "/tmp/pollenprophet"
-
REFRESH_PERIOD = 3600 # between fetching articles, in seconds
llm: LLMClient = LLMClient()
+repo: IImprovementRepo = ImprovementPickleRepo()
def grab_latest_originals() -> list[Original]:
- # TODO: Implement skipping any we already have
feed: feedparser.FeedParserDict = feedparser.parse(BEE_FEED) # noqa: F841
results: list[Original] = []
for entry in feed.entries:
@@ -38,39 +36,11 @@ def grab_latest_originals() -> list[Original]:
return results
-def save_new_improvements(improvements: list[Improvement]) -> None:
- save_dir = Path(PICKLE_DIR)
- save_dir.mkdir(parents=True, exist_ok=True)
- for imp in improvements:
- fname = save_dir / f"{int(imp.original.date.timestamp())}_{imp.id}"
- try:
- with open(fname, "wb") as f:
- pickle.dump(imp, f)
- print(f"Saved {fname}")
- except Exception as e:
- print(f"Error saving file {fname}: {e}")
-
-
-def load_existing_improvements() -> list[Improvement]:
- improvements: list[Improvement] = []
- for fname in Path(PICKLE_DIR).iterdir():
- if not fname.is_file():
- continue
-
- try:
- with open(fname, "rb") as f:
- obj: Improvement = pickle.load(f)
- improvements.append(obj)
- except FileNotFoundError as e:
- print(f"Error loading file {fname}: {e}")
- return improvements
-
-
def keep_only_new_originals(
additional: list[Original], existing: list[Original] | None = None
):
if not existing:
- existing = [e.original for e in load_existing_improvements()]
+ existing = [e.original for e in repo.get_all()]
existing_hashes = set([e.id for e in existing])
@@ -128,7 +98,7 @@ def improve_summary(original_title: str, new_title: str, original_summary: str):
def refresh_articles():
adding = keep_only_new_originals(grab_latest_originals())
improved = improve_originals(adding)
- save_new_improvements(improved)
+ repo.add_all(improved)
print(f"Updated articles. Added {len(improved)} new ones.")
@@ -141,7 +111,7 @@ async def fetch_update():
## HTML (& hyperdata) responses
@app.get("/improvements", response_class=HTMLResponse)
def list_improvements():
- improved = load_existing_improvements()
+ improved = repo.get_all()
return (
""" """
+ "\n".join(
@@ -160,7 +130,7 @@ def list_improvements():
@app.get("/originals", response_class=HTMLResponse)
def list_originals():
- improved = load_existing_improvements()
+ improved = repo.get_all()
return (
""" """
+ "\n".join(
@@ -228,7 +198,7 @@ if __name__ == "__main__":
# save_new_improvements(improved)
# migrate to newer version
- improved = load_existing_improvements()
+ improved = repo.get_all()
for imp in improved:
imp.original.__post_init__()
print(f"Old Title: {imp.original.title}")
@@ -239,4 +209,4 @@ if __name__ == "__main__":
print(f"Summary: {imp.summary}")
print("-" * 50)
- save_new_improvements(improved)
+ repo.add_all(improved)
diff --git a/prophet/domain/improvement_repo.py b/prophet/domain/improvement_repo.py
index 57e7083..7081e34 100644
--- a/prophet/domain/improvement_repo.py
+++ b/prophet/domain/improvement_repo.py
@@ -1,13 +1,19 @@
-
from typing import Protocol
from prophet.domain.improvement import Improvement
+class ImprovementNotFoundError(Exception):
+ pass
+
+
class IImprovementRepo(Protocol):
def add(self, improvement: Improvement) -> None:
raise NotImplementedError
+ def add_all(self, improvements: list[Improvement]) -> None:
+ raise NotImplementedError
+
def get(self, id: str) -> Improvement:
raise NotImplementedError
diff --git a/prophet/infra/improvement_pickle_repo.py b/prophet/infra/improvement_pickle_repo.py
new file mode 100644
index 0000000..7c59ec8
--- /dev/null
+++ b/prophet/infra/improvement_pickle_repo.py
@@ -0,0 +1,49 @@
+import pickle
+from pathlib import Path
+from typing import override
+
+from prophet.domain.improvement import Improvement
+from prophet.domain.improvement_repo import IImprovementRepo, ImprovementNotFoundError
+
+
+class ImprovementPickleRepo(IImprovementRepo):
+ pickle_dir: Path
+
+ def __init__(self, pickle_dir: str | Path = "/tmp/pollenprophet") -> None:
+ self.pickle_dir = Path(pickle_dir)
+ self.pickle_dir.mkdir(parents=True, exist_ok=True)
+
+ @override
+ def add(self, improvement: Improvement) -> None:
+ fname = self.pickle_dir / improvement.id
+ try:
+ with open(fname, "wb") as f:
+ pickle.dump(improvement, f)
+ print(f"Saved {fname}")
+ except FileExistsError:
+ print(f"Error saving file {fname}")
+
+ @override
+ def add_all(self, improvements: list[Improvement]) -> None:
+ for imp in improvements:
+ self.add(imp)
+
+ @override
+ def get(self, id: str) -> Improvement:
+ try:
+ with open(self.pickle_dir / id, "rb") as f:
+ improvement: Improvement = pickle.load(f)
+ except FileNotFoundError:
+ raise ImprovementNotFoundError
+
+ return improvement
+
+ @override
+ def get_all(self) -> list[Improvement]:
+ improvements: list[Improvement] = []
+ for fname in Path(self.pickle_dir).iterdir():
+ try:
+ improvements.append(self.get(fname.name))
+ except ImprovementNotFoundError:
+ print(f"File {fname.absolute()} is not a valid Improvement.")
+ return improvements