diff --git a/prophet/domain/improvement_repo.py b/prophet/domain/improvement_repo.py index 14b750c..4d1a016 100644 --- a/prophet/domain/improvement_repo.py +++ b/prophet/domain/improvement_repo.py @@ -17,7 +17,7 @@ class IImprovementRepo(Protocol): def get(self, id: str) -> Improvement: raise NotImplementedError - def get_all(self) -> list[Improvement]: + def get_all(self, last_n: int | None = None) -> list[Improvement]: raise NotImplementedError def remove(self, id: str) -> Improvement: diff --git a/prophet/domain/llm.py b/prophet/domain/llm.py index 0f32610..18aa936 100644 --- a/prophet/domain/llm.py +++ b/prophet/domain/llm.py @@ -5,7 +5,9 @@ from prophet.domain.original import Original class LLMClient(Protocol): - def rewrite(self, original: Original) -> Improvement: + def rewrite( + self, original: Original, previous_titles: list[str] | None = None + ) -> Improvement: raise NotImplementedError def rewrite_title( diff --git a/prophet/infra/improvement_supa_repo.py b/prophet/infra/improvement_supa_repo.py index 18f691e..e1b3219 100644 --- a/prophet/infra/improvement_supa_repo.py +++ b/prophet/infra/improvement_supa_repo.py @@ -48,15 +48,22 @@ class ImprovementSupaRepo(IImprovementRepo): ) @override - def get_all(self) -> list[Improvement]: - return [ - self._from_tbl_row(row) - for row in self.client.table(self.config.TABLE) - .select("*") - .order("date_orig_ts", desc=True) - .execute() - .data - ] + def get_all(self, last_n: int | None = None) -> list[Improvement]: + if not last_n: + sql = ( + self.client.table(self.config.TABLE) + .select("*") + .order("date_orig_ts", desc=True) + ) + else: + sql = ( + self.client.table(self.config.TABLE) + .select("*") + .order("date_orig_ts", desc=True) + .limit(last_n) + ) + + return [self._from_tbl_row(row) for row in sql.execute().data] @override def remove(self, id: str) -> Improvement: @@ -110,6 +117,7 @@ class ImprovementSupaRepo(IImprovementRepo): if __name__ == "__main__": # response = supabase.table("improvements").select("*").execute() repo = ImprovementSupaRepo() + print("latest entries:\n- ", "\n- ".join([imp.title for imp in repo.get_all(3)])) # from prophet.app import grab_latest_originals # latest = grab_latest_originals() diff --git a/prophet/infra/llm_groq.py b/prophet/infra/llm_groq.py index f528208..60dc29b 100644 --- a/prophet/infra/llm_groq.py +++ b/prophet/infra/llm_groq.py @@ -21,7 +21,9 @@ class GroqClient(LLMClient): self.client = client if client else Groq(api_key=self.config_ai.API_KEY) @override - def rewrite(self, original: Original) -> Improvement: + def rewrite( + self, original: Original, previous_titles: list[str] | None = None + ) -> Improvement: suggestions = self.get_alternative_title_suggestions(original.title) new_title = self.rewrite_title(original.title, suggestions) new_summary = self.rewrite_summary(original, new_title) @@ -30,18 +32,22 @@ class GroqClient(LLMClient): @override def get_alternative_title_suggestions( - self, original_content: str, custom_prompt: str | None = None + self, + original_content: str, + previous_titles: list[str] | None = None, + custom_prompt: str | None = None, ) -> str: prompt = ( custom_prompt if custom_prompt - else """ + else f""" Political context: We are in the year 2025, Donald Trump is President of the United States again. There has been a crackdown on 'illegal' immigration, with controversial disappearings happening - almost every day. Many are calling the United States an - increasingly fascist state. + almost every day by masked ICE agents. Many view the United States + as an increasingly fascist state, and the disappearings fueled by + racism. You are a comedy writer at a left-leaning satirical newspaper. Improve on the following satirical headline. Your new headline is @@ -49,6 +55,13 @@ class GroqClient(LLMClient): It should be roughly the length of the original headline. Print only new suggestions, with one suggestion on each line. + Do not create a headline naming Trump if more than 2 of the + previous headlines already do so and he is not specifically + referenced in the original headline. + + {"The previous 5 headlines you created are the following:\n- " if previous_titles else ""} + {"\n- ".join(previous_titles) if previous_titles else ""} + """ ) suggestions = self.client.chat.completions.create(