Extract data classes and llm class
This commit is contained in:
parent
f96b6413e2
commit
c537b1e750
3 changed files with 118 additions and 108 deletions
115
prophet/app.py
115
prophet/app.py
|
|
@ -1,20 +1,16 @@
|
|||
import hashlib
|
||||
import json
|
||||
import pickle
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import feedparser
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from groq import Groq
|
||||
|
||||
from prophet.config import AiConfig
|
||||
from prophet.data import Improvement, Original
|
||||
from prophet.llm import LLMClient
|
||||
|
||||
BEE_FEED = "https://babylonbee.com/feed"
|
||||
BEE_FEED_TEST = "test/resources/feed_short.atom" # NOTE: Switch out when done testing
|
||||
|
|
@ -23,45 +19,7 @@ PICKLE_DIR = "/tmp/pollenprophet"
|
|||
|
||||
REFRESH_PERIOD = 3600 # between fetching articles, in seconds
|
||||
|
||||
config_ai: AiConfig = AiConfig.from_env()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Original: # BadJoke: Sting
|
||||
title: str
|
||||
summary: str
|
||||
link: str
|
||||
date: datetime
|
||||
image_link: str | None = None
|
||||
id: str = field(init=False)
|
||||
|
||||
def _extract_img(self, s: str) -> tuple[str, str]: # [img_link, rest of string]
|
||||
img: str
|
||||
m = re.match(r'<img src="(?P<img>.+?)"', s)
|
||||
try:
|
||||
img = m.group("img")
|
||||
except (IndexError, NameError):
|
||||
return ("", s)
|
||||
|
||||
if img:
|
||||
rest = re.sub(r"<img src=.+?>", "", s)
|
||||
return (img, rest)
|
||||
|
||||
def __post_init__(self):
|
||||
self.id = hashlib.sha256(self.link.encode()).hexdigest()
|
||||
|
||||
extracted = self._extract_img(self.summary)
|
||||
if extracted[0]:
|
||||
self.image_link = extracted[0]
|
||||
self.summary = extracted[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Improvement: # GoodJoke: Queen
|
||||
original: Original
|
||||
title: str
|
||||
summary: str
|
||||
id: str = str(uuid4())
|
||||
llm: LLMClient = LLMClient()
|
||||
|
||||
|
||||
def grab_latest_originals() -> list[Original]:
|
||||
|
|
@ -126,8 +84,8 @@ def keep_only_new_originals(
|
|||
def improve_originals(originals: list[Original]) -> list[Improvement]:
|
||||
improvements: list[Improvement] = []
|
||||
for orig in originals:
|
||||
new_title = rewrite_title_with_groq(orig.title)
|
||||
new_summary = rewrite_summary_with_groq(orig, new_title)
|
||||
new_title = llm.rewrite_title_with_groq(orig.title)
|
||||
new_summary = llm.rewrite_summary_with_groq(orig, new_title)
|
||||
|
||||
improvements.append(
|
||||
Improvement(original=orig, title=new_title, summary=new_summary)
|
||||
|
|
@ -135,65 +93,6 @@ def improve_originals(originals: list[Original]) -> list[Improvement]:
|
|||
return improvements
|
||||
|
||||
|
||||
def rewrite_title_with_groq(original_content: str) -> str:
|
||||
client = Groq(api_key=config_ai.API_KEY)
|
||||
|
||||
suggestions = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a comedy writer at a satirical newspaper. Improve on the following satirical headline. Your new headline is funny, can involve current political events and has an edge to it. Print only the suggestions, with one suggestion on each line.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": original_content,
|
||||
},
|
||||
],
|
||||
model="llama-3.3-70b-versatile",
|
||||
)
|
||||
suggestions_str = suggestions.choices[0].message.content
|
||||
if not suggestions_str:
|
||||
raise ValueError
|
||||
print("Suggestions: ", suggestions_str)
|
||||
winner = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an editor at a satirical newspaper. Improve on the following satirical headline. For a given headline, you diligently evaluate: (1) Whether the headline is funny; (2) Whether the headline follows a clear satirical goal; (3) Whether the headline has sufficient substance and bite. Based on the outcomes of your review, you pick your favorite headline from the given suggestions and you make targeted revisions to it. Your output consists solely of the revised headline.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": suggestions_str,
|
||||
},
|
||||
],
|
||||
model="llama-3.3-70b-versatile",
|
||||
)
|
||||
print("Winner: ", winner.choices[0].message.content)
|
||||
winner_str = winner.choices[0].message.content
|
||||
if not winner_str:
|
||||
raise ValueError
|
||||
return winner_str.strip(" \"'")
|
||||
|
||||
|
||||
def rewrite_summary_with_groq(orig: Original, improved_title: str) -> str:
|
||||
client = Groq(api_key=config_ai.API_KEY)
|
||||
|
||||
summary = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Below there is an original title and an original summary. Then follows an improved title. Write an improved summary based on the original summary which fits to the improved title. Only output the improved summary.\n\nTitle:{orig.title}\nSummary:{orig.summary}\n---\nTitle:{improved_title}\nSummary:",
|
||||
}
|
||||
],
|
||||
model="llama-3.3-70b-versatile",
|
||||
)
|
||||
summary_str = summary.choices[0].message.content
|
||||
if not summary_str:
|
||||
raise ValueError
|
||||
print("Improved summary", summary_str)
|
||||
return summary_str.strip(" \"'")
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = [
|
||||
|
|
@ -212,7 +111,7 @@ app.add_middleware(
|
|||
|
||||
@app.get("/improve-title")
|
||||
def improve_headline(content: str):
|
||||
return rewrite_title_with_groq(content)
|
||||
return llm.rewrite_title_with_groq(content)
|
||||
|
||||
|
||||
@app.get("/improve-summary")
|
||||
|
|
@ -220,7 +119,7 @@ def improve_summary(original_title: str, new_title: str, original_summary: str):
|
|||
o = Original(
|
||||
title=original_title, summary=original_summary, link="", date=datetime.now()
|
||||
)
|
||||
return rewrite_summary_with_groq(o, new_title)
|
||||
return llm.rewrite_summary_with_groq(o, new_title)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
|
|
|
|||
43
prophet/data.py
Normal file
43
prophet/data.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import hashlib
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
@dataclass
|
||||
class Original: # BadJoke: Sting
|
||||
title: str
|
||||
summary: str
|
||||
link: str
|
||||
date: datetime
|
||||
image_link: str | None = None
|
||||
id: str = field(init=False)
|
||||
|
||||
def _extract_img(self, s: str) -> tuple[str, str]: # [img_link, rest of string]
|
||||
img: str
|
||||
m = re.match(r'<img src="(?P<img>.+?)"', s)
|
||||
try:
|
||||
img = m.group("img")
|
||||
except (IndexError, NameError):
|
||||
return ("", s)
|
||||
|
||||
if img:
|
||||
rest = re.sub(r"<img src=.+?>", "", s)
|
||||
return (img, rest)
|
||||
|
||||
def __post_init__(self):
|
||||
self.id = hashlib.sha256(self.link.encode()).hexdigest()
|
||||
|
||||
extracted = self._extract_img(self.summary)
|
||||
if extracted[0]:
|
||||
self.image_link = extracted[0]
|
||||
self.summary = extracted[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Improvement: # GoodJoke: Queen
|
||||
original: Original
|
||||
title: str
|
||||
summary: str
|
||||
id: str = str(uuid4())
|
||||
68
prophet/llm.py
Normal file
68
prophet/llm.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from groq import Groq
|
||||
|
||||
from prophet.data import Original
|
||||
from prophet.config import AiConfig
|
||||
|
||||
|
||||
class LLMClient:
|
||||
config_ai: AiConfig
|
||||
client: Groq
|
||||
|
||||
def __init__(
|
||||
self, config_ai: AiConfig | None = None, client: Groq | None = None
|
||||
) -> None:
|
||||
self.config_ai = config_ai if config_ai else AiConfig.from_env()
|
||||
self.client = client if client else Groq(api_key=self.config_ai.API_KEY)
|
||||
|
||||
def rewrite_title_with_groq(self, original_content: str) -> str:
|
||||
suggestions = self.client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a comedy writer at a satirical newspaper. Improve on the following satirical headline. Your new headline is funny, can involve current political events and has an edge to it. Print only the suggestions, with one suggestion on each line.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": original_content,
|
||||
},
|
||||
],
|
||||
model="llama-3.3-70b-versatile",
|
||||
)
|
||||
suggestions_str = suggestions.choices[0].message.content
|
||||
if not suggestions_str:
|
||||
raise ValueError
|
||||
print("Suggestions: ", suggestions_str)
|
||||
winner = self.client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an editor at a satirical newspaper. Improve on the following satirical headline. For a given headline, you diligently evaluate: (1) Whether the headline is funny; (2) Whether the headline follows a clear satirical goal; (3) Whether the headline has sufficient substance and bite. Based on the outcomes of your review, you pick your favorite headline from the given suggestions and you make targeted revisions to it. Your output consists solely of the revised headline.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": suggestions_str,
|
||||
},
|
||||
],
|
||||
model="llama-3.3-70b-versatile",
|
||||
)
|
||||
print("Winner: ", winner.choices[0].message.content)
|
||||
winner_str = winner.choices[0].message.content
|
||||
if not winner_str:
|
||||
raise ValueError
|
||||
return winner_str.strip(" \"'")
|
||||
|
||||
def rewrite_summary_with_groq(self, orig: Original, improved_title: str) -> str:
|
||||
summary = self.client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Below there is an original title and an original summary. Then follows an improved title. Write an improved summary based on the original summary which fits to the improved title. Only output the improved summary.\n\nTitle:{orig.title}\nSummary:{orig.summary}\n---\nTitle:{improved_title}\nSummary:",
|
||||
}
|
||||
],
|
||||
model="llama-3.3-70b-versatile",
|
||||
)
|
||||
summary_str = summary.choices[0].message.content
|
||||
if not summary_str:
|
||||
raise ValueError
|
||||
print("Improved summary", summary_str)
|
||||
return summary_str.strip(" \"'")
|
||||
Loading…
Add table
Add a link
Reference in a new issue