Move pickle functions to PickleRepo
This commit is contained in:
parent
c216c2d0d7
commit
121e678c7d
3 changed files with 65 additions and 40 deletions
|
|
@ -1,7 +1,5 @@
|
||||||
import json
|
import json
|
||||||
import pickle
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import feedparser
|
import feedparser
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
@ -10,21 +8,21 @@ from fastapi.responses import HTMLResponse
|
||||||
from fastapi_utils.tasks import repeat_every
|
from fastapi_utils.tasks import repeat_every
|
||||||
|
|
||||||
from prophet.domain.improvement import Improvement
|
from prophet.domain.improvement import Improvement
|
||||||
|
from prophet.domain.improvement_repo import IImprovementRepo
|
||||||
from prophet.domain.original import Original
|
from prophet.domain.original import Original
|
||||||
|
from prophet.infra.improvement_pickle_repo import ImprovementPickleRepo
|
||||||
from prophet.llm import LLMClient
|
from prophet.llm import LLMClient
|
||||||
|
|
||||||
BEE_FEED = "https://babylonbee.com/feed"
|
BEE_FEED = "https://babylonbee.com/feed"
|
||||||
BEE_FEED_TEST = "test/resources/feed_short.atom" # NOTE: Switch out when done testing
|
BEE_FEED_TEST = "test/resources/feed_short.atom" # NOTE: Switch out when done testing
|
||||||
|
|
||||||
PICKLE_DIR = "/tmp/pollenprophet"
|
|
||||||
|
|
||||||
REFRESH_PERIOD = 3600 # between fetching articles, in seconds
|
REFRESH_PERIOD = 3600 # between fetching articles, in seconds
|
||||||
|
|
||||||
llm: LLMClient = LLMClient()
|
llm: LLMClient = LLMClient()
|
||||||
|
repo: IImprovementRepo = ImprovementPickleRepo()
|
||||||
|
|
||||||
|
|
||||||
def grab_latest_originals() -> list[Original]:
|
def grab_latest_originals() -> list[Original]:
|
||||||
# TODO: Implement skipping any we already have
|
|
||||||
feed: feedparser.FeedParserDict = feedparser.parse(BEE_FEED) # noqa: F841
|
feed: feedparser.FeedParserDict = feedparser.parse(BEE_FEED) # noqa: F841
|
||||||
results: list[Original] = []
|
results: list[Original] = []
|
||||||
for entry in feed.entries:
|
for entry in feed.entries:
|
||||||
|
|
@ -38,39 +36,11 @@ def grab_latest_originals() -> list[Original]:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def save_new_improvements(improvements: list[Improvement]) -> None:
|
|
||||||
save_dir = Path(PICKLE_DIR)
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
for imp in improvements:
|
|
||||||
fname = save_dir / f"{int(imp.original.date.timestamp())}_{imp.id}"
|
|
||||||
try:
|
|
||||||
with open(fname, "wb") as f:
|
|
||||||
pickle.dump(imp, f)
|
|
||||||
print(f"Saved {fname}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving file {fname}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_existing_improvements() -> list[Improvement]:
|
|
||||||
improvements: list[Improvement] = []
|
|
||||||
for fname in Path(PICKLE_DIR).iterdir():
|
|
||||||
if not fname.is_file():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(fname, "rb") as f:
|
|
||||||
obj: Improvement = pickle.load(f)
|
|
||||||
improvements.append(obj)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
print(f"Error loading file {fname}: {e}")
|
|
||||||
return improvements
|
|
||||||
|
|
||||||
|
|
||||||
def keep_only_new_originals(
|
def keep_only_new_originals(
|
||||||
additional: list[Original], existing: list[Original] | None = None
|
additional: list[Original], existing: list[Original] | None = None
|
||||||
):
|
):
|
||||||
if not existing:
|
if not existing:
|
||||||
existing = [e.original for e in load_existing_improvements()]
|
existing = [e.original for e in repo.get_all()]
|
||||||
|
|
||||||
existing_hashes = set([e.id for e in existing])
|
existing_hashes = set([e.id for e in existing])
|
||||||
|
|
||||||
|
|
@ -128,7 +98,7 @@ def improve_summary(original_title: str, new_title: str, original_summary: str):
|
||||||
def refresh_articles():
|
def refresh_articles():
|
||||||
adding = keep_only_new_originals(grab_latest_originals())
|
adding = keep_only_new_originals(grab_latest_originals())
|
||||||
improved = improve_originals(adding)
|
improved = improve_originals(adding)
|
||||||
save_new_improvements(improved)
|
repo.add_all(improved)
|
||||||
print(f"Updated articles. Added {len(improved)} new ones.")
|
print(f"Updated articles. Added {len(improved)} new ones.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,7 +111,7 @@ async def fetch_update():
|
||||||
## HTML (& hyperdata) responses
|
## HTML (& hyperdata) responses
|
||||||
@app.get("/improvements", response_class=HTMLResponse)
|
@app.get("/improvements", response_class=HTMLResponse)
|
||||||
def list_improvements():
|
def list_improvements():
|
||||||
improved = load_existing_improvements()
|
improved = repo.get_all()
|
||||||
return (
|
return (
|
||||||
"""<button hx-get="/originals" hx-target="#content">Originals</button> """
|
"""<button hx-get="/originals" hx-target="#content">Originals</button> """
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
|
|
@ -160,7 +130,7 @@ def list_improvements():
|
||||||
|
|
||||||
@app.get("/originals", response_class=HTMLResponse)
|
@app.get("/originals", response_class=HTMLResponse)
|
||||||
def list_originals():
|
def list_originals():
|
||||||
improved = load_existing_improvements()
|
improved = repo.get_all()
|
||||||
return (
|
return (
|
||||||
"""<button hx-get="/improvements" hx-target="#content">Improvements</button> """
|
"""<button hx-get="/improvements" hx-target="#content">Improvements</button> """
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
|
|
@ -228,7 +198,7 @@ if __name__ == "__main__":
|
||||||
# save_new_improvements(improved)
|
# save_new_improvements(improved)
|
||||||
|
|
||||||
# migrate to newer version
|
# migrate to newer version
|
||||||
improved = load_existing_improvements()
|
improved = repo.get_all()
|
||||||
for imp in improved:
|
for imp in improved:
|
||||||
imp.original.__post_init__()
|
imp.original.__post_init__()
|
||||||
print(f"Old Title: {imp.original.title}")
|
print(f"Old Title: {imp.original.title}")
|
||||||
|
|
@ -239,4 +209,4 @@ if __name__ == "__main__":
|
||||||
print(f"Summary: {imp.summary}")
|
print(f"Summary: {imp.summary}")
|
||||||
|
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
save_new_improvements(improved)
|
repo.add_all(improved)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from prophet.domain.improvement import Improvement
|
from prophet.domain.improvement import Improvement
|
||||||
|
|
||||||
|
|
||||||
|
class ImprovementNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class IImprovementRepo(Protocol):
|
class IImprovementRepo(Protocol):
|
||||||
def add(self, improvement: Improvement) -> None:
|
def add(self, improvement: Improvement) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add_all(self, improvements: list[Improvement]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get(self, id: str) -> Improvement:
|
def get(self, id: str) -> Improvement:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
||||||
49
prophet/infra/improvement_pickle_repo.py
Normal file
49
prophet/infra/improvement_pickle_repo.py
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import override
|
||||||
|
|
||||||
|
from prophet.domain.improvement import Improvement
|
||||||
|
from prophet.domain.improvement_repo import IImprovementRepo, ImprovementNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
class ImprovementPickleRepo(IImprovementRepo):
|
||||||
|
pickle_dir: Path
|
||||||
|
|
||||||
|
def __init__(self, pickle_dir: str | Path = "/tmp/pollenprophet") -> None:
|
||||||
|
self.pickle_dir = Path(pickle_dir)
|
||||||
|
self.pickle_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def add(self, improvement: Improvement) -> None:
|
||||||
|
fname = self.pickle_dir / improvement.id
|
||||||
|
try:
|
||||||
|
with open(fname, "wb") as f:
|
||||||
|
pickle.dump(improvement, f)
|
||||||
|
print(f"Saved {fname}")
|
||||||
|
except FileExistsError:
|
||||||
|
print(f"Error saving file {fname}")
|
||||||
|
|
||||||
|
@override
|
||||||
|
def add_all(self, improvements: list[Improvement]) -> None:
|
||||||
|
for imp in improvements:
|
||||||
|
self.add(imp)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get(self, id: str) -> Improvement:
|
||||||
|
try:
|
||||||
|
with open(self.pickle_dir / id, "rb") as f:
|
||||||
|
improvement: Improvement = pickle.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise ImprovementNotFoundError
|
||||||
|
|
||||||
|
return improvement
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_all(self) -> list[Improvement]:
|
||||||
|
improvements: list[Improvement] = []
|
||||||
|
for fname in Path(self.pickle_dir).iterdir():
|
||||||
|
try:
|
||||||
|
improvements.append(self.get(fname.name))
|
||||||
|
except ImprovementNotFoundError:
|
||||||
|
print(f"File {fname.absolute()} is not a valid Improvement.")
|
||||||
|
return improvements
|
||||||
Loading…
Add table
Add a link
Reference in a new issue