verbanote-client/verbanote_client/main.py

197 lines
6.1 KiB
Python

import time
import requests
import logging
from datetime import timedelta
from math import floor
from pathlib import Path
from dataclasses import dataclass
import click
from rich.console import Console
from rich.table import Table
from rich.live import Live
import quick
# TODO turn all this into config style options or @click-style flags/options
logging.basicConfig(level=logging.INFO)
console = Console()
@dataclass
class Config:
endpoint: str
token: str
headers: dict[str, str]
@quick.gui_option()
@click.group()
@click.pass_context
@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint to use.")
@click.option("--token", "-t", help="Access token for runpod instance.")
# TODO @click.version_option()
def cli(ctx, endpoint, token):
"""Verbanote
Transcribes any audio file given using OpenAI's whisper AI
and pyannote for speaker detection.
"""
print(f"Token: {token}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
options: Config = Config(endpoint=endpoint, token=token, headers=headers)
ctx.obj = options
@cli.command()
@click.pass_obj
@click.argument(
"audiofile",
type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path),
)
def start(config: Config, audiofile: Path) -> None:
"""Start processing the given audiofile.
Queues a job in the processing queue of the AI api.
"""
endpoint_new_job = f"{config.endpoint}/run"
with console.status("[bold green]Uploading data..."):
url = _upload_to_oxo(audiofile)
input_data = {"input": {"url": url}}
console.log(f"[green]Requesting new job for[/green] {audiofile}...")
response = requests.post(endpoint_new_job, json=input_data, headers=config.headers)
job_id = response.json()["id"]
console.log(f"[green]Job[/green] {job_id} [green]has been queued.[/green]")
print_job_status(config, job_id)
@cli.command()
@click.pass_obj
def health(config: Config) -> None:
endpoint_health = f"{config.endpoint}/health"
with console.status("[bold green]Requesting health status..."):
resp = requests.get(endpoint_health, headers=config.headers)
json = resp.json()
console.print_json(data=json)
@cli.command()
@click.pass_obj
@click.argument("job_id")
def job(config: Config, job_id: str) -> None:
print_job_status(config, job_id)
def cancel(config: Config, job_id: str) -> None:
...
STATUS_MAPPING = {
"IN_QUEUE": "[yellow]queued[/yellow]",
"IN_PROGRESS": "[blue]running[/blue]",
"CANCELLED": "[orange1]cancelled[/orange1]",
"COMPLETED": "[green]complete[/green]",
"FAILED": "[red]failed[/red]",
}
def print_job_status(config: Config, job_id: str) -> None:
result = _request_job_state(config, job_id)
if not result:
return
values: dict[str, str] = {}
sw_start: float = time.time()
sw_current: timedelta = timedelta()
def rebuild_table():
table = Table()
table.add_column("Status")
table.add_column("Time running")
table.add_column("Job ID")
table.add_column("Diarization")
table.add_column("Transcription")
table.add_row(
values.get("status", "unknown"),
str(sw_current),
job_id,
values.get("diarization", "..."),
values.get("transcription", "..."),
)
return table
with Live(get_renderable=rebuild_table, refresh_per_second=1):
while True:
result = _request_job_state(config, job_id, silent=True)
sw_current = timedelta(seconds=floor(time.time() - sw_start))
values: dict[str, str] = {
"status": STATUS_MAPPING[result["status"]],
"transcription": result.get("transcription_url", "..."),
"diarization": result.get("diarization_url", "..."),
}
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
break
time.sleep(1)
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
endpoint_health = f"{config.endpoint}/status/{id}"
if silent:
response = requests.get(endpoint_health, headers=config.headers)
else:
with console.status(
f"[bold green]Requesting job[/bold green] {id}"
" [bold green]status...[/bold green]"
):
response = requests.get(endpoint_health, headers=config.headers)
if response.status_code == 404:
console.log(f"[red]Job[/red] {id} [red]not found on endpoint.[/red]")
return {}
if not response.ok:
raise requests.exceptions.HTTPError()
return response.json()
# TODO switch server component to be able to use S3 storage options
def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) -> str:
resp = requests.post(
url=url,
files={"file": open(file, "rb"), "expires": str(expires)},
)
if not resp.ok:
raise requests.exceptions.HTTPError()
console.log(f"Uploaded file {file} to {str(resp.content)}")
return str(resp.content)
# def main(args: list[str]) -> None:
# if args[1] == "status":
# elif args[1] == "cancel":
# if len(args) <= 2:
# logging.error("No job id to cancel supplied.")
# sys.exit(1)
# logging.info(f"requesting job {args[2]} cancellation...")
# response = requests.get(f"{status_endpoint}/{args[2]}", headers=headers)
# elif args[1] == "purge":
# logging.info("purging all jobs in queue...")
# response = requests.post(purge_endpoint, headers=headers)
#
# # the json will be similar to
# # {'id': 'e3d2e250-ea81-4074-9838-1c52d006ddcf', 'status': 'IN_QUEUE'}
#
# while "status" in json and (
# json["status"] == "IN_QUEUE" or json["status"] == "IN_PROGRESS"
# ):
# logging.info(f"{json['status']} for job {json['id']}, waiting...")
# time.sleep(3)
# response = requests.get(f"{status_endpoint}/{json['id']}", headers=headers)
# json = response.json()
#
# logging.info(json)
if __name__ == "__main__":
cli(auto_envvar_prefix="VERBANOTE")