Reinsert whisper model transcribing

This commit is contained in:
Marty Oehme 2023-08-23 13:22:55 +02:00
parent 3246469be2
commit 30fd269cd9
Signed by: Marty
GPG key ID: EDBF2ED917B2EF6A
6 changed files with 483 additions and 177 deletions

View file

@ -1,11 +1,12 @@
import locale
from pathlib import Path
# from whisper import Whisper
from whisper import Whisper, load_model
from pyannote.audio import Pipeline
import torch
import static_ffmpeg
import file_operations
def prep() -> None:
locale.getpreferredencoding = lambda: "UTF-8"
# download and add ffmpeg to env
@ -26,8 +27,8 @@ def diarization(access_token: str | None) -> Pipeline:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return pipeline.to(device)
#
# def whisper() -> Whisper:
# # LOAD MODEL INTO VRAM
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# return whisper.load_model("large", device=device)
def whispermodel() -> Whisper:
# LOAD MODEL INTO VRAM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return load_model("large", device=device)

View file

@ -4,7 +4,7 @@ import json
from pathlib import Path
from pyannote.audio import Pipeline
from pydub import AudioSegment
# from whisper import Whisper
from whisper import Whisper
MILLISECONDS_TO_SPACE = 2000
@ -25,22 +25,53 @@ def diarize(audiofile: Path, pipeline: Pipeline, output_path: Path) -> Path:
return out_file
# def transcribe(
# model: Whisper,
# diarized_groups: list,
# output_path: Path,
# lang: str = "en",
# word_timestamps: bool = True,
# ):
# for i in range(len(diarized_groups)):
# f = {Path.joinpath(output_path, str(i))}
# audio_f = f"{f}.wav"
# json_f = f"{f}.json"
# result = model.transcribe(
# audio=audio_f, language=lang, word_timestamps=word_timestamps
# )
# with open(json_f, "w") as outfile:
# json.dump(result, outfile, indent=4)
def transcribe(
model: Whisper,
diarized_groups: list,
output_path: Path,
lang: str = "en",
word_timestamps: bool = True,
) -> None:
for i in range(len(diarized_groups)):
f = {Path.joinpath(output_path, str(i))}
audio_f = f"{f}.wav"
json_f = f"{f}.json"
result = model.transcribe(
audio=audio_f, language=lang, word_timestamps=word_timestamps
)
with open(json_f, "w") as outfile:
json.dump(result, outfile, indent=4)
# TODO clean up this mess
def output_txt(diarized_groups: list, transcription_path: Path) -> str:
txt = list("")
gidx = -1
for g in diarized_groups:
shift = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[0])[0]
shift = (
_millisec(shift) - MILLISECONDS_TO_SPACE
) # the start time in the original video
shift = max(shift, 0)
gidx += 1
with open(f"{Path.joinpath(transcription_path, str(gidx))}.json") as f:
captions = json.load(f)["segments"]
if captions:
speaker = g[0].split()[-1]
for c in captions:
txt.append(f"[{speaker}] {c['text']}\n")
txt.append("\n")
output = "".join(txt)
with open(
Path.joinpath(transcription_path, "capspeaker.txt"), "w", encoding="utf-8"
) as file:
file.write(output)
return output
def save_diarized_audio_files(

View file

@ -16,12 +16,14 @@ access_token: str = os.environ.get("VERBANOTE_HF_TOKEN", "")
loaders.prep()
diarize_pipeline = loaders.diarization(access_token)
# whisper_model = loaders.whisper()
whisper_model = loaders.whispermodel()
def handler(job):
input: dict = job["input"]
url: str | None = input.get("url")
lang: str | None = input.get("lang")
word_timestamps: str | None = input.get("word_timestamps")
if not url:
return {"error": "no file link provided"}
@ -35,17 +37,21 @@ def handler(job):
diarized_groups = process.save_diarized_audio_files(
diarized, audiofile, output_path
)
uploaded_file: str = file_operations.upload_to_oxo(file=diarized, expires=1)
# process.transcribe(
# model=whisper_model, diarized_groups=diarized_groups, output_path=output_path
# )
uploaded_diarization: str = file_operations.upload_to_oxo(file=diarized, expires=1)
process.transcribe(
model=whisper_model,
diarized_groups=diarized_groups,
output_path=output_path,
lang=lang or "fr",
word_timestamps=word_timestamps or True,
)
transcription = process.output_txt(diarized_groups, output_path)
return {
"speaker_timings": "s3-address-to-speakers",
"transcription_text": "s3-address-to-transcription",
"transcription_page": "web-address-to-deployment",
"audiofile_path": str(audiofile),
"audio_url": uploaded_file,
"audiofile": str(audiofile),
"diarization_url": uploaded_diarization,
"diarization": diarized_groups,
"transcription_text": transcription,
}