2023-08-23 13:10:41 +00:00
|
|
|
import logging
|
2023-08-20 12:29:36 +00:00
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import json
|
2023-08-23 13:09:53 +00:00
|
|
|
from dataclasses import dataclass
|
2023-08-20 12:29:36 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from pyannote.audio import Pipeline
|
|
|
|
from pydub import AudioSegment
|
2023-08-23 11:22:55 +00:00
|
|
|
from whisper import Whisper
|
2023-08-20 12:29:36 +00:00
|
|
|
|
|
|
|
MILLISECONDS_TO_SPACE = 2000
|
|
|
|
|
|
|
|
|
2023-08-23 13:09:53 +00:00
|
|
|
@dataclass
|
|
|
|
class TxtTranscription:
|
|
|
|
text: str
|
|
|
|
file: Path
|
|
|
|
|
|
|
|
|
2023-08-20 12:29:36 +00:00
|
|
|
def diarize(audiofile: Path, pipeline: Pipeline, output_path: Path) -> Path:
|
|
|
|
audiofile_prepended = _add_audio_silence(audiofile)
|
|
|
|
|
2023-08-23 15:11:47 +00:00
|
|
|
logging.info(f"Beginning diarization of {audiofile}...")
|
2023-08-22 12:27:52 +00:00
|
|
|
DIARIZE_FILE = {"uri": "not-important", "audio": audiofile_prepended}
|
|
|
|
dz = pipeline(DIARIZE_FILE)
|
2023-08-20 12:29:36 +00:00
|
|
|
|
|
|
|
out_file = Path.joinpath(output_path, "diarization.txt")
|
|
|
|
with open(out_file, "w") as text_file:
|
|
|
|
text_file.write(str(dz))
|
2023-08-23 15:11:47 +00:00
|
|
|
logging.info(f"Created diarization in {out_file}.")
|
2023-08-20 12:29:36 +00:00
|
|
|
|
|
|
|
return out_file
|
|
|
|
|
|
|
|
|
2023-08-23 11:22:55 +00:00
|
|
|
def transcribe(
|
|
|
|
model: Whisper,
|
|
|
|
diarized_groups: list,
|
2023-08-23 13:07:48 +00:00
|
|
|
files_path: Path,
|
2023-08-23 11:22:55 +00:00
|
|
|
lang: str = "en",
|
|
|
|
word_timestamps: bool = True,
|
|
|
|
) -> None:
|
|
|
|
for i in range(len(diarized_groups)):
|
2023-08-23 13:11:44 +00:00
|
|
|
audio_f = Path.joinpath(files_path, f"{str(i)}.wav")
|
|
|
|
json_f = Path.joinpath(files_path, f"{str(i)}.json")
|
2023-08-23 13:10:41 +00:00
|
|
|
logging.info(f"Starting transcription of {str(audio_f)}...")
|
2023-08-23 11:22:55 +00:00
|
|
|
result = model.transcribe(
|
2023-08-23 13:11:44 +00:00
|
|
|
audio=str(audio_f), language=lang, word_timestamps=word_timestamps
|
2023-08-23 11:22:55 +00:00
|
|
|
)
|
|
|
|
with open(json_f, "w") as outfile:
|
|
|
|
json.dump(result, outfile, indent=4)
|
2023-08-23 13:10:41 +00:00
|
|
|
logging.info(f"Transcription written to {str(json_f)}.")
|
2023-08-23 11:22:55 +00:00
|
|
|
|
|
|
|
|
|
|
|
# TODO clean up this mess
|
2023-08-23 13:09:53 +00:00
|
|
|
def output_txt(diarized_groups: list, transcription_path: Path) -> TxtTranscription:
|
2023-08-23 11:22:55 +00:00
|
|
|
txt = list("")
|
|
|
|
gidx = -1
|
|
|
|
for g in diarized_groups:
|
|
|
|
shift = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[0])[0]
|
|
|
|
shift = (
|
|
|
|
_millisec(shift) - MILLISECONDS_TO_SPACE
|
|
|
|
) # the start time in the original video
|
|
|
|
shift = max(shift, 0)
|
|
|
|
|
|
|
|
gidx += 1
|
|
|
|
|
2023-08-23 13:11:44 +00:00
|
|
|
fname = Path.joinpath(transcription_path, f"{str(gidx)}.json")
|
|
|
|
with open(fname) as f:
|
2023-08-23 11:22:55 +00:00
|
|
|
captions = json.load(f)["segments"]
|
2023-08-23 13:10:41 +00:00
|
|
|
logging.info(f"Loaded {fname} for transcription...")
|
2023-08-23 11:22:55 +00:00
|
|
|
|
|
|
|
if captions:
|
|
|
|
speaker = g[0].split()[-1]
|
2023-08-23 13:13:04 +00:00
|
|
|
|
|
|
|
txt.append(f"[{speaker}] ")
|
2023-08-23 11:22:55 +00:00
|
|
|
for c in captions:
|
2023-08-23 13:13:04 +00:00
|
|
|
txt.append(f"{c['text']}")
|
|
|
|
txt.append("\n\n")
|
2023-08-23 11:22:55 +00:00
|
|
|
|
|
|
|
output = "".join(txt)
|
2023-08-23 13:09:53 +00:00
|
|
|
fname = Path.joinpath(transcription_path, "transcription_result.txt")
|
|
|
|
with open(fname, "w", encoding="utf-8") as file:
|
2023-08-23 11:22:55 +00:00
|
|
|
file.write(output)
|
2023-08-23 13:11:44 +00:00
|
|
|
logging.info(f"Wrote transcription to output file {fname}.")
|
2023-08-23 13:09:53 +00:00
|
|
|
return TxtTranscription(text=output, file=fname)
|
2023-08-20 12:29:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def save_diarized_audio_files(
|
|
|
|
diarization: Path, audiofile: Path, output_path: Path
|
|
|
|
) -> list:
|
|
|
|
groups = _group_speakers(diarization)
|
2023-08-23 13:11:44 +00:00
|
|
|
_save_individual_audio_files(
|
|
|
|
audiofile=audiofile, groups=groups, output_path=output_path
|
|
|
|
)
|
2023-08-20 12:29:36 +00:00
|
|
|
return groups
|
|
|
|
|
|
|
|
|
|
|
|
def _add_audio_silence(audiofile) -> Path:
|
|
|
|
spacermilli = MILLISECONDS_TO_SPACE
|
|
|
|
spacer = AudioSegment.silent(duration=spacermilli)
|
|
|
|
audio = AudioSegment.from_wav(audiofile)
|
|
|
|
audio = spacer.append(audio, crossfade=0)
|
2023-08-23 13:11:44 +00:00
|
|
|
fname = Path.joinpath(Path(os.path.dirname(audiofile)), "interview_prepend.wav")
|
|
|
|
audio.export(fname, format="wav")
|
2023-08-23 13:10:41 +00:00
|
|
|
logging.info(f"Exported audiofile with silence prepended to {fname}.")
|
2023-08-20 12:29:36 +00:00
|
|
|
|
2023-08-23 13:11:44 +00:00
|
|
|
return fname
|
2023-08-20 12:29:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _save_individual_audio_files(
|
|
|
|
audiofile: Path, groups: list[str], output_path: Path
|
|
|
|
) -> None:
|
|
|
|
audio = AudioSegment.from_wav(audiofile)
|
|
|
|
gidx = -1
|
|
|
|
for g in groups:
|
|
|
|
start = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[0])[0]
|
|
|
|
end = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=g[-1])[1]
|
|
|
|
start = _millisec(start) # - spacermilli
|
|
|
|
end = _millisec(end) # - spacermilli
|
|
|
|
gidx += 1
|
2023-08-23 13:11:44 +00:00
|
|
|
fname = Path.joinpath(output_path, f"{str(gidx)}.wav")
|
|
|
|
audio[start:end].export(fname, format="wav")
|
2023-08-23 13:10:41 +00:00
|
|
|
logging.info(f"Exported audiopart {gidx} of {len(groups)} to {fname}.")
|
2023-08-20 12:29:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _group_speakers(diarization_file: Path) -> list:
|
|
|
|
dzs = open(diarization_file).read().splitlines()
|
|
|
|
|
|
|
|
groups: list = []
|
|
|
|
g = []
|
|
|
|
lastend = 0
|
|
|
|
|
|
|
|
for d in dzs:
|
|
|
|
if g and (g[0].split()[-1] != d.split()[-1]): # same speaker
|
|
|
|
groups.append(g)
|
|
|
|
g = []
|
|
|
|
|
|
|
|
g.append(d)
|
|
|
|
|
|
|
|
end = re.findall(r"[0-9]+:[0-9]+:[0-9]+\.[0-9]+", string=d)[1]
|
|
|
|
end = _millisec(end)
|
|
|
|
if lastend > end: # segment engulfed by a previous segment
|
|
|
|
groups.append(g)
|
|
|
|
g = []
|
|
|
|
else:
|
|
|
|
lastend = end
|
|
|
|
if g:
|
|
|
|
groups.append(g)
|
|
|
|
return groups
|
|
|
|
|
|
|
|
|
|
|
|
def _millisec(timeStr):
|
|
|
|
spl = timeStr.split(":")
|
|
|
|
s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2])) * 1000)
|
|
|
|
return s
|