Create simple loaders infrastructure

Loaders for models, necessary ffmpeg binaries and input files.
This commit is contained in:
Marty Oehme 2023-08-20 13:29:13 +02:00
parent 540128bc97
commit 7e91b7a1a2
Signed by: Marty
GPG key ID: EDBF2ED917B2EF6A
2 changed files with 54 additions and 9 deletions

31
verbanote/loaders.py Normal file
View file

@ -0,0 +1,31 @@
import locale
from pathlib import Path
from whisper import Whisper
from pyannote.audio import Pipeline
import torch
import static_ffmpeg
import gdown
def prep() -> None:
locale.getpreferredencoding = lambda: "UTF-8"
# download and add ffmpeg to env
static_ffmpeg.add_paths()
def audiofile(drive_url: str, path: str) -> Path | None:
if not drive_url:
return None
fn = Path.joinpath(Path(path), "interview")
gdown.download(drive_url, str(fn))
return fn
def diarization(access_token: str | None) -> Pipeline:
return Pipeline.from_pretrained(
"pyannote/speaker-diarization", use_auth_token=access_token
)
def whisper() -> Whisper:
# LOAD MODEL INTO VRAM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return whisper.load_model("large", device=device)

View file

@ -1,17 +1,31 @@
from pathlib import Path
import runpod
from runpod.serverless import os
import loaders
def is_even(job):
job_input = job["input"]
num = job_input["number"]
access_token = os.environ.get("VERBANOTE_HF_TOKEN")
output_path = os.environ.get("VERBANOTE_OUTPUT_PATH", "/transcriptions")
output_path = str(Path(output_path))
input_path = os.environ.get("VERBANOTE_INPUT_PATH", "/audiofiles")
input_path = str(Path(input_path))
if not isinstance(num, int):
return {"error": "Integer required."}
loaders.prep()
diarize_pipeline = loaders.diarization(access_token)
whisper_model = loaders.whisper()
if num % 2 == 0:
return True
return False
def handler(job):
input = job["input"]
audiofile = loaders.audiofile(input.get("file"), path = input_path)
if not audiofile:
return {"error": "missing audio file location"}
return {
"speaker_timings": "s3-address-to-speakers",
"transcription_text": "s3-address-to-transcription",
"transcription_page": "web-address-to-deployment",
}
if __name__ == "__main__":
runpod.serverless.start({"handler": is_even})
runpod.serverless.start({"handler": handler})