import time from datetime import timedelta from math import floor from typing import Any import requests import runpod from configuration import Config from rich.live import Live from rich.table import Table from runpod.endpoint import Job STATUS_MAPPING = { "IN_QUEUE": "[yellow]queued[/yellow]", "IN_PROGRESS": "[blue]running[/blue]", "CANCELLED": "[orange1]cancelled[/orange1]", "COMPLETED": "[green]complete[/green]", "FAILED": "[red]failed[/red]", } def start_job(config: Config, input: Any) -> Job: endpoint = runpod.Endpoint(config.pod_id) return endpoint.run(input) def print_job_status(config: Config, job: Job, once: bool = False) -> None: job_id = job.job_id result = _request_job_state(config, job_id) if not result: return def result_to_values(result: dict) -> dict[str, str]: output = result.get("output", {}) transcription:str = output.get("transcription_url", "...") diarization = output.get("diarization_url", "...") return { "status": STATUS_MAPPING[result["status"]], "transcription": transcription.removeprefix(r"b'").removesuffix(r"\n'"), "diarization": diarization.removeprefix(r"b'").removesuffix(r"\n'"), } values: dict[str, str] = result_to_values(result) def rebuild_table(): table = Table() table.add_column("Status") table.add_column("Time running") table.add_column("Job ID", overflow='fold') table.add_column("Diarization", overflow='fold') table.add_column("Transcription", overflow='fold') table.add_row( values.get("status", "unknown"), str(sw_current), job_id, values.get("diarization", "..."), values.get("transcription", "..."), ) return table sw_start: float = time.time() sw_current: timedelta = timedelta() with Live(get_renderable=rebuild_table, refresh_per_second=2): while True: if once: break if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS": once = True result = _request_job_state(config, job_id, silent=True) values = result_to_values(result) sw_current = timedelta(seconds=floor(time.time() - sw_start)) time.sleep(1) if "transcription" in values: config.console.print(f"[green]Transcript:[/green] {values['transcription']}") def _request_job_state(config: Config, id: str, silent: bool = False) -> dict: endpoint_health = f"{config.endpoint}/status/{id}" if silent: response = requests.get(endpoint_health, headers=config.headers) else: with config.console.status( f"[bold green]Requesting job[/bold green] {id}" " [bold green]status...[/bold green]" ): response = requests.get(endpoint_health, headers=config.headers) if response.status_code == 404: config.console.log(f"[red]Job[/red] {id} [red]not found on endpoint.[/red]") return {} if not response.ok: raise requests.exceptions.HTTPError() return response.json()