diff --git a/prophet/domain/llm.py b/prophet/domain/llm.py index 01ff319..0f32610 100644 --- a/prophet/domain/llm.py +++ b/prophet/domain/llm.py @@ -1,10 +1,11 @@ from typing import Protocol +from prophet.domain.improvement import Improvement from prophet.domain.original import Original class LLMClient(Protocol): - def get_alternative_title_suggestions(self, original_content: str) -> str: + def rewrite(self, original: Original) -> Improvement: raise NotImplementedError def rewrite_title( @@ -16,3 +17,6 @@ class LLMClient(Protocol): self, original: Original, improved_title: str | None = None ) -> str: raise NotImplementedError + + def get_alternative_title_suggestions(self, original_content: str) -> str: + raise NotImplementedError diff --git a/prophet/infra/llm_groq.py b/prophet/infra/llm_groq.py index f90263e..e91d8f3 100644 --- a/prophet/infra/llm_groq.py +++ b/prophet/infra/llm_groq.py @@ -20,6 +20,14 @@ class GroqClient(LLMClient): self.config_ai = config_ai if config_ai else AiConfig.from_env() self.client = client if client else Groq(api_key=self.config_ai.API_KEY) + @override + def rewrite(self, original: Original) -> Improvement: + suggestions = self.get_alternative_title_suggestions(original.title) + new_title = self.rewrite_title(original.title, suggestions) + new_summary = self.rewrite_summary(original, new_title) + + return Improvement(original=original, title=new_title, summary=new_summary) + @override def get_alternative_title_suggestions( self, original_content: str, custom_prompt: str | None = None