Update to make basic functionality work

This commit is contained in:
Marty Oehme 2024-03-13 23:51:41 +01:00
parent 193f6b6f0c
commit d3c79b8bf9
Signed by: Marty
GPG key ID: EDBF2ED917B2EF6A
4 changed files with 67 additions and 34 deletions

View file

@ -4,58 +4,69 @@ import requests
import logging
from pathlib import Path
from rich.console import Console
import runpod
from runpod.endpoint import Job
from network_functions import _upload_to_oxo
from job_functions import print_job_status
from job_functions import print_job_status, start_job
from configuration import Config
# TODO turn all this into config style options or @click-style flags/options
logging.basicConfig(level=logging.INFO)
RUNPOD_API_URL = "https://api.runpod.ai/v2/"
@quick.gui_option()
@click.group()
@click.pass_context
@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint to use.")
@click.option("--pod-id", "-i", help="ID of runpod serverless endpoint to use.")
@click.option("--token", "-t", help="Access token for runpod instance.")
# TODO @click.version_option()
def cli(ctx, endpoint, token):
def cli(ctx, pod_id, token):
"""Verbanote
Transcribes any audio file given using OpenAI's whisper AI
and pyannote for speaker detection.
"""
print(f"Token: {token}")
runpod.api_key = token
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
options: Config = Config(
endpoint=endpoint, token=token, headers=headers, console=Console()
endpoint=f"{RUNPOD_API_URL}{pod_id}",
pod_id=pod_id,
token=token,
headers=headers,
console=Console(),
)
ctx.obj = options
@cli.command()
@click.pass_obj
@click.option("--language", "-l", help="Language to use for transcription in 2-letter ISO code (`de` or `fr`). Defaults to `en`.")
@click.argument(
"audiofile",
type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path),
)
def start(config: Config, audiofile: Path) -> None:
def start(config: Config, language: str, audiofile: Path) -> None:
"""Start processing the given audiofile.
Queues a job in the processing queue of the AI api.
"""
endpoint_new_job = f"{config.endpoint}/run"
with config.console.status("[bold green]Uploading data..."):
with config.console.status("[bold green]Uploading data to 0x0.st...\n"):
url = _upload_to_oxo(audiofile)
input_data = {"input": {"url": url}}
config.console.log(f"[green]Requesting new job for[/green] {audiofile}...")
response = requests.post(endpoint_new_job, json=input_data, headers=config.headers)
job_id = response.json()["id"]
config.console.log(f"[green]Job[/green] {job_id} [green]has been queued.[/green]")
print_job_status(config, job_id)
input_data = {"url": url, "lang": language}
config.console.print(
f"[green]Requesting new job for[/green] {audiofile} over {url}..."
)
job = start_job(config, input_data)
config.console.print(
f"[green]Job[/green] {job.job_id} [green]has been queued.[/green]"
)
print_job_status(config, job)
@cli.command()
@ -72,7 +83,8 @@ def health(config: Config) -> None:
@click.pass_obj
@click.argument("job_id")
def job(config: Config, job_id: str) -> None:
print_job_status(config, job_id)
job = Job(config.pod_id, job_id)
print_job_status(config, job, once=True)
def cancel(config: Config, job_id: str) -> None:

View file

@ -4,6 +4,7 @@ from rich.console import Console
@dataclass
class Config:
pod_id: str
endpoint: str
token: str
headers: dict[str, str]

View file

@ -1,10 +1,14 @@
import time
import requests
from datetime import timedelta
from math import floor
from rich.table import Table
from rich.live import Live
from typing import Any
import requests
import runpod
from configuration import Config
from rich.live import Live
from rich.table import Table
from runpod.endpoint import Job
STATUS_MAPPING = {
"IN_QUEUE": "[yellow]queued[/yellow]",
@ -15,26 +19,36 @@ STATUS_MAPPING = {
}
def print_job_status(config: Config, job_id: str, once:bool = False) -> None:
def start_job(config: Config, input: Any) -> Job:
endpoint = runpod.Endpoint(config.pod_id)
return endpoint.run(input)
def print_job_status(config: Config, job: Job, once: bool = False) -> None:
job_id = job.job_id
result = _request_job_state(config, job_id)
if not result:
return
def result_to_values(result: dict) -> dict[str, str]:
output = result.get("output", {})
transcription:str = output.get("transcription_url", "...")
diarization = output.get("diarization_url", "...")
return {
"status": STATUS_MAPPING[result["status"]],
"transcription": result.get("transcription_url", "..."),
"diarization": result.get("diarization_url", "..."),
"transcription": transcription.removeprefix(r"b'").removesuffix(r"\n'"),
"diarization": diarization.removeprefix(r"b'").removesuffix(r"\n'"),
}
values: dict[str, str] = result_to_values(result)
def rebuild_table():
table = Table()
table.add_column("Status")
table.add_column("Time running")
table.add_column("Job ID")
table.add_column("Diarization")
table.add_column("Transcription")
table.add_column("Job ID", overflow='fold')
table.add_column("Diarization", overflow='fold')
table.add_column("Transcription", overflow='fold')
table.add_row(
values.get("status", "unknown"),
str(sw_current),
@ -47,15 +61,21 @@ def print_job_status(config: Config, job_id: str, once:bool = False) -> None:
sw_start: float = time.time()
sw_current: timedelta = timedelta()
with Live(get_renderable=rebuild_table, refresh_per_second=2):
while not once:
while True:
if once:
break
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
once = True
result = _request_job_state(config, job_id, silent=True)
values: dict[str, str] = result_to_values(result)
values = result_to_values(result)
sw_current = timedelta(seconds=floor(time.time() - sw_start))
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
break
time.sleep(1)
if "transcription" in values:
config.console.print(f"[green]Transcript:[/green] {values['transcription']}")
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
endpoint_health = f"{config.endpoint}/status/{id}"
@ -73,4 +93,3 @@ def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
if not response.ok:
raise requests.exceptions.HTTPError()
return response.json()

View file

@ -11,6 +11,7 @@ def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) ->
)
if not resp.ok:
raise requests.exceptions.HTTPError()
console.log(f"Uploaded file {file} to {str(resp.content)}")
return str(resp.content).strip()
url = resp.content.decode().strip()
console.log(f"Uploaded file {file} to {url}")
return url