Reinsert whisper model transcribing
This commit is contained in:
parent
3246469be2
commit
30fd269cd9
6 changed files with 483 additions and 177 deletions
|
|
@ -1,11 +1,12 @@
|
|||
import locale
|
||||
from pathlib import Path
|
||||
# from whisper import Whisper
|
||||
from whisper import Whisper, load_model
|
||||
from pyannote.audio import Pipeline
|
||||
import torch
|
||||
import static_ffmpeg
|
||||
import file_operations
|
||||
|
||||
|
||||
def prep() -> None:
|
||||
locale.getpreferredencoding = lambda: "UTF-8"
|
||||
# download and add ffmpeg to env
|
||||
|
|
@ -26,8 +27,8 @@ def diarization(access_token: str | None) -> Pipeline:
|
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return pipeline.to(device)
|
||||
|
||||
#
|
||||
# def whisper() -> Whisper:
|
||||
# # LOAD MODEL INTO VRAM
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# return whisper.load_model("large", device=device)
|
||||
|
||||
def whispermodel() -> Whisper:
|
||||
# LOAD MODEL INTO VRAM
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return load_model("large", device=device)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
from pathlib import Path
|
||||
from pyannote.audio import Pipeline
|
||||
from pydub import AudioSegment
|
||||
# from whisper import Whisper
|
||||
from whisper import Whisper
|
||||
|
||||
MILLISECONDS_TO_SPACE = 2000
|
||||
|
||||
|
|
@ -25,22 +25,53 @@ def diarize(audiofile: Path, pipeline: Pipeline, output_path: Path) -> Path:
|
|||
return out_file
|
||||
|
||||
|
||||
# def transcribe(
|
||||
# model: Whisper,
|
||||
# diarized_groups: list,
|
||||
# output_path: Path,
|
||||
# lang: str = "en",
|
||||
# word_timestamps: bool = True,
|
||||
# ):
|
||||
# for i in range(len(diarized_groups)):
|
||||
# f = {Path.joinpath(output_path, str(i))}
|
||||
# audio_f = f"{f}.wav"
|
||||
# json_f = f"{f}.json"
|
||||
# result = model.transcribe(
|
||||
# audio=audio_f, language=lang, word_timestamps=word_timestamps
|
||||
# )
|
||||
# with open(json_f, "w") as outfile:
|
||||
# json.dump(result, outfile, indent=4)
|
||||
def transcribe(
|
||||
model: Whisper,
|
||||
diarized_groups: list,
|
||||
output_path: Path,
|
||||
lang: str = "en",
|
||||
word_timestamps: bool = True,
|
||||
) -> None:
|
||||
for i in range(len(diarized_groups)):
|
||||
f = {Path.joinpath(output_path, str(i))}
|
||||
audio_f = f"{f}.wav"
|
||||
json_f = f"{f}.json"
|
||||
result = model.transcribe(
|
||||
audio=audio_f, language=lang, word_timestamps=word_timestamps
|
||||
)
|
||||
with open(json_f, "w") as outfile:
|
||||
json.dump(result, outfile, indent=4)
|
||||
|
||||
|
||||
# TODO clean up this mess
|
||||
def output_txt(diarized_groups: list, transcription_path: Path) -> str:
|
||||
txt = list("")
|
||||
gidx = -1
|
||||
for g in diarized_groups:
|
||||
shift = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[0])[0]
|
||||
shift = (
|
||||
_millisec(shift) - MILLISECONDS_TO_SPACE
|
||||
) # the start time in the original video
|
||||
shift = max(shift, 0)
|
||||
|
||||
gidx += 1
|
||||
|
||||
with open(f"{Path.joinpath(transcription_path, str(gidx))}.json") as f:
|
||||
captions = json.load(f)["segments"]
|
||||
|
||||
if captions:
|
||||
speaker = g[0].split()[-1]
|
||||
|
||||
for c in captions:
|
||||
txt.append(f"[{speaker}] {c['text']}\n")
|
||||
txt.append("\n")
|
||||
|
||||
output = "".join(txt)
|
||||
with open(
|
||||
Path.joinpath(transcription_path, "capspeaker.txt"), "w", encoding="utf-8"
|
||||
) as file:
|
||||
file.write(output)
|
||||
return output
|
||||
|
||||
|
||||
def save_diarized_audio_files(
|
||||
|
|
|
|||
|
|
@ -16,12 +16,14 @@ access_token: str = os.environ.get("VERBANOTE_HF_TOKEN", "")
|
|||
|
||||
loaders.prep()
|
||||
diarize_pipeline = loaders.diarization(access_token)
|
||||
# whisper_model = loaders.whisper()
|
||||
whisper_model = loaders.whispermodel()
|
||||
|
||||
|
||||
def handler(job):
|
||||
input: dict = job["input"]
|
||||
url: str | None = input.get("url")
|
||||
lang: str | None = input.get("lang")
|
||||
word_timestamps: str | None = input.get("word_timestamps")
|
||||
|
||||
if not url:
|
||||
return {"error": "no file link provided"}
|
||||
|
|
@ -35,17 +37,21 @@ def handler(job):
|
|||
diarized_groups = process.save_diarized_audio_files(
|
||||
diarized, audiofile, output_path
|
||||
)
|
||||
uploaded_file: str = file_operations.upload_to_oxo(file=diarized, expires=1)
|
||||
# process.transcribe(
|
||||
# model=whisper_model, diarized_groups=diarized_groups, output_path=output_path
|
||||
# )
|
||||
uploaded_diarization: str = file_operations.upload_to_oxo(file=diarized, expires=1)
|
||||
process.transcribe(
|
||||
model=whisper_model,
|
||||
diarized_groups=diarized_groups,
|
||||
output_path=output_path,
|
||||
lang=lang or "fr",
|
||||
word_timestamps=word_timestamps or True,
|
||||
)
|
||||
transcription = process.output_txt(diarized_groups, output_path)
|
||||
|
||||
return {
|
||||
"speaker_timings": "s3-address-to-speakers",
|
||||
"transcription_text": "s3-address-to-transcription",
|
||||
"transcription_page": "web-address-to-deployment",
|
||||
"audiofile_path": str(audiofile),
|
||||
"audio_url": uploaded_file,
|
||||
"audiofile": str(audiofile),
|
||||
"diarization_url": uploaded_diarization,
|
||||
"diarization": diarized_groups,
|
||||
"transcription_text": transcription,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue