Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion melo/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def get_bert(norm_text, word2ph, language, device):
from .spanish_bert import get_bert_feature as sp_bert
from .french_bert import get_bert_feature as fr_bert
from .korean import get_bert_feature as kr_bert
from .turkish import get_bert_feature as tr_bert

lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert,
'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert}
'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert, "TR": tr_bert}
bert = lang_bert_func_map[language](norm_text, word2ph, device)
return bert

6 changes: 3 additions & 3 deletions melo/text/cleaner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from . import chinese, japanese, english, chinese_mix, korean, french, spanish
from . import chinese, japanese, english, chinese_mix, korean, french, spanish, turkish
from . import cleaned_text_to_sequence
import copy

language_module_map = {"ZH": chinese, "JP": japanese, "EN": english, 'ZH_MIX_EN': chinese_mix, 'KR': korean,
'FR': french, 'SP': spanish, 'ES': spanish}
'FR': french, 'SP': spanish, 'ES': spanish, "TR": turkish}


def clean_text(text, language):
Expand Down Expand Up @@ -33,4 +33,4 @@ def text_to_sequence(text, language):


if __name__ == "__main__":
pass
pass
22 changes: 18 additions & 4 deletions melo/text/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,28 @@
]
num_ru_tones = 1

# Turkish symbols (IPA)
tr_symbols = [
"a", "aː", "b", "d", "dʒ", "e", "f", "ɡ", "h", "ɯ",
"i", "j", "k", "l", "m", "n", "o", "ø", "p", "r",
"s", "ʃ", "t", "tʃ", "u", "y", "v", "z", "ː"
]
num_tr_tones = 1

# combine all symbols
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols + kr_symbols + es_symbols + fr_symbols + de_symbols + ru_symbols))
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols + kr_symbols +
es_symbols + fr_symbols + de_symbols + ru_symbols + tr_symbols))
symbols = [pad] + normal_symbols + pu_symbols
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]

# combine all tones
num_tones = num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones + num_ru_tones

num_tones = (num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones +
num_es_tones + num_fr_tones + num_de_tones + num_ru_tones + num_tr_tones)
# language maps
language_id_map = {"ZH": 0, "JP": 1, "EN": 2, "ZH_MIX_EN": 3, 'KR': 4, 'ES': 5, 'SP': 5 ,'FR': 6}
language_id_map = {
"ZH": 0, "JP": 1, "EN": 2, "ZH_MIX_EN": 3, 'KR': 4,
'ES': 5, 'SP': 5, 'FR': 6, 'TR': 7
}
num_languages = len(language_id_map.keys())

language_tone_start_map = {
Expand All @@ -282,9 +294,11 @@
"ES": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
"SP": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
"FR": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones,
"TR": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones
}

if __name__ == "__main__":
a = set(zh_symbols)
b = set(en_symbols)
print(sorted(a & b))

122 changes: 122 additions & 0 deletions melo/text/turkish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import re
from transformers import AutoTokenizer
from . import symbols

def distribute_phone(n_phone, n_word):
phones_per_word = [0] * n_word
for task in range(n_phone):
min_tasks = min(phones_per_word)
min_index = phones_per_word.index(min_tasks)
phones_per_word[min_index] += 1
return phones_per_word

def text_normalize(text):
# Basic Turkish text normalization
# Convert to lowercase while preserving Turkish characters
text = text.replace("I", "ı").lower()
text = text.lower()
# Replace multiple spaces with single space
text = re.sub(r'\s+', ' ', text)
# Remove unnecessary punctuation
text = re.sub(r'[^\w\s.,!?;:əğıöüşçİĞÖÜŞÇ]', '', text)
return text.strip()

def post_replace_ph(ph):
rep_map = {
":": ",",
";": ",",
",": ",",
"。": ".",
"!": "!",
"?": "?",
"\n": ".",
"·": ",",
"、": ",",
"...": "…"
}
if ph in rep_map.keys():
ph = rep_map[ph]
if ph in symbols:
return ph
if ph not in symbols:
ph = "UNK"
return ph

def refine_ph(phn):
tone = 0
if re.search(r"\d$", phn):
tone = int(phn[-1]) + 1
phn = phn[:-1]
return phn.lower(), tone

def tr_to_ipa(text):
"""
Convert Turkish text to IPA
This is a basic implementation - you might want to expand this based on Turkish phonology rules
"""
tr_to_ipa_dict = {
'a': 'a', 'e': 'e', 'ı': 'ɯ', 'i': 'i',
'o': 'o', 'ö': 'ø', 'u': 'u', 'ü': 'y',
'b': 'b', 'c': 'dʒ', 'ç': 'tʃ', 'd': 'd',
'f': 'f', 'g': 'ɡ', 'ğ': 'ː', 'h': 'h',
'j': 'ʒ', 'k': 'k', 'l': 'l', 'm': 'm',
'n': 'n', 'p': 'p', 'r': 'r', 's': 's',
'ş': 'ʃ', 't': 't', 'v': 'v', 'y': 'j',
'z': 'z'
}
return ''.join(tr_to_ipa_dict.get(char, char) for char in text.lower())

# Initialize the Turkish BERT tokenizer
model_id = 'ytu-ce-cosmos/turkish-base-bert-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_id)

def g2p(text, pad_start_end=True, tokenized=None):
if tokenized is None:
tokenized = tokenizer.tokenize(text)

phs = []
ph_groups = []
for t in tokenized:
if not t.startswith("##"): # Note: Turkish BERT uses ## for subwords
ph_groups.append([t])
else:
ph_groups[-1].append(t.replace("##", ""))

phones = []
tones = []
word2ph = []

for group in ph_groups:
w = "".join(group)
phone_len = 0
word_len = len(group)
if w == '[UNK]':
phone_list = ['UNK']
else:
phone_list = list(filter(lambda p: p != " ", tr_to_ipa(w)))

for ph in phone_list:
phones.append(ph)
tones.append(0) # Turkish is not a tonal language
phone_len += 1
aaa = distribute_phone(phone_len, word_len)
word2ph += aaa

if pad_start_end:
phones = ["_"] + phones + ["_"]
tones = [0] + tones + [0]
word2ph = [1] + word2ph + [1]
return phones, tones, word2ph

def get_bert_feature(text, word2ph, device=None):
from text import turkish_bert
return turkish_bert.get_bert_feature(text, word2ph, device=device)

if __name__ == "__main__":
text = "Merhaba, nasılsın? Ben iyiyim."
text = text_normalize(text)
print(text)
phones, tones, word2ph = g2p(text)
bert = get_bert_feature(text, word2ph)
print(phones)
print(len(phones), tones, sum(word2ph), bert.shape)
39 changes: 39 additions & 0 deletions melo/text/turkish_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys

model_id = 'ytu-ce-cosmos/turkish-base-bert-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = None

def get_bert_feature(text, word2ph, device=None):
global model
if (
sys.platform == "darwin"
and torch.backends.mps.is_available()
and device == "cpu"
):
device = "mps"
if not device:
device = "cuda"
if model is None:
model = AutoModelForMaskedLM.from_pretrained(model_id).to(
device
)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()

assert inputs["input_ids"].shape[-1] == len(word2ph)
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
repeat_feature = res[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)

phone_level_feature = torch.cat(phone_level_feature, dim=0)

return phone_level_feature.T
2 changes: 1 addition & 1 deletion melo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
if language_str == "ZH":
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU', 'TR']:
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
Expand Down