From d3c79b8bf98487bbf8cd3933caaa64a210356f79 Mon Sep 17 00:00:00 2001 From: Marty Oehme Date: Wed, 13 Mar 2024 23:51:41 +0100 Subject: [PATCH] Update to make basic functionality work --- verbanote_client/cli.py | 42 +++++++++++++-------- verbanote_client/configuration.py | 1 + verbanote_client/job_functions.py | 53 ++++++++++++++++++--------- verbanote_client/network_functions.py | 5 ++- 4 files changed, 67 insertions(+), 34 deletions(-) diff --git a/verbanote_client/cli.py b/verbanote_client/cli.py index 9cb3435..72da313 100644 --- a/verbanote_client/cli.py +++ b/verbanote_client/cli.py @@ -4,58 +4,69 @@ import requests import logging from pathlib import Path from rich.console import Console +import runpod +from runpod.endpoint import Job from network_functions import _upload_to_oxo -from job_functions import print_job_status +from job_functions import print_job_status, start_job from configuration import Config # TODO turn all this into config style options or @click-style flags/options logging.basicConfig(level=logging.INFO) +RUNPOD_API_URL = "https://api.runpod.ai/v2/" + @quick.gui_option() @click.group() @click.pass_context -@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint to use.") +@click.option("--pod-id", "-i", help="ID of runpod serverless endpoint to use.") @click.option("--token", "-t", help="Access token for runpod instance.") # TODO @click.version_option() -def cli(ctx, endpoint, token): +def cli(ctx, pod_id, token): """Verbanote Transcribes any audio file given using OpenAI's whisper AI and pyannote for speaker detection. """ - print(f"Token: {token}") + runpod.api_key = token headers = { "Content-Type": "application/json", "Authorization": f"Bearer {token}", } options: Config = Config( - endpoint=endpoint, token=token, headers=headers, console=Console() + endpoint=f"{RUNPOD_API_URL}{pod_id}", + pod_id=pod_id, + token=token, + headers=headers, + console=Console(), ) ctx.obj = options @cli.command() @click.pass_obj +@click.option("--language", "-l", help="Language to use for transcription in 2-letter ISO code (`de` or `fr`). Defaults to `en`.") @click.argument( "audiofile", type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path), ) -def start(config: Config, audiofile: Path) -> None: +def start(config: Config, language: str, audiofile: Path) -> None: """Start processing the given audiofile. Queues a job in the processing queue of the AI api. """ - endpoint_new_job = f"{config.endpoint}/run" - with config.console.status("[bold green]Uploading data..."): + with config.console.status("[bold green]Uploading data to 0x0.st...\n"): url = _upload_to_oxo(audiofile) - input_data = {"input": {"url": url}} - config.console.log(f"[green]Requesting new job for[/green] {audiofile}...") - response = requests.post(endpoint_new_job, json=input_data, headers=config.headers) - job_id = response.json()["id"] - config.console.log(f"[green]Job[/green] {job_id} [green]has been queued.[/green]") - print_job_status(config, job_id) + input_data = {"url": url, "lang": language} + config.console.print( + f"[green]Requesting new job for[/green] {audiofile} over {url}..." + ) + job = start_job(config, input_data) + config.console.print( + f"[green]Job[/green] {job.job_id} [green]has been queued.[/green]" + ) + print_job_status(config, job) @cli.command() @@ -72,7 +83,8 @@ def health(config: Config) -> None: @click.pass_obj @click.argument("job_id") def job(config: Config, job_id: str) -> None: - print_job_status(config, job_id) + job = Job(config.pod_id, job_id) + print_job_status(config, job, once=True) def cancel(config: Config, job_id: str) -> None: diff --git a/verbanote_client/configuration.py b/verbanote_client/configuration.py index e57113c..88209d0 100644 --- a/verbanote_client/configuration.py +++ b/verbanote_client/configuration.py @@ -4,6 +4,7 @@ from rich.console import Console @dataclass class Config: + pod_id: str endpoint: str token: str headers: dict[str, str] diff --git a/verbanote_client/job_functions.py b/verbanote_client/job_functions.py index 42b032a..e38ba2d 100644 --- a/verbanote_client/job_functions.py +++ b/verbanote_client/job_functions.py @@ -1,10 +1,14 @@ import time -import requests from datetime import timedelta from math import floor -from rich.table import Table -from rich.live import Live +from typing import Any + +import requests +import runpod from configuration import Config +from rich.live import Live +from rich.table import Table +from runpod.endpoint import Job STATUS_MAPPING = { "IN_QUEUE": "[yellow]queued[/yellow]", @@ -15,26 +19,36 @@ STATUS_MAPPING = { } -def print_job_status(config: Config, job_id: str, once:bool = False) -> None: +def start_job(config: Config, input: Any) -> Job: + endpoint = runpod.Endpoint(config.pod_id) + return endpoint.run(input) + + +def print_job_status(config: Config, job: Job, once: bool = False) -> None: + job_id = job.job_id result = _request_job_state(config, job_id) if not result: return - def result_to_values(result:dict)-> dict[str,str]: + def result_to_values(result: dict) -> dict[str, str]: + output = result.get("output", {}) + transcription:str = output.get("transcription_url", "...") + diarization = output.get("diarization_url", "...") return { - "status": STATUS_MAPPING[result["status"]], - "transcription": result.get("transcription_url", "..."), - "diarization": result.get("diarization_url", "..."), - } + "status": STATUS_MAPPING[result["status"]], + "transcription": transcription.removeprefix(r"b'").removesuffix(r"\n'"), + "diarization": diarization.removeprefix(r"b'").removesuffix(r"\n'"), + } + values: dict[str, str] = result_to_values(result) def rebuild_table(): table = Table() table.add_column("Status") table.add_column("Time running") - table.add_column("Job ID") - table.add_column("Diarization") - table.add_column("Transcription") + table.add_column("Job ID", overflow='fold') + table.add_column("Diarization", overflow='fold') + table.add_column("Transcription", overflow='fold') table.add_row( values.get("status", "unknown"), str(sw_current), @@ -47,15 +61,21 @@ def print_job_status(config: Config, job_id: str, once:bool = False) -> None: sw_start: float = time.time() sw_current: timedelta = timedelta() with Live(get_renderable=rebuild_table, refresh_per_second=2): - while not once: + while True: + if once: + break + if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS": + once = True + result = _request_job_state(config, job_id, silent=True) - values: dict[str, str] = result_to_values(result) + values = result_to_values(result) sw_current = timedelta(seconds=floor(time.time() - sw_start)) - if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS": - break time.sleep(1) + if "transcription" in values: + config.console.print(f"[green]Transcript:[/green] {values['transcription']}") + def _request_job_state(config: Config, id: str, silent: bool = False) -> dict: endpoint_health = f"{config.endpoint}/status/{id}" @@ -73,4 +93,3 @@ def _request_job_state(config: Config, id: str, silent: bool = False) -> dict: if not response.ok: raise requests.exceptions.HTTPError() return response.json() - diff --git a/verbanote_client/network_functions.py b/verbanote_client/network_functions.py index 29b56d3..0786657 100644 --- a/verbanote_client/network_functions.py +++ b/verbanote_client/network_functions.py @@ -11,6 +11,7 @@ def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) -> ) if not resp.ok: raise requests.exceptions.HTTPError() - console.log(f"Uploaded file {file} to {str(resp.content)}") - return str(resp.content).strip() + url = resp.content.decode().strip() + console.log(f"Uploaded file {file} to {url}") + return url