verbanote-client/verbanote_client/job_functions.py

96 lines
3.1 KiB
Python

import time
from datetime import timedelta
from math import floor
from typing import Any
import requests
import runpod
from configuration import Config
from rich.live import Live
from rich.table import Table
from runpod.endpoint import Job
STATUS_MAPPING = {
"IN_QUEUE": "[yellow]queued[/yellow]",
"IN_PROGRESS": "[blue]running[/blue]",
"CANCELLED": "[orange1]cancelled[/orange1]",
"COMPLETED": "[green]complete[/green]",
"FAILED": "[red]failed[/red]",
}
def start_job(config: Config, input: Any) -> Job:
endpoint = runpod.Endpoint(config.pod_id)
return endpoint.run(input)
def print_job_status(config: Config, job: Job, once: bool = False) -> None:
job_id = job.job_id
result = _request_job_state(config, job_id)
if not result:
return
def result_to_values(result: dict) -> dict[str, str]:
output = result.get("output", {})
transcription:str = output.get("transcription_url", "...")
diarization = output.get("diarization_url", "...")
return {
"status": STATUS_MAPPING[result["status"]],
"transcription": transcription.removeprefix(r"b'").removesuffix(r"\n'"),
"diarization": diarization.removeprefix(r"b'").removesuffix(r"\n'"),
}
values: dict[str, str] = result_to_values(result)
def rebuild_table():
table = Table()
table.add_column("Status")
table.add_column("Time running")
table.add_column("Job ID", overflow='fold')
table.add_column("Diarization", overflow='fold')
table.add_column("Transcription", overflow='fold')
table.add_row(
values.get("status", "unknown"),
str(sw_current),
job_id,
values.get("diarization", "..."),
values.get("transcription", "..."),
)
return table
sw_start: float = time.time()
sw_current: timedelta = timedelta()
with Live(get_renderable=rebuild_table, refresh_per_second=2):
while True:
if once:
break
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
once = True
result = _request_job_state(config, job_id, silent=True)
values = result_to_values(result)
sw_current = timedelta(seconds=floor(time.time() - sw_start))
time.sleep(1)
if "transcription" in values:
config.console.print(f"[green]Transcript:[/green] {values['transcription']}")
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
endpoint_health = f"{config.endpoint}/status/{id}"
if silent:
response = requests.get(endpoint_health, headers=config.headers)
else:
with config.console.status(
f"[bold green]Requesting job[/bold green] {id}"
" [bold green]status...[/bold green]"
):
response = requests.get(endpoint_health, headers=config.headers)
if response.status_code == 404:
config.console.log(f"[red]Job[/red] {id} [red]not found on endpoint.[/red]")
return {}
if not response.ok:
raise requests.exceptions.HTTPError()
return response.json()