Create simple loaders infrastructure
Loaders for models, necessary ffmpeg binaries and input files.
This commit is contained in:
parent
540128bc97
commit
7e91b7a1a2
2 changed files with 54 additions and 9 deletions
31
verbanote/loaders.py
Normal file
31
verbanote/loaders.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import locale
|
||||||
|
from pathlib import Path
|
||||||
|
from whisper import Whisper
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
import torch
|
||||||
|
import static_ffmpeg
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
|
||||||
|
def prep() -> None:
|
||||||
|
locale.getpreferredencoding = lambda: "UTF-8"
|
||||||
|
# download and add ffmpeg to env
|
||||||
|
static_ffmpeg.add_paths()
|
||||||
|
|
||||||
|
def audiofile(drive_url: str, path: str) -> Path | None:
|
||||||
|
if not drive_url:
|
||||||
|
return None
|
||||||
|
fn = Path.joinpath(Path(path), "interview")
|
||||||
|
gdown.download(drive_url, str(fn))
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def diarization(access_token: str | None) -> Pipeline:
|
||||||
|
return Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization", use_auth_token=access_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def whisper() -> Whisper:
|
||||||
|
# LOAD MODEL INTO VRAM
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
return whisper.load_model("large", device=device)
|
|
@ -1,17 +1,31 @@
|
||||||
|
from pathlib import Path
|
||||||
import runpod
|
import runpod
|
||||||
|
from runpod.serverless import os
|
||||||
|
import loaders
|
||||||
|
|
||||||
def is_even(job):
|
access_token = os.environ.get("VERBANOTE_HF_TOKEN")
|
||||||
job_input = job["input"]
|
output_path = os.environ.get("VERBANOTE_OUTPUT_PATH", "/transcriptions")
|
||||||
num = job_input["number"]
|
output_path = str(Path(output_path))
|
||||||
|
input_path = os.environ.get("VERBANOTE_INPUT_PATH", "/audiofiles")
|
||||||
|
input_path = str(Path(input_path))
|
||||||
|
|
||||||
if not isinstance(num, int):
|
loaders.prep()
|
||||||
return {"error": "Integer required."}
|
diarize_pipeline = loaders.diarization(access_token)
|
||||||
|
whisper_model = loaders.whisper()
|
||||||
|
|
||||||
if num % 2 == 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
def handler(job):
|
||||||
|
input = job["input"]
|
||||||
|
audiofile = loaders.audiofile(input.get("file"), path = input_path)
|
||||||
|
if not audiofile:
|
||||||
|
return {"error": "missing audio file location"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"speaker_timings": "s3-address-to-speakers",
|
||||||
|
"transcription_text": "s3-address-to-transcription",
|
||||||
|
"transcription_page": "web-address-to-deployment",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
runpod.serverless.start({"handler": is_even})
|
runpod.serverless.start({"handler": handler})
|
||||||
|
|
Loading…
Reference in a new issue