diff --git a/verbanote/loaders.py b/verbanote/loaders.py new file mode 100644 index 0000000..b6ad740 --- /dev/null +++ b/verbanote/loaders.py @@ -0,0 +1,31 @@ +import locale +from pathlib import Path +from whisper import Whisper +from pyannote.audio import Pipeline +import torch +import static_ffmpeg +import gdown + + +def prep() -> None: + locale.getpreferredencoding = lambda: "UTF-8" + # download and add ffmpeg to env + static_ffmpeg.add_paths() + +def audiofile(drive_url: str, path: str) -> Path | None: + if not drive_url: + return None + fn = Path.joinpath(Path(path), "interview") + gdown.download(drive_url, str(fn)) + return fn + +def diarization(access_token: str | None) -> Pipeline: + return Pipeline.from_pretrained( + "pyannote/speaker-diarization", use_auth_token=access_token + ) + + +def whisper() -> Whisper: + # LOAD MODEL INTO VRAM + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return whisper.load_model("large", device=device) diff --git a/verbanote/rp_handler.py b/verbanote/rp_handler.py index a72f565..6cd93fe 100644 --- a/verbanote/rp_handler.py +++ b/verbanote/rp_handler.py @@ -1,17 +1,31 @@ +from pathlib import Path import runpod +from runpod.serverless import os +import loaders -def is_even(job): - job_input = job["input"] - num = job_input["number"] +access_token = os.environ.get("VERBANOTE_HF_TOKEN") +output_path = os.environ.get("VERBANOTE_OUTPUT_PATH", "/transcriptions") +output_path = str(Path(output_path)) +input_path = os.environ.get("VERBANOTE_INPUT_PATH", "/audiofiles") +input_path = str(Path(input_path)) - if not isinstance(num, int): - return {"error": "Integer required."} +loaders.prep() +diarize_pipeline = loaders.diarization(access_token) +whisper_model = loaders.whisper() - if num % 2 == 0: - return True - return False +def handler(job): + input = job["input"] + audiofile = loaders.audiofile(input.get("file"), path = input_path) + if not audiofile: + return {"error": "missing audio file location"} + + return { + "speaker_timings": "s3-address-to-speakers", + "transcription_text": "s3-address-to-transcription", + "transcription_page": "web-address-to-deployment", + } if __name__ == "__main__": - runpod.serverless.start({"handler": is_even}) + runpod.serverless.start({"handler": handler})