From cb4b633c050975c26597dd0e76147c8f8be4a794 Mon Sep 17 00:00:00 2001 From: Marty Oehme Date: Sun, 20 Aug 2023 14:29:36 +0200 Subject: [PATCH] Create modelling processes --- verbanote/loaders.py | 29 +++++++++-- verbanote/process.py | 110 ++++++++++++++++++++++++++++++++++++++++ verbanote/rp_handler.py | 33 ++++++++++-- 3 files changed, 163 insertions(+), 9 deletions(-) create mode 100644 verbanote/process.py diff --git a/verbanote/loaders.py b/verbanote/loaders.py index b6ad740..db8dd7d 100644 --- a/verbanote/loaders.py +++ b/verbanote/loaders.py @@ -1,5 +1,6 @@ import locale from pathlib import Path +import subprocess from whisper import Whisper from pyannote.audio import Pipeline import torch @@ -12,17 +13,37 @@ def prep() -> None: # download and add ffmpeg to env static_ffmpeg.add_paths() -def audiofile(drive_url: str, path: str) -> Path | None: + +def audiofile(drive_url: str, path: Path) -> Path | None: if not drive_url: return None - fn = Path.joinpath(Path(path), "interview") - gdown.download(drive_url, str(fn)) + gdown.download(drive_url, "infile") + fn = Path.joinpath(path, "interview.wav") + subprocess.run( + [ + "ffmpeg", + "-i", + "{repr(video_path)}", + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + "16000", + "-ac", + "1", + "-y", + fn, + ] + ) return fn + def diarization(access_token: str | None) -> Pipeline: - return Pipeline.from_pretrained( + pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization", use_auth_token=access_token ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return pipeline.to(device) def whisper() -> Whisper: diff --git a/verbanote/process.py b/verbanote/process.py new file mode 100644 index 0000000..c1534ef --- /dev/null +++ b/verbanote/process.py @@ -0,0 +1,110 @@ +import os +import re +import json +from pathlib import Path +from pyannote.audio import Pipeline +from pydub import AudioSegment +from whisper import Whisper + +MILLISECONDS_TO_SPACE = 2000 + + +def diarize(audiofile: Path, pipeline: Pipeline, output_path: Path) -> Path: + audiofile_prepended = _add_audio_silence(audiofile) + + DEMO_FILE = {"uri": "blabla", "audio": audiofile_prepended} + dz = pipeline(DEMO_FILE) + + out_file = Path.joinpath(output_path, "diarization.txt") + with open(out_file, "w") as text_file: + text_file.write(str(dz)) + + print("Diarized:") + print(*list(dz.itertracks(yield_label=True))[:10], sep="\n") + + return out_file + + +def transcribe( + model: Whisper, + diarized_groups: list, + output_path: Path, + lang: str = "en", + word_timestamps: bool = True, +): + for i in range(len(diarized_groups)): + f = {Path.joinpath(output_path, str(i))} + audio_f = f"{f}.wav" + json_f = f"{f}.json" + result = model.transcribe( + audio=audio_f, language=lang, word_timestamps=word_timestamps + ) + with open(json_f, "w") as outfile: + json.dump(result, outfile, indent=4) + + +def save_diarized_audio_files( + diarization: Path, audiofile: Path, output_path: Path +) -> list: + groups = _group_speakers(diarization) + _save_individual_audio_files(audiofile, groups, output_path) + return groups + + +def _add_audio_silence(audiofile) -> Path: + spacermilli = MILLISECONDS_TO_SPACE + spacer = AudioSegment.silent(duration=spacermilli) + audio = AudioSegment.from_wav(audiofile) + audio = spacer.append(audio, crossfade=0) + out_file = Path.joinpath(Path(os.path.dirname(audiofile)), "interview_prepend.wav") + audio.export(out_file, format="wav") + + return out_file + + +def _save_individual_audio_files( + audiofile: Path, groups: list[str], output_path: Path +) -> None: + audio = AudioSegment.from_wav(audiofile) + gidx = -1 + for g in groups: + start = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[0])[0] + end = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[-1])[1] + start = _millisec(start) # - spacermilli + end = _millisec(end) # - spacermilli + gidx += 1 + audio[start:end].export( + f"{Path.joinpath(output_path, str(gidx))}.wav", format="wav" + ) + + +def _group_speakers(diarization_file: Path) -> list: + dzs = open(diarization_file).read().splitlines() + + groups: list = [] + g = [] + lastend = 0 + + for d in dzs: + if g and (g[0].split()[-1] != d.split()[-1]): # same speaker + groups.append(g) + g = [] + + g.append(d) + + end = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=d)[1] + end = _millisec(end) + if lastend > end: # segment engulfed by a previous segment + groups.append(g) + g = [] + else: + lastend = end + if g: + groups.append(g) + return groups + + +def _millisec(timeStr): + spl = timeStr.split(":") + s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2])) * 1000) + return s diff --git a/verbanote/rp_handler.py b/verbanote/rp_handler.py index 6cd93fe..ad42b2e 100644 --- a/verbanote/rp_handler.py +++ b/verbanote/rp_handler.py @@ -2,12 +2,15 @@ from pathlib import Path import runpod from runpod.serverless import os import loaders +import process + + +output_path = os.environ.get("VERBANOTE_OUTPUT_PATH", "/transcriptions") +output_path = Path(output_path) +input_path = os.environ.get("VERBANOTE_INPUT_PATH", "/audiofiles") +input_path = Path(input_path) access_token = os.environ.get("VERBANOTE_HF_TOKEN") -output_path = os.environ.get("VERBANOTE_OUTPUT_PATH", "/transcriptions") -output_path = str(Path(output_path)) -input_path = os.environ.get("VERBANOTE_INPUT_PATH", "/audiofiles") -input_path = str(Path(input_path)) loaders.prep() diarize_pipeline = loaders.diarization(access_token) @@ -16,16 +19,36 @@ whisper_model = loaders.whisper() def handler(job): input = job["input"] - audiofile = loaders.audiofile(input.get("file"), path = input_path) + audiofile = loaders.audiofile(input.get("file"), path=input_path) if not audiofile: return {"error": "missing audio file location"} + diarized = process.diarize(audiofile, diarize_pipeline, output_path) + diarized_groups = process.save_diarized_audio_files( + diarized, audiofile, output_path + ) + process.transcribe( + model=whisper_model, diarized_groups=diarized_groups, output_path=output_path + ) + return { "speaker_timings": "s3-address-to-speakers", "transcription_text": "s3-address-to-transcription", "transcription_page": "web-address-to-deployment", } +# speakers = { +# # speaker, textboxcolor, speaker color +# "SPEAKER_00": ("SPEAKER00", "white", "darkgreen"), +# "SPEAKER_01": ("SPEAKER01", "white", "darkorange"), +# "SPEAKER_02": ("SPEAKER02", "white", "darkred"), +# "SPEAKER_03": ("SPEAKER03", "white", "darkblue"), +# "SPEAKER_04": ("SPEAKER04", "white", "darkyellow"), +# "SPEAKER_05": ("SPEAKER05", "white", "lightgreen"), +# "SPEAKER_06": ("SPEAKER06", "white", "lightred"), +# "SPEAKER_07": ("SPEAKER07", "white", "lightblue"), +# } + if __name__ == "__main__": runpod.serverless.start({"handler": handler})