Update to make basic functionality work

This commit is contained in:
Marty Oehme 2024-03-13 23:51:41 +01:00
parent 193f6b6f0c
commit d3c79b8bf9
Signed by: Marty
GPG key ID: EDBF2ED917B2EF6A
4 changed files with 67 additions and 34 deletions

View file

@ -4,58 +4,69 @@ import requests
import logging import logging
from pathlib import Path from pathlib import Path
from rich.console import Console from rich.console import Console
import runpod
from runpod.endpoint import Job
from network_functions import _upload_to_oxo from network_functions import _upload_to_oxo
from job_functions import print_job_status from job_functions import print_job_status, start_job
from configuration import Config from configuration import Config
# TODO turn all this into config style options or @click-style flags/options # TODO turn all this into config style options or @click-style flags/options
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
RUNPOD_API_URL = "https://api.runpod.ai/v2/"
@quick.gui_option() @quick.gui_option()
@click.group() @click.group()
@click.pass_context @click.pass_context
@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint to use.") @click.option("--pod-id", "-i", help="ID of runpod serverless endpoint to use.")
@click.option("--token", "-t", help="Access token for runpod instance.") @click.option("--token", "-t", help="Access token for runpod instance.")
# TODO @click.version_option() # TODO @click.version_option()
def cli(ctx, endpoint, token): def cli(ctx, pod_id, token):
"""Verbanote """Verbanote
Transcribes any audio file given using OpenAI's whisper AI Transcribes any audio file given using OpenAI's whisper AI
and pyannote for speaker detection. and pyannote for speaker detection.
""" """
print(f"Token: {token}") runpod.api_key = token
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
} }
options: Config = Config( options: Config = Config(
endpoint=endpoint, token=token, headers=headers, console=Console() endpoint=f"{RUNPOD_API_URL}{pod_id}",
pod_id=pod_id,
token=token,
headers=headers,
console=Console(),
) )
ctx.obj = options ctx.obj = options
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
@click.option("--language", "-l", help="Language to use for transcription in 2-letter ISO code (`de` or `fr`). Defaults to `en`.")
@click.argument( @click.argument(
"audiofile", "audiofile",
type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path), type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path),
) )
def start(config: Config, audiofile: Path) -> None: def start(config: Config, language: str, audiofile: Path) -> None:
"""Start processing the given audiofile. """Start processing the given audiofile.
Queues a job in the processing queue of the AI api. Queues a job in the processing queue of the AI api.
""" """
endpoint_new_job = f"{config.endpoint}/run" with config.console.status("[bold green]Uploading data to 0x0.st...\n"):
with config.console.status("[bold green]Uploading data..."):
url = _upload_to_oxo(audiofile) url = _upload_to_oxo(audiofile)
input_data = {"input": {"url": url}} input_data = {"url": url, "lang": language}
config.console.log(f"[green]Requesting new job for[/green] {audiofile}...") config.console.print(
response = requests.post(endpoint_new_job, json=input_data, headers=config.headers) f"[green]Requesting new job for[/green] {audiofile} over {url}..."
job_id = response.json()["id"] )
config.console.log(f"[green]Job[/green] {job_id} [green]has been queued.[/green]") job = start_job(config, input_data)
print_job_status(config, job_id) config.console.print(
f"[green]Job[/green] {job.job_id} [green]has been queued.[/green]"
)
print_job_status(config, job)
@cli.command() @cli.command()
@ -72,7 +83,8 @@ def health(config: Config) -> None:
@click.pass_obj @click.pass_obj
@click.argument("job_id") @click.argument("job_id")
def job(config: Config, job_id: str) -> None: def job(config: Config, job_id: str) -> None:
print_job_status(config, job_id) job = Job(config.pod_id, job_id)
print_job_status(config, job, once=True)
def cancel(config: Config, job_id: str) -> None: def cancel(config: Config, job_id: str) -> None:

View file

@ -4,6 +4,7 @@ from rich.console import Console
@dataclass @dataclass
class Config: class Config:
pod_id: str
endpoint: str endpoint: str
token: str token: str
headers: dict[str, str] headers: dict[str, str]

View file

@ -1,10 +1,14 @@
import time import time
import requests
from datetime import timedelta from datetime import timedelta
from math import floor from math import floor
from rich.table import Table from typing import Any
from rich.live import Live
import requests
import runpod
from configuration import Config from configuration import Config
from rich.live import Live
from rich.table import Table
from runpod.endpoint import Job
STATUS_MAPPING = { STATUS_MAPPING = {
"IN_QUEUE": "[yellow]queued[/yellow]", "IN_QUEUE": "[yellow]queued[/yellow]",
@ -15,26 +19,36 @@ STATUS_MAPPING = {
} }
def print_job_status(config: Config, job_id: str, once:bool = False) -> None: def start_job(config: Config, input: Any) -> Job:
endpoint = runpod.Endpoint(config.pod_id)
return endpoint.run(input)
def print_job_status(config: Config, job: Job, once: bool = False) -> None:
job_id = job.job_id
result = _request_job_state(config, job_id) result = _request_job_state(config, job_id)
if not result: if not result:
return return
def result_to_values(result: dict) -> dict[str, str]: def result_to_values(result: dict) -> dict[str, str]:
output = result.get("output", {})
transcription:str = output.get("transcription_url", "...")
diarization = output.get("diarization_url", "...")
return { return {
"status": STATUS_MAPPING[result["status"]], "status": STATUS_MAPPING[result["status"]],
"transcription": result.get("transcription_url", "..."), "transcription": transcription.removeprefix(r"b'").removesuffix(r"\n'"),
"diarization": result.get("diarization_url", "..."), "diarization": diarization.removeprefix(r"b'").removesuffix(r"\n'"),
} }
values: dict[str, str] = result_to_values(result) values: dict[str, str] = result_to_values(result)
def rebuild_table(): def rebuild_table():
table = Table() table = Table()
table.add_column("Status") table.add_column("Status")
table.add_column("Time running") table.add_column("Time running")
table.add_column("Job ID") table.add_column("Job ID", overflow='fold')
table.add_column("Diarization") table.add_column("Diarization", overflow='fold')
table.add_column("Transcription") table.add_column("Transcription", overflow='fold')
table.add_row( table.add_row(
values.get("status", "unknown"), values.get("status", "unknown"),
str(sw_current), str(sw_current),
@ -47,15 +61,21 @@ def print_job_status(config: Config, job_id: str, once:bool = False) -> None:
sw_start: float = time.time() sw_start: float = time.time()
sw_current: timedelta = timedelta() sw_current: timedelta = timedelta()
with Live(get_renderable=rebuild_table, refresh_per_second=2): with Live(get_renderable=rebuild_table, refresh_per_second=2):
while not once: while True:
if once:
break
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
once = True
result = _request_job_state(config, job_id, silent=True) result = _request_job_state(config, job_id, silent=True)
values: dict[str, str] = result_to_values(result) values = result_to_values(result)
sw_current = timedelta(seconds=floor(time.time() - sw_start)) sw_current = timedelta(seconds=floor(time.time() - sw_start))
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
break
time.sleep(1) time.sleep(1)
if "transcription" in values:
config.console.print(f"[green]Transcript:[/green] {values['transcription']}")
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict: def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
endpoint_health = f"{config.endpoint}/status/{id}" endpoint_health = f"{config.endpoint}/status/{id}"
@ -73,4 +93,3 @@ def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
if not response.ok: if not response.ok:
raise requests.exceptions.HTTPError() raise requests.exceptions.HTTPError()
return response.json() return response.json()

View file

@ -11,6 +11,7 @@ def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) ->
) )
if not resp.ok: if not resp.ok:
raise requests.exceptions.HTTPError() raise requests.exceptions.HTTPError()
console.log(f"Uploaded file {file} to {str(resp.content)}") url = resp.content.decode().strip()
return str(resp.content).strip() console.log(f"Uploaded file {file} to {url}")
return url