import gradio as gr import torch from unsloth import FastLanguageModel from snac import SNAC import numpy as np # Set device globally for the app device = "cuda" if torch.cuda.is_available() else "cpu" # Load models (globally, once when app starts) model_name = "sachin6624/orpheus-3b-0.1-ft-malayalam-3epoch" print(f"Loading LLM {model_name} on {device}...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=2048, dtype=None, load_in_4bit=False, # Use True for 4-bit loading to reduce memory if needed ) model.to(device) FastLanguageModel.for_inference(model) print("LLM loaded.") print(f"Loading SNAC decoder on {device}...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") snac_model = snac_model.to(device) # Explicitly define sample rate as the model name 'snac_24khz' suggests 24000 Hz snac_model_sample_rate = 24000 print("SNAC decoder loaded. Assumed sample rate:", snac_model_sample_rate) # Define tokens on the selected device start_token = torch.tensor([[128259]], dtype=torch.int64, device=device) # Start of human end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device) # End of text, End of human token_to_find = 128257 token_to_remove = 128258 def redistribute_codes(code_list): """ Redistributes SNAC codes into layers and decodes them to audio. `code_list` is expected to be a list of Python integers. """ if not code_list: raise ValueError("Input code_list to redistribute_codes is empty.") layer_1 = [] layer_2 = [] layer_3 = [] # Ensure there are enough codes to form full groups of 7 processed_len = (len(code_list) // 7) * 7 if processed_len == 0: raise ValueError("code_list is too short to form any valid SNAC layers.") for i in range(processed_len // 7): base_idx = 7*i layer_1.append(code_list[base_idx]) layer_2.append(code_list[base_idx+1]-4096) layer_3.append(code_list[base_idx+2]-(2*4096)) layer_3.append(code_list[base_idx+3]-(3*4096)) layer_2.append(code_list[base_idx+4]-(4*4096)) layer_3.append(code_list[base_idx+5]-(5*4096)) layer_3.append(code_list[base_idx+6]-(6*4096)) # Convert lists of Python integers to torch tensors on the specified device codes = [ torch.tensor(layer_1, dtype=torch.long, device=device).unsqueeze(0), torch.tensor(layer_2, dtype=torch.long, device=device).unsqueeze(0), torch.tensor(layer_3, dtype=torch.long, device=device).unsqueeze(0) ] audio_hat = snac_model.decode(codes) return audio_hat def generate_audio(prompt: str): """ Generates audio from a given text prompt. """ if not prompt or not prompt.strip(): raise gr.Error("Please enter a valid text prompt.") try: # Tokenize the prompt and prepare input_ids input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Concatenate start/end tokens to the input_ids modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # Create an attention mask for the unpadded input attention_mask = torch.ones_like(modified_input_ids, dtype=torch.long, device=device) # Generate IDs using the model generated_ids = model.generate( input_ids=modified_input_ids, attention_mask=attention_mask, max_new_tokens=1200, do_sample=True, temperature=0.6, top_p=0.95, repetition_penalty=1.1, num_return_sequences=1, eos_token_id=128258, use_cache = True ) # Post-process generated_ids to extract SNAC codes token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) cropped_tensor = generated_ids if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] # Filter out token_to_remove (EOS token for generation) processed_row_tensor = cropped_tensor[cropped_tensor != token_to_remove] row_length = processed_row_tensor.size(0) new_length = (row_length // 7) * 7 # Ensure length is a multiple of 7 for redistribution if new_length == 0: raise gr.Error("Generated response was too short to form valid audio codes. Try a different prompt or longer text.") trimmed_row = processed_row_tensor[:new_length] # Convert tensor elements to Python integers and apply offset trimmed_row_list = [t.item() - 128266 for t in trimmed_row] samples = redistribute_codes(trimmed_row_list) audio_output = samples.detach().squeeze().to("cpu").numpy() return (snac_model_sample_rate, audio_output) except Exception as e: raise gr.Error(f"An error occurred during audio generation: {e}") # Gradio Interface setup iface = gr.Interface( fn=generate_audio, inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Text Prompt (Malayalam)"), outputs=gr.Audio(label="Generated Audio", autoplay=True), title="Malayalam Text-to-Speech (Orpheus-3B & SNAC)", description="Generate speech from Malayalam text using the fine-tuned Orpheus-3B model and SNAC for audio generation.", examples=[["എങ്ങനെയുണ്ട് എന്റെ കുട്ടി?, ."], ["നമസ്കാരം, നിങ്ങൾക്ക് സുഖമാണോ?"]], ) # Use flagging_mode instead of allow_flagging for Gradio 4.0+ iface.flagging_mode = 'never' # Launch the Gradio app if the script is run directly if __name__ == "__main__": iface.launch()