Create modelling processes
This commit is contained in:
parent
6cf1da6ea2
commit
cb4b633c05
3 changed files with 163 additions and 9 deletions
|
@ -1,5 +1,6 @@
|
|||
import locale
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
from whisper import Whisper
|
||||
from pyannote.audio import Pipeline
|
||||
import torch
|
||||
|
@ -12,17 +13,37 @@ def prep() -> None:
|
|||
# download and add ffmpeg to env
|
||||
static_ffmpeg.add_paths()
|
||||
|
||||
def audiofile(drive_url: str, path: str) -> Path | None:
|
||||
|
||||
def audiofile(drive_url: str, path: Path) -> Path | None:
|
||||
if not drive_url:
|
||||
return None
|
||||
fn = Path.joinpath(Path(path), "interview")
|
||||
gdown.download(drive_url, str(fn))
|
||||
gdown.download(drive_url, "infile")
|
||||
fn = Path.joinpath(path, "interview.wav")
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
"{repr(video_path)}",
|
||||
"-vn",
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
"16000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-y",
|
||||
fn,
|
||||
]
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def diarization(access_token: str | None) -> Pipeline:
|
||||
return Pipeline.from_pretrained(
|
||||
pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization", use_auth_token=access_token
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return pipeline.to(device)
|
||||
|
||||
|
||||
def whisper() -> Whisper:
|
||||
|
|
110
verbanote/process.py
Normal file
110
verbanote/process.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import os
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from pyannote.audio import Pipeline
|
||||
from pydub import AudioSegment
|
||||
from whisper import Whisper
|
||||
|
||||
MILLISECONDS_TO_SPACE = 2000
|
||||
|
||||
|
||||
def diarize(audiofile: Path, pipeline: Pipeline, output_path: Path) -> Path:
|
||||
audiofile_prepended = _add_audio_silence(audiofile)
|
||||
|
||||
DEMO_FILE = {"uri": "blabla", "audio": audiofile_prepended}
|
||||
dz = pipeline(DEMO_FILE)
|
||||
|
||||
out_file = Path.joinpath(output_path, "diarization.txt")
|
||||
with open(out_file, "w") as text_file:
|
||||
text_file.write(str(dz))
|
||||
|
||||
print("Diarized:")
|
||||
print(*list(dz.itertracks(yield_label=True))[:10], sep="\n")
|
||||
|
||||
return out_file
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: Whisper,
|
||||
diarized_groups: list,
|
||||
output_path: Path,
|
||||
lang: str = "en",
|
||||
word_timestamps: bool = True,
|
||||
):
|
||||
for i in range(len(diarized_groups)):
|
||||
f = {Path.joinpath(output_path, str(i))}
|
||||
audio_f = f"{f}.wav"
|
||||
json_f = f"{f}.json"
|
||||
result = model.transcribe(
|
||||
audio=audio_f, language=lang, word_timestamps=word_timestamps
|
||||
)
|
||||
with open(json_f, "w") as outfile:
|
||||
json.dump(result, outfile, indent=4)
|
||||
|
||||
|
||||
def save_diarized_audio_files(
|
||||
diarization: Path, audiofile: Path, output_path: Path
|
||||
) -> list:
|
||||
groups = _group_speakers(diarization)
|
||||
_save_individual_audio_files(audiofile, groups, output_path)
|
||||
return groups
|
||||
|
||||
|
||||
def _add_audio_silence(audiofile) -> Path:
|
||||
spacermilli = MILLISECONDS_TO_SPACE
|
||||
spacer = AudioSegment.silent(duration=spacermilli)
|
||||
audio = AudioSegment.from_wav(audiofile)
|
||||
audio = spacer.append(audio, crossfade=0)
|
||||
out_file = Path.joinpath(Path(os.path.dirname(audiofile)), "interview_prepend.wav")
|
||||
audio.export(out_file, format="wav")
|
||||
|
||||
return out_file
|
||||
|
||||
|
||||
def _save_individual_audio_files(
|
||||
audiofile: Path, groups: list[str], output_path: Path
|
||||
) -> None:
|
||||
audio = AudioSegment.from_wav(audiofile)
|
||||
gidx = -1
|
||||
for g in groups:
|
||||
start = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[0])[0]
|
||||
end = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[-1])[1]
|
||||
start = _millisec(start) # - spacermilli
|
||||
end = _millisec(end) # - spacermilli
|
||||
gidx += 1
|
||||
audio[start:end].export(
|
||||
f"{Path.joinpath(output_path, str(gidx))}.wav", format="wav"
|
||||
)
|
||||
|
||||
|
||||
def _group_speakers(diarization_file: Path) -> list:
|
||||
dzs = open(diarization_file).read().splitlines()
|
||||
|
||||
groups: list = []
|
||||
g = []
|
||||
lastend = 0
|
||||
|
||||
for d in dzs:
|
||||
if g and (g[0].split()[-1] != d.split()[-1]): # same speaker
|
||||
groups.append(g)
|
||||
g = []
|
||||
|
||||
g.append(d)
|
||||
|
||||
end = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=d)[1]
|
||||
end = _millisec(end)
|
||||
if lastend > end: # segment engulfed by a previous segment
|
||||
groups.append(g)
|
||||
g = []
|
||||
else:
|
||||
lastend = end
|
||||
if g:
|
||||
groups.append(g)
|
||||
return groups
|
||||
|
||||
|
||||
def _millisec(timeStr):
|
||||
spl = timeStr.split(":")
|
||||
s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2])) * 1000)
|
||||
return s
|
|
@ -2,12 +2,15 @@ from pathlib import Path
|
|||
import runpod
|
||||
from runpod.serverless import os
|
||||
import loaders
|
||||
import process
|
||||
|
||||
|
||||
output_path = os.environ.get("VERBANOTE_OUTPUT_PATH", "/transcriptions")
|
||||
output_path = Path(output_path)
|
||||
input_path = os.environ.get("VERBANOTE_INPUT_PATH", "/audiofiles")
|
||||
input_path = Path(input_path)
|
||||
|
||||
access_token = os.environ.get("VERBANOTE_HF_TOKEN")
|
||||
output_path = os.environ.get("VERBANOTE_OUTPUT_PATH", "/transcriptions")
|
||||
output_path = str(Path(output_path))
|
||||
input_path = os.environ.get("VERBANOTE_INPUT_PATH", "/audiofiles")
|
||||
input_path = str(Path(input_path))
|
||||
|
||||
loaders.prep()
|
||||
diarize_pipeline = loaders.diarization(access_token)
|
||||
|
@ -20,12 +23,32 @@ def handler(job):
|
|||
if not audiofile:
|
||||
return {"error": "missing audio file location"}
|
||||
|
||||
diarized = process.diarize(audiofile, diarize_pipeline, output_path)
|
||||
diarized_groups = process.save_diarized_audio_files(
|
||||
diarized, audiofile, output_path
|
||||
)
|
||||
process.transcribe(
|
||||
model=whisper_model, diarized_groups=diarized_groups, output_path=output_path
|
||||
)
|
||||
|
||||
return {
|
||||
"speaker_timings": "s3-address-to-speakers",
|
||||
"transcription_text": "s3-address-to-transcription",
|
||||
"transcription_page": "web-address-to-deployment",
|
||||
}
|
||||
|
||||
# speakers = {
|
||||
# # speaker, textboxcolor, speaker color
|
||||
# "SPEAKER_00": ("SPEAKER00", "white", "darkgreen"),
|
||||
# "SPEAKER_01": ("SPEAKER01", "white", "darkorange"),
|
||||
# "SPEAKER_02": ("SPEAKER02", "white", "darkred"),
|
||||
# "SPEAKER_03": ("SPEAKER03", "white", "darkblue"),
|
||||
# "SPEAKER_04": ("SPEAKER04", "white", "darkyellow"),
|
||||
# "SPEAKER_05": ("SPEAKER05", "white", "lightgreen"),
|
||||
# "SPEAKER_06": ("SPEAKER06", "white", "lightred"),
|
||||
# "SPEAKER_07": ("SPEAKER07", "white", "lightblue"),
|
||||
# }
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runpod.serverless.start({"handler": handler})
|
||||
|
|
Loading…
Reference in a new issue