95 lines
3.1 KiB
Python
95 lines
3.1 KiB
Python
import time
|
|
from datetime import timedelta
|
|
from math import floor
|
|
from typing import Any
|
|
|
|
import requests
|
|
import runpod
|
|
from configuration import Config
|
|
from rich.live import Live
|
|
from rich.table import Table
|
|
from runpod.endpoint import Job
|
|
|
|
STATUS_MAPPING = {
|
|
"IN_QUEUE": "[yellow]queued[/yellow]",
|
|
"IN_PROGRESS": "[blue]running[/blue]",
|
|
"CANCELLED": "[orange1]cancelled[/orange1]",
|
|
"COMPLETED": "[green]complete[/green]",
|
|
"FAILED": "[red]failed[/red]",
|
|
}
|
|
|
|
|
|
def start_job(config: Config, input: Any) -> Job:
|
|
endpoint = runpod.Endpoint(config.pod_id)
|
|
return endpoint.run(input)
|
|
|
|
|
|
def print_job_status(config: Config, job: Job, once: bool = False) -> None:
|
|
job_id = job.job_id
|
|
result = _request_job_state(config, job_id)
|
|
if not result:
|
|
return
|
|
|
|
def result_to_values(result: dict) -> dict[str, str]:
|
|
output = result.get("output", {})
|
|
transcription:str = output.get("transcription_url", "...")
|
|
diarization = output.get("diarization_url", "...")
|
|
return {
|
|
"status": STATUS_MAPPING[result["status"]],
|
|
"transcription": transcription.removeprefix(r"b'").removesuffix(r"\n'"),
|
|
"diarization": diarization.removeprefix(r"b'").removesuffix(r"\n'"),
|
|
}
|
|
|
|
values: dict[str, str] = result_to_values(result)
|
|
|
|
def rebuild_table():
|
|
table = Table()
|
|
table.add_column("Status")
|
|
table.add_column("Time running")
|
|
table.add_column("Job ID", overflow='fold')
|
|
table.add_column("Diarization", overflow='fold')
|
|
table.add_column("Transcription", overflow='fold')
|
|
table.add_row(
|
|
values.get("status", "unknown"),
|
|
str(sw_current),
|
|
job_id,
|
|
values.get("diarization", "..."),
|
|
values.get("transcription", "..."),
|
|
)
|
|
return table
|
|
|
|
sw_start: float = time.time()
|
|
sw_current: timedelta = timedelta()
|
|
with Live(get_renderable=rebuild_table, refresh_per_second=2):
|
|
while True:
|
|
if once:
|
|
break
|
|
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
|
|
once = True
|
|
|
|
result = _request_job_state(config, job_id, silent=True)
|
|
values = result_to_values(result)
|
|
sw_current = timedelta(seconds=floor(time.time() - sw_start))
|
|
|
|
time.sleep(1)
|
|
|
|
if "transcription" in values:
|
|
config.console.print(f"[green]Transcript:[/green] {values['transcription']}")
|
|
|
|
|
|
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
|
|
endpoint_health = f"{config.endpoint}/status/{id}"
|
|
if silent:
|
|
response = requests.get(endpoint_health, headers=config.headers)
|
|
else:
|
|
with config.console.status(
|
|
f"[bold green]Requesting job[/bold green] {id}"
|
|
" [bold green]status...[/bold green]"
|
|
):
|
|
response = requests.get(endpoint_health, headers=config.headers)
|
|
if response.status_code == 404:
|
|
config.console.log(f"[red]Job[/red] {id} [red]not found on endpoint.[/red]")
|
|
return {}
|
|
if not response.ok:
|
|
raise requests.exceptions.HTTPError()
|
|
return response.json()
|