Update to make basic functionality work
This commit is contained in:
parent
193f6b6f0c
commit
d3c79b8bf9
4 changed files with 67 additions and 34 deletions
|
@ -4,58 +4,69 @@ import requests
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
import runpod
|
||||||
|
from runpod.endpoint import Job
|
||||||
from network_functions import _upload_to_oxo
|
from network_functions import _upload_to_oxo
|
||||||
from job_functions import print_job_status
|
from job_functions import print_job_status, start_job
|
||||||
from configuration import Config
|
from configuration import Config
|
||||||
|
|
||||||
# TODO turn all this into config style options or @click-style flags/options
|
# TODO turn all this into config style options or @click-style flags/options
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
RUNPOD_API_URL = "https://api.runpod.ai/v2/"
|
||||||
|
|
||||||
|
|
||||||
@quick.gui_option()
|
@quick.gui_option()
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@click.option("--endpoint", "-e", help="URL of runpod serverless endpoint to use.")
|
@click.option("--pod-id", "-i", help="ID of runpod serverless endpoint to use.")
|
||||||
@click.option("--token", "-t", help="Access token for runpod instance.")
|
@click.option("--token", "-t", help="Access token for runpod instance.")
|
||||||
# TODO @click.version_option()
|
# TODO @click.version_option()
|
||||||
def cli(ctx, endpoint, token):
|
def cli(ctx, pod_id, token):
|
||||||
"""Verbanote
|
"""Verbanote
|
||||||
|
|
||||||
Transcribes any audio file given using OpenAI's whisper AI
|
Transcribes any audio file given using OpenAI's whisper AI
|
||||||
and pyannote for speaker detection.
|
and pyannote for speaker detection.
|
||||||
"""
|
"""
|
||||||
print(f"Token: {token}")
|
runpod.api_key = token
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {token}",
|
"Authorization": f"Bearer {token}",
|
||||||
}
|
}
|
||||||
options: Config = Config(
|
options: Config = Config(
|
||||||
endpoint=endpoint, token=token, headers=headers, console=Console()
|
endpoint=f"{RUNPOD_API_URL}{pod_id}",
|
||||||
|
pod_id=pod_id,
|
||||||
|
token=token,
|
||||||
|
headers=headers,
|
||||||
|
console=Console(),
|
||||||
)
|
)
|
||||||
ctx.obj = options
|
ctx.obj = options
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
|
@click.option("--language", "-l", help="Language to use for transcription in 2-letter ISO code (`de` or `fr`). Defaults to `en`.")
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"audiofile",
|
"audiofile",
|
||||||
type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path),
|
type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path),
|
||||||
)
|
)
|
||||||
def start(config: Config, audiofile: Path) -> None:
|
def start(config: Config, language: str, audiofile: Path) -> None:
|
||||||
"""Start processing the given audiofile.
|
"""Start processing the given audiofile.
|
||||||
|
|
||||||
Queues a job in the processing queue of the AI api.
|
Queues a job in the processing queue of the AI api.
|
||||||
"""
|
"""
|
||||||
endpoint_new_job = f"{config.endpoint}/run"
|
with config.console.status("[bold green]Uploading data to 0x0.st...\n"):
|
||||||
with config.console.status("[bold green]Uploading data..."):
|
|
||||||
url = _upload_to_oxo(audiofile)
|
url = _upload_to_oxo(audiofile)
|
||||||
|
|
||||||
input_data = {"input": {"url": url}}
|
input_data = {"url": url, "lang": language}
|
||||||
config.console.log(f"[green]Requesting new job for[/green] {audiofile}...")
|
config.console.print(
|
||||||
response = requests.post(endpoint_new_job, json=input_data, headers=config.headers)
|
f"[green]Requesting new job for[/green] {audiofile} over {url}..."
|
||||||
job_id = response.json()["id"]
|
)
|
||||||
config.console.log(f"[green]Job[/green] {job_id} [green]has been queued.[/green]")
|
job = start_job(config, input_data)
|
||||||
print_job_status(config, job_id)
|
config.console.print(
|
||||||
|
f"[green]Job[/green] {job.job_id} [green]has been queued.[/green]"
|
||||||
|
)
|
||||||
|
print_job_status(config, job)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
@ -72,7 +83,8 @@ def health(config: Config) -> None:
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
@click.argument("job_id")
|
@click.argument("job_id")
|
||||||
def job(config: Config, job_id: str) -> None:
|
def job(config: Config, job_id: str) -> None:
|
||||||
print_job_status(config, job_id)
|
job = Job(config.pod_id, job_id)
|
||||||
|
print_job_status(config, job, once=True)
|
||||||
|
|
||||||
|
|
||||||
def cancel(config: Config, job_id: str) -> None:
|
def cancel(config: Config, job_id: str) -> None:
|
||||||
|
|
|
@ -4,6 +4,7 @@ from rich.console import Console
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
|
pod_id: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
token: str
|
token: str
|
||||||
headers: dict[str, str]
|
headers: dict[str, str]
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
import time
|
import time
|
||||||
import requests
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from math import floor
|
from math import floor
|
||||||
from rich.table import Table
|
from typing import Any
|
||||||
from rich.live import Live
|
|
||||||
|
import requests
|
||||||
|
import runpod
|
||||||
from configuration import Config
|
from configuration import Config
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.table import Table
|
||||||
|
from runpod.endpoint import Job
|
||||||
|
|
||||||
STATUS_MAPPING = {
|
STATUS_MAPPING = {
|
||||||
"IN_QUEUE": "[yellow]queued[/yellow]",
|
"IN_QUEUE": "[yellow]queued[/yellow]",
|
||||||
|
@ -15,26 +19,36 @@ STATUS_MAPPING = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def print_job_status(config: Config, job_id: str, once:bool = False) -> None:
|
def start_job(config: Config, input: Any) -> Job:
|
||||||
|
endpoint = runpod.Endpoint(config.pod_id)
|
||||||
|
return endpoint.run(input)
|
||||||
|
|
||||||
|
|
||||||
|
def print_job_status(config: Config, job: Job, once: bool = False) -> None:
|
||||||
|
job_id = job.job_id
|
||||||
result = _request_job_state(config, job_id)
|
result = _request_job_state(config, job_id)
|
||||||
if not result:
|
if not result:
|
||||||
return
|
return
|
||||||
|
|
||||||
def result_to_values(result: dict) -> dict[str, str]:
|
def result_to_values(result: dict) -> dict[str, str]:
|
||||||
|
output = result.get("output", {})
|
||||||
|
transcription:str = output.get("transcription_url", "...")
|
||||||
|
diarization = output.get("diarization_url", "...")
|
||||||
return {
|
return {
|
||||||
"status": STATUS_MAPPING[result["status"]],
|
"status": STATUS_MAPPING[result["status"]],
|
||||||
"transcription": result.get("transcription_url", "..."),
|
"transcription": transcription.removeprefix(r"b'").removesuffix(r"\n'"),
|
||||||
"diarization": result.get("diarization_url", "..."),
|
"diarization": diarization.removeprefix(r"b'").removesuffix(r"\n'"),
|
||||||
}
|
}
|
||||||
|
|
||||||
values: dict[str, str] = result_to_values(result)
|
values: dict[str, str] = result_to_values(result)
|
||||||
|
|
||||||
def rebuild_table():
|
def rebuild_table():
|
||||||
table = Table()
|
table = Table()
|
||||||
table.add_column("Status")
|
table.add_column("Status")
|
||||||
table.add_column("Time running")
|
table.add_column("Time running")
|
||||||
table.add_column("Job ID")
|
table.add_column("Job ID", overflow='fold')
|
||||||
table.add_column("Diarization")
|
table.add_column("Diarization", overflow='fold')
|
||||||
table.add_column("Transcription")
|
table.add_column("Transcription", overflow='fold')
|
||||||
table.add_row(
|
table.add_row(
|
||||||
values.get("status", "unknown"),
|
values.get("status", "unknown"),
|
||||||
str(sw_current),
|
str(sw_current),
|
||||||
|
@ -47,15 +61,21 @@ def print_job_status(config: Config, job_id: str, once:bool = False) -> None:
|
||||||
sw_start: float = time.time()
|
sw_start: float = time.time()
|
||||||
sw_current: timedelta = timedelta()
|
sw_current: timedelta = timedelta()
|
||||||
with Live(get_renderable=rebuild_table, refresh_per_second=2):
|
with Live(get_renderable=rebuild_table, refresh_per_second=2):
|
||||||
while not once:
|
while True:
|
||||||
|
if once:
|
||||||
|
break
|
||||||
|
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
|
||||||
|
once = True
|
||||||
|
|
||||||
result = _request_job_state(config, job_id, silent=True)
|
result = _request_job_state(config, job_id, silent=True)
|
||||||
values: dict[str, str] = result_to_values(result)
|
values = result_to_values(result)
|
||||||
sw_current = timedelta(seconds=floor(time.time() - sw_start))
|
sw_current = timedelta(seconds=floor(time.time() - sw_start))
|
||||||
|
|
||||||
if result["status"] != "IN_QUEUE" and result["status"] != "IN_PROGRESS":
|
|
||||||
break
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
if "transcription" in values:
|
||||||
|
config.console.print(f"[green]Transcript:[/green] {values['transcription']}")
|
||||||
|
|
||||||
|
|
||||||
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
|
def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
|
||||||
endpoint_health = f"{config.endpoint}/status/{id}"
|
endpoint_health = f"{config.endpoint}/status/{id}"
|
||||||
|
@ -73,4 +93,3 @@ def _request_job_state(config: Config, id: str, silent: bool = False) -> dict:
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
raise requests.exceptions.HTTPError()
|
raise requests.exceptions.HTTPError()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ def _upload_to_oxo(file: Path, url: str = "https://0x0.st", expires: int = 2) ->
|
||||||
)
|
)
|
||||||
if not resp.ok:
|
if not resp.ok:
|
||||||
raise requests.exceptions.HTTPError()
|
raise requests.exceptions.HTTPError()
|
||||||
console.log(f"Uploaded file {file} to {str(resp.content)}")
|
url = resp.content.decode().strip()
|
||||||
return str(resp.content).strip()
|
console.log(f"Uploaded file {file} to {url}")
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue