Spaces:
Running
Running
| import gradio as gr | |
| import spaces | |
| import uuid | |
| import torch | |
| from datetime import timedelta | |
| from lhotse import Recording | |
| from lhotse.dataset import DynamicCutSampler | |
| from nemo.collections.speechlm2 import SALM | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| SAMPLE_RATE = 16000 # Hz | |
| MAX_AUDIO_MINUTES = 120 # wont try to transcribe if longer than this | |
| CHUNK_SECONDS = 40.0 # max audio length seen by the model | |
| BATCH_SIZE = 192 # for parallel transcription of audio longer than CHUNK_SECONDS | |
| model = SALM.from_pretrained("nvidia/canary-qwen-2.5b").bfloat16().eval().to(device) | |
| def timestamp(idx: int): | |
| b = str(timedelta(seconds= idx * CHUNK_SECONDS)) | |
| e = str(timedelta(seconds=(idx + 1) * CHUNK_SECONDS)) | |
| return f"[{b} - {e}]" | |
| def as_batches(audio_filepath, utt_id): | |
| rec = Recording.from_file(audio_filepath, recording_id=utt_id) | |
| if rec.duration / 60.0 > MAX_AUDIO_MINUTES: | |
| raise gr.Error( | |
| f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. " | |
| "If you wish, you may trim the audio using the Audio viewer in Step 1 " | |
| "(click on the scissors icon to start trimming audio)." | |
| ) | |
| cut = rec.resample(SAMPLE_RATE).to_cut() | |
| if cut.num_channels > 1: | |
| cut = cut.to_mono(mono_downmix=True) | |
| return DynamicCutSampler(cut.cut_into_windows(CHUNK_SECONDS), max_cuts=BATCH_SIZE) | |
| def transcribe(audio_filepath): | |
| if audio_filepath is None: | |
| raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone") | |
| utt_id = uuid.uuid4() | |
| pred_text = [] | |
| pred_text_ts = [] | |
| chunk_idx = 0 | |
| for batch in as_batches(audio_filepath, str(utt_id)): | |
| audio, audio_lens = batch.load_audio(collate=True) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]] * len(batch), | |
| audios=torch.as_tensor(audio).to(device, non_blocking=True), | |
| audio_lens=torch.as_tensor(audio_lens).to(device, non_blocking=True), | |
| max_new_tokens=256, | |
| ) | |
| texts = [model.tokenizer.ids_to_text(oids) for oids in output_ids.cpu()] | |
| for t in texts: | |
| pred_text.append(t) | |
| pred_text_ts.append(f"{timestamp(chunk_idx)} {t}\n\n") | |
| chunk_idx += 1 | |
| return ''.join(pred_text_ts), ' '.join(pred_text) | |
| def postprocess(transcript, prompt): | |
| with torch.inference_mode(), model.llm.disable_adapter(): | |
| output_ids = model.generate( | |
| prompts=[[{"role": "user", "content": f"{prompt}\n\n{transcript}"}]], | |
| max_new_tokens=2048, | |
| ) | |
| ans = model.tokenizer.ids_to_text(output_ids[0].cpu()) | |
| ans = ans.split("<|im_start|>assistant")[-1] # get rid of the prompt | |
| if "<think>" in ans: | |
| ans = ans.split("<think>")[-1] | |
| thoughts, ans = ans.split("</think>") # get rid of the thinking | |
| else: | |
| thoughts = "" | |
| return ans.strip(), thoughts | |
| def disable_buttons(): | |
| return gr.update(interactive=False), gr.update(interactive=False) | |
| def enable_buttons(): | |
| return gr.update(interactive=True), gr.update(interactive=True) | |
| with gr.Blocks( | |
| title="NeMo Canary-Qwen-2.5B Model", | |
| css=""" | |
| textarea { font-size: 18px;} | |
| #transcript_box span { | |
| font-size: 18px; | |
| font-weight: bold; | |
| } | |
| """, | |
| theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md ) | |
| ) as demo: | |
| gr.HTML( | |
| "<h1 style='text-align: center'>NeMo Canary-Qwen-2.5B model: Transcribe and prompt</h1>" | |
| "<p>Canary-Qwen-2.5B is an ASR model capable of transcribing speech to text (ASR mode) and using its inner Qwen3-1.7B LLM for answering questions about the transcript (LLM mode).</p>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML( | |
| "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>" | |
| "<p style='color: #A0A0A0;'>This demo supports audio files up to 2 hours long." | |
| ) | |
| audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath") | |
| with gr.Column(): | |
| gr.HTML("<p><b>Step 2:</b> Transcribe the audio.</p>") | |
| asr_button = gr.Button( | |
| value="Run model", | |
| variant="primary", # make "primary" so it stands out (default is "secondary") | |
| ) | |
| transcript_box = gr.Textbox( | |
| label="Model Transcript", | |
| elem_id="transcript_box", | |
| ) | |
| raw_transcript = gr.State() | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML("<p><b>Step 3:</b> Prompt the model.</p>") | |
| prompt_box = gr.Textbox( | |
| "Give me a TL;DR:", | |
| label="Prompt", | |
| elem_id="prompt_box", | |
| ) | |
| with gr.Column(): | |
| gr.HTML("<p><b>Step 4:</b> See the outcome!</p>") | |
| llm_button = gr.Button( | |
| value="Apply the prompt", | |
| variant="primary", # make "primary" so it stands out (default is "secondary") | |
| ) | |
| magic_box = gr.Textbox( | |
| label="Assistant's Response", | |
| elem_id="magic_box", | |
| ) | |
| think_box = gr.Textbox( | |
| label="Assistant's Thinking", | |
| elem_id="think_box", | |
| ) | |
| with gr.Row(): | |
| gr.HTML( | |
| "<p style='text-align: center'>" | |
| "🐤 <a href='https://huggingface.co/nvidia/canary-qwen-2.5b' target='_blank'>Canary-Qwen-2.5B model</a> | " | |
| "🧑💻 <a href='https://github.com/NVIDIA/NeMo' target='_blank'>NeMo Repository</a>" | |
| "</p>" | |
| ) | |
| asr_button.click( | |
| disable_buttons, | |
| outputs=[asr_button, llm_button], | |
| trigger_mode="once", | |
| ).then( | |
| fn=transcribe, | |
| inputs=[audio_file], | |
| outputs=[transcript_box, raw_transcript] | |
| ).then( | |
| enable_buttons, | |
| outputs=[asr_button, llm_button], | |
| ) | |
| llm_button.click( | |
| disable_buttons, | |
| outputs=[asr_button, llm_button], | |
| trigger_mode="once", | |
| ).then( | |
| fn=postprocess, | |
| inputs=[raw_transcript, prompt_box], | |
| outputs=[magic_box, think_box] | |
| ).then( | |
| enable_buttons, | |
| outputs=[asr_button, llm_button], | |
| ) | |
| demo.queue() | |
| demo.launch() | |