Move pickle functions to PickleRepo

This commit is contained in:
Marty Oehme 2025-06-06 12:32:31 +02:00
parent c216c2d0d7
commit 121e678c7d
Signed by: Marty
GPG key ID: 4E535BC19C61886E
3 changed files with 65 additions and 40 deletions

View file

@ -1,7 +1,5 @@
import json
import pickle
from datetime import datetime
from pathlib import Path
import feedparser
from fastapi import FastAPI
@ -10,21 +8,21 @@ from fastapi.responses import HTMLResponse
from fastapi_utils.tasks import repeat_every
from prophet.domain.improvement import Improvement
from prophet.domain.improvement_repo import IImprovementRepo
from prophet.domain.original import Original
from prophet.infra.improvement_pickle_repo import ImprovementPickleRepo
from prophet.llm import LLMClient
BEE_FEED = "https://babylonbee.com/feed"
BEE_FEED_TEST = "test/resources/feed_short.atom" # NOTE: Switch out when done testing
PICKLE_DIR = "/tmp/pollenprophet"
REFRESH_PERIOD = 3600 # between fetching articles, in seconds
llm: LLMClient = LLMClient()
repo: IImprovementRepo = ImprovementPickleRepo()
def grab_latest_originals() -> list[Original]:
# TODO: Implement skipping any we already have
feed: feedparser.FeedParserDict = feedparser.parse(BEE_FEED) # noqa: F841
results: list[Original] = []
for entry in feed.entries:
@ -38,39 +36,11 @@ def grab_latest_originals() -> list[Original]:
return results
def save_new_improvements(improvements: list[Improvement]) -> None:
save_dir = Path(PICKLE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)
for imp in improvements:
fname = save_dir / f"{int(imp.original.date.timestamp())}_{imp.id}"
try:
with open(fname, "wb") as f:
pickle.dump(imp, f)
print(f"Saved {fname}")
except Exception as e:
print(f"Error saving file {fname}: {e}")
def load_existing_improvements() -> list[Improvement]:
improvements: list[Improvement] = []
for fname in Path(PICKLE_DIR).iterdir():
if not fname.is_file():
continue
try:
with open(fname, "rb") as f:
obj: Improvement = pickle.load(f)
improvements.append(obj)
except FileNotFoundError as e:
print(f"Error loading file {fname}: {e}")
return improvements
def keep_only_new_originals(
additional: list[Original], existing: list[Original] | None = None
):
if not existing:
existing = [e.original for e in load_existing_improvements()]
existing = [e.original for e in repo.get_all()]
existing_hashes = set([e.id for e in existing])
@ -128,7 +98,7 @@ def improve_summary(original_title: str, new_title: str, original_summary: str):
def refresh_articles():
adding = keep_only_new_originals(grab_latest_originals())
improved = improve_originals(adding)
save_new_improvements(improved)
repo.add_all(improved)
print(f"Updated articles. Added {len(improved)} new ones.")
@ -141,7 +111,7 @@ async def fetch_update():
## HTML (& hyperdata) responses
@app.get("/improvements", response_class=HTMLResponse)
def list_improvements():
improved = load_existing_improvements()
improved = repo.get_all()
return (
"""<button hx-get="/originals" hx-target="#content">Originals</button> """
+ "\n".join(
@ -160,7 +130,7 @@ def list_improvements():
@app.get("/originals", response_class=HTMLResponse)
def list_originals():
improved = load_existing_improvements()
improved = repo.get_all()
return (
"""<button hx-get="/improvements" hx-target="#content">Improvements</button> """
+ "\n".join(
@ -228,7 +198,7 @@ if __name__ == "__main__":
# save_new_improvements(improved)
# migrate to newer version
improved = load_existing_improvements()
improved = repo.get_all()
for imp in improved:
imp.original.__post_init__()
print(f"Old Title: {imp.original.title}")
@ -239,4 +209,4 @@ if __name__ == "__main__":
print(f"Summary: {imp.summary}")
print("-" * 50)
save_new_improvements(improved)
repo.add_all(improved)

View file

@ -1,13 +1,19 @@
from typing import Protocol
from prophet.domain.improvement import Improvement
class ImprovementNotFoundError(Exception):
pass
class IImprovementRepo(Protocol):
def add(self, improvement: Improvement) -> None:
raise NotImplementedError
def add_all(self, improvements: list[Improvement]) -> None:
raise NotImplementedError
def get(self, id: str) -> Improvement:
raise NotImplementedError

View file

@ -0,0 +1,49 @@
import pickle
from pathlib import Path
from typing import override
from prophet.domain.improvement import Improvement
from prophet.domain.improvement_repo import IImprovementRepo, ImprovementNotFoundError
class ImprovementPickleRepo(IImprovementRepo):
pickle_dir: Path
def __init__(self, pickle_dir: str | Path = "/tmp/pollenprophet") -> None:
self.pickle_dir = Path(pickle_dir)
self.pickle_dir.mkdir(parents=True, exist_ok=True)
@override
def add(self, improvement: Improvement) -> None:
fname = self.pickle_dir / improvement.id
try:
with open(fname, "wb") as f:
pickle.dump(improvement, f)
print(f"Saved {fname}")
except FileExistsError:
print(f"Error saving file {fname}")
@override
def add_all(self, improvements: list[Improvement]) -> None:
for imp in improvements:
self.add(imp)
@override
def get(self, id: str) -> Improvement:
try:
with open(self.pickle_dir / id, "rb") as f:
improvement: Improvement = pickle.load(f)
except FileNotFoundError:
raise ImprovementNotFoundError
return improvement
@override
def get_all(self) -> list[Improvement]:
improvements: list[Improvement] = []
for fname in Path(self.pickle_dir).iterdir():
try:
improvements.append(self.get(fname.name))
except ImprovementNotFoundError:
print(f"File {fname.absolute()} is not a valid Improvement.")
return improvements