Add rich display and quick gui

This commit is contained in:
Marty Oehme 2023-08-25 22:48:19 +02:00
parent c2710d180b
commit 4441fe3d46
Signed by: Marty
GPG key ID: EDBF2ED917B2EF6A
3 changed files with 331 additions and 64 deletions

View file

@ -1,67 +1,160 @@
import time
import requests
import logging
import time
import sys
import click
from datetime import timedelta
from math import floor
from pathlib import Path
args = sys.argv
from dataclasses import dataclass
import click
from rich.console import Console
from rich.table import Table
from rich.live import Live
import quick
# TODO turn all this into config style options or @click-style flags/options
logging.basicConfig(level=logging.INFO)
pod_id = "7x0b7u16s6vyrc"
bearer_token = "EIWX9RO18PRXCD0RUSY26MSD062GUF6REQGGV6QB"
api = f"https://api.runpod.ai/v2/{pod_id}"
run_endpoint = f"{api}/run"
status_endpoint = f"{api}/status"
health_endpoint = f"{api}/health"
purge_endpoint = f"{api}/purge-queue"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {bearer_token}",
}
console = Console()
@dataclass
class Config:
endpoint: str
token: str
headers: dict[str, str]
@quick.gui_option()
@click.group()
@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint.")
@click.pass_context
@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint to use.")
@click.option("--token", "-t", help="Access token for runpod instance.")
# TODO @click.version_option()
def cli(token):
def cli(ctx, endpoint, token):
"""Verbanote
Transcribes any audio file given using OpenAI's whisper AI
and pyannote for speaker detection.
"""
print(f"Token: {token}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
options: Config = Config(endpoint=endpoint, token=token, headers=headers)
ctx.obj = options
@cli.command()
@click.pass_obj
@click.argument(
"audiofile",
type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path),
)
def start(audiofile):
def start(config: Config, audiofile: Path) -> None:
"""Start processing the given audiofile.
Queues a job in the processing queue of the AI api.
"""
url = _upload_to_oxo(audiofile)
endpoint_new_job = f"{config.endpoint}/run"
with console.status("[bold green]Uploading data..."):
url = _upload_to_oxo(audiofile)
input_data = {"input": {"url": url}}
logging.info(f"Requesting new job for {audiofile}...")
response = requests.post(run_endpoint, json=input_data, headers=headers)
click.echo(f"Job {response} has been queued.")
console.log(f"[green]Requesting new job for[/green] {audiofile}...")
response = requests.post(endpoint_new_job, json=input_data, headers=config.headers)
job_id = response.json()["id"]
console.log(f"[green]Job[/green] {job_id} [green]has been queued.[/green]")
print_job_status(config, job_id)
@cli.command()
def health():
logging.info("requesting health status...")
resp = requests.get(health_endpoint, headers=headers)
click.echo(resp)
@click.pass_obj
def health(config: Config) -> None:
endpoint_health = f"{config.endpoint}/health"
with console.status("[bold green]Requesting health status..."):
resp = requests.get(endpoint_health, headers=config.headers)
json = resp.json()
console.print_json(data=json)
@cli.command()
@click.pass_obj
@click.argument("job_id")
def job(config: Config, job_id: str) -> None:
print_job_status(config, job_id)
def cancel(config: Config, job_id: str) -> None:
...
STATUS_MAPPING = {
"IN_QUEUE": "[yellow]queued[/yellow]",
"IN_PROGRESS": "[blue]running[/blue]",
"CANCELLED": "[orange1]cancelled[/orange1]",
"COMPLETED": "[green]complete[/green]",
"FAILED": "[red]failed[/red]",
}
def print_job_status(config: Config, job_id: str) -> None:
result = _request_job_state(config, job_id)
if not result:
return
values: dict[str, str] = {}
sw_start: float = time.time()
sw_current: timedelta = timedelta()
def rebuild_table():
table = Table()
table.add_column("Status")
table.add_column("Time running")
table.add_column("Job ID")
table.add_column("Diarization")
table.add_column("Transcription")
table.add_row(
values.get("status", "unknown"),
str(sw_current),
job_id,
values.get("diarization", "..."),
values.get("transcription", "..."),
)
return table
with Live(get_renderable=rebuild_table, refresh_per_second=1):
while True:
result = _request_job_state(config, job_id, silent=True)
sw_current = timedelta(seconds=floor(time.time() - sw_start))
values: dict[str, str] = {
"status": STATUS_MAPPING[result["status"]],
"transcription": result.get("transcription_url", "..."),
"diarization": result.get("diarization_url", "..."),
}
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
break
time.sleep(1)
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
endpoint_health = f"{config.endpoint}/status/{id}"
if silent:
response = requests.get(endpoint_health, headers=config.headers)
else:
with console.status(
f"[bold green]Requesting job[/bold green] {id}"
" [bold green]status...[/bold green]"
):
response = requests.get(endpoint_health, headers=config.headers)
if response.status_code == 404:
console.log(f"[red]Job[/red] {id} [red]not found on endpoint.[/red]")
return {}
if not response.ok:
raise requests.exceptions.HTTPError()
return response.json()
# TODO switch server component to be able to use S3 storage options
def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) -> str:
resp = requests.post(
url=url,
@ -69,42 +162,35 @@ def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) ->
)
if not resp.ok:
raise requests.exceptions.HTTPError()
logging.info(f"Uploaded file {file} to {str(resp.content)}")
console.log(f"Uploaded file {file} to {str(resp.content)}")
return str(resp.content)
def main(args: list[str]) -> None:
if args[1] == "status":
if len(args) <= 2:
logging.error("No job id to get status from supplied.")
sys.exit(1)
logging.info(f"requesting job {args[2]} status...")
response = requests.get(f"{status_endpoint}/{args[2]}", headers=headers)
elif args[1] == "cancel":
if len(args) <= 2:
logging.error("No job id to cancel supplied.")
sys.exit(1)
logging.info(f"requesting job {args[2]} cancellation...")
response = requests.get(f"{status_endpoint}/{args[2]}", headers=headers)
elif args[1] == "purge":
logging.info("purging all jobs in queue...")
response = requests.post(purge_endpoint, headers=headers)
json = response.json()
# the json will be similar to
# {'id': 'e3d2e250-ea81-4074-9838-1c52d006ddcf', 'status': 'IN_QUEUE'}
while "status" in json and (
json["status"] == "IN_QUEUE" or json["status"] == "IN_PROGRESS"
):
logging.info(f"{json['status']} for job {json['id']}, waiting...")
time.sleep(3)
response = requests.get(f"{status_endpoint}/{json['id']}", headers=headers)
json = response.json()
logging.info(json)
# def main(args: list[str]) -> None:
# if args[1] == "status":
# elif args[1] == "cancel":
# if len(args) <= 2:
# logging.error("No job id to cancel supplied.")
# sys.exit(1)
# logging.info(f"requesting job {args[2]} cancellation...")
# response = requests.get(f"{status_endpoint}/{args[2]}", headers=headers)
# elif args[1] == "purge":
# logging.info("purging all jobs in queue...")
# response = requests.post(purge_endpoint, headers=headers)
#
# # the json will be similar to
# # {'id': 'e3d2e250-ea81-4074-9838-1c52d006ddcf', 'status': 'IN_QUEUE'}
#
# while "status" in json and (
# json["status"] == "IN_QUEUE" or json["status"] == "IN_PROGRESS"
# ):
# logging.info(f"{json['status']} for job {json['id']}, waiting...")
# time.sleep(3)
# response = requests.get(f"{status_endpoint}/{json['id']}", headers=headers)
# json = response.json()
#
# logging.info(json)
if __name__ == "__main__":
cli(auto_envvar_prefix='VERBANOTE')
cli(auto_envvar_prefix="VERBANOTE")