verbanote-client/verbanote_client/main.py

197 lines
6.1 KiB
Python
Raw Normal View History

2023-08-25 20:48:19 +00:00
import time
2023-08-24 10:40:57 +00:00
import requests
import logging
2023-08-25 20:48:19 +00:00
from datetime import timedelta
from math import floor
2023-08-24 12:37:01 +00:00
from pathlib import Path
2023-08-25 20:48:19 +00:00
from dataclasses import dataclass
import click
from rich.console import Console
from rich.table import Table
from rich.live import Live
import quick
2023-08-24 10:40:57 +00:00
# TODO turn all this into config style options or @click-style flags/options
logging.basicConfig(level=logging.INFO)
2023-08-25 20:48:19 +00:00
console = Console()
2023-08-24 10:40:57 +00:00
2023-08-25 20:48:19 +00:00
@dataclass
class Config:
endpoint: str
token: str
headers: dict[str, str]
2023-08-24 10:40:57 +00:00
2023-08-25 20:48:19 +00:00
@quick.gui_option()
2023-08-24 12:37:01 +00:00
@click.group()
2023-08-25 20:48:19 +00:00
@click.pass_context
@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint to use.")
2023-08-24 12:37:01 +00:00
@click.option("--token", "-t", help="Access token for runpod instance.")
# TODO @click.version_option()
2023-08-25 20:48:19 +00:00
def cli(ctx, endpoint, token):
2023-08-24 12:37:01 +00:00
"""Verbanote
Transcribes any audio file given using OpenAI's whisper AI
and pyannote for speaker detection.
"""
print(f"Token: {token}")
2023-08-25 20:48:19 +00:00
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
options: Config = Config(endpoint=endpoint, token=token, headers=headers)
ctx.obj = options
2023-08-24 12:37:01 +00:00
@cli.command()
2023-08-25 20:48:19 +00:00
@click.pass_obj
2023-08-24 12:37:01 +00:00
@click.argument(
"audiofile",
type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path),
)
2023-08-25 20:48:19 +00:00
def start(config: Config, audiofile: Path) -> None:
2023-08-24 12:37:01 +00:00
"""Start processing the given audiofile.
Queues a job in the processing queue of the AI api.
"""
2023-08-25 20:48:19 +00:00
endpoint_new_job = f"{config.endpoint}/run"
with console.status("[bold green]Uploading data..."):
url = _upload_to_oxo(audiofile)
2023-08-24 12:37:01 +00:00
input_data = {"input": {"url": url}}
2023-08-25 20:48:19 +00:00
console.log(f"[green]Requesting new job for[/green] {audiofile}...")
response = requests.post(endpoint_new_job, json=input_data, headers=config.headers)
job_id = response.json()["id"]
console.log(f"[green]Job[/green] {job_id} [green]has been queued.[/green]")
print_job_status(config, job_id)
@cli.command()
@click.pass_obj
def health(config: Config) -> None:
endpoint_health = f"{config.endpoint}/health"
with console.status("[bold green]Requesting health status..."):
resp = requests.get(endpoint_health, headers=config.headers)
json = resp.json()
console.print_json(data=json)
2023-08-24 10:40:57 +00:00
2023-08-24 12:37:01 +00:00
@cli.command()
2023-08-25 20:48:19 +00:00
@click.pass_obj
@click.argument("job_id")
def job(config: Config, job_id: str) -> None:
print_job_status(config, job_id)
2023-08-24 12:37:01 +00:00
2023-08-25 20:48:19 +00:00
def cancel(config: Config, job_id: str) -> None:
...
2023-08-24 12:37:01 +00:00
2023-08-25 20:48:19 +00:00
STATUS_MAPPING = {
"IN_QUEUE": "[yellow]queued[/yellow]",
"IN_PROGRESS": "[blue]running[/blue]",
"CANCELLED": "[orange1]cancelled[/orange1]",
"COMPLETED": "[green]complete[/green]",
"FAILED": "[red]failed[/red]",
}
def print_job_status(config: Config, job_id: str) -> None:
result = _request_job_state(config, job_id)
if not result:
return
values: dict[str, str] = {}
sw_start: float = time.time()
sw_current: timedelta = timedelta()
def rebuild_table():
table = Table()
table.add_column("Status")
table.add_column("Time running")
table.add_column("Job ID")
table.add_column("Diarization")
table.add_column("Transcription")
table.add_row(
values.get("status", "unknown"),
str(sw_current),
job_id,
values.get("diarization", "..."),
values.get("transcription", "..."),
)
return table
with Live(get_renderable=rebuild_table, refresh_per_second=1):
while True:
result = _request_job_state(config, job_id, silent=True)
sw_current = timedelta(seconds=floor(time.time() - sw_start))
values: dict[str, str] = {
"status": STATUS_MAPPING[result["status"]],
"transcription": result.get("transcription_url", "..."),
"diarization": result.get("diarization_url", "..."),
}
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
break
time.sleep(1)
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
endpoint_health = f"{config.endpoint}/status/{id}"
if silent:
response = requests.get(endpoint_health, headers=config.headers)
else:
with console.status(
f"[bold green]Requesting job[/bold green] {id}"
" [bold green]status...[/bold green]"
):
response = requests.get(endpoint_health, headers=config.headers)
if response.status_code == 404:
console.log(f"[red]Job[/red] {id} [red]not found on endpoint.[/red]")
return {}
if not response.ok:
raise requests.exceptions.HTTPError()
return response.json()
# TODO switch server component to be able to use S3 storage options
2023-08-24 12:37:01 +00:00
def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) -> str:
resp = requests.post(
url=url,
files={"file": open(file, "rb"), "expires": str(expires)},
)
if not resp.ok:
raise requests.exceptions.HTTPError()
2023-08-25 20:48:19 +00:00
console.log(f"Uploaded file {file} to {str(resp.content)}")
2023-08-24 12:37:01 +00:00
return str(resp.content)
2023-08-25 20:48:19 +00:00
# def main(args: list[str]) -> None:
# if args[1] == "status":
# elif args[1] == "cancel":
# if len(args) <= 2:
# logging.error("No job id to cancel supplied.")
# sys.exit(1)
# logging.info(f"requesting job {args[2]} cancellation...")
# response = requests.get(f"{status_endpoint}/{args[2]}", headers=headers)
# elif args[1] == "purge":
# logging.info("purging all jobs in queue...")
# response = requests.post(purge_endpoint, headers=headers)
#
# # the json will be similar to
# # {'id': 'e3d2e250-ea81-4074-9838-1c52d006ddcf', 'status': 'IN_QUEUE'}
#
# while "status" in json and (
# json["status"] == "IN_QUEUE" or json["status"] == "IN_PROGRESS"
# ):
# logging.info(f"{json['status']} for job {json['id']}, waiting...")
# time.sleep(3)
# response = requests.get(f"{status_endpoint}/{json['id']}", headers=headers)
# json = response.json()
#
# logging.info(json)
2023-08-24 10:40:57 +00:00
if __name__ == "__main__":
2023-08-25 20:48:19 +00:00
cli(auto_envvar_prefix="VERBANOTE")