# NetraEmbed Gradio Demo

**üöÄ Universal Multilingual Multimodal Document Retrieval**

This notebook provides a Gradio interface for testing both BiGemma3 and ColGemma3 models with PDF document upload, automatic conversion to images, and query-based retrieval.

**Available Models:**
- **NetraEmbed (BiGemma3)**: Single-vector embedding with Matryoshka representation - Fast retrieval with cosine similarity
- **ColNetraEmbed (ColGemma3)**: Multi-vector embedding with late interaction - High-quality retrieval with MaxSim scoring and attention heatmaps

**Links:**
- üìÑ [Paper](https://arxiv.org/abs/2512.03514)
- üíª [GitHub](https://github.com/adithya-s-k/colpali)
- ü§ó [HuggingFace Model](https://huggingface.co/Cognitive-Lab/ColNetraEmbed)
- üìù [Blog](https://www.cognitivelab.in/blog/introducing-netraembed)

---

**‚ö†Ô∏è GPU Requirements:**
- **T4 GPU (16GB)**: Can run single model at a time (use smaller batch sizes)
- **L40S GPU (48GB)**: Can run both models simultaneously
- **A100 GPU (40-80GB)**: Can run both models simultaneously

Go to `Runtime` ‚Üí `Change runtime type` ‚Üí Select GPU

## 1. Setup and Installation

In [1]:
# Install required packages
!pip install -q git+https://github.com/adithya-s-k/colpali.git
!pip install -q gradio pdf2image Pillow matplotlib seaborn einops numpy
!apt-get install -y poppler-utils  # For pdf2image

print("‚úÖ Installation complete!")

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m504.9/504.9 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m887.9/887.9 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m594.3/594.3 MB[0m [31m851.8 kB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m10.2/10.2 MB[0m [31m115.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [2]:
# Import libraries
import io
import gc
import os
import math
from typing import Iterator, List, Optional, Tuple

import gradio as gr
import torch
from pdf2image import convert_from_path
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from einops import rearrange

# Import from colpali_engine
from colpali_engine.models import BiGemma3, BiGemmaProcessor3, ColGemma3, ColGemmaProcessor3
from colpali_engine.interpretability import get_similarity_maps_from_embeddings
from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map

# Check GPU availability and memory
device = "cuda" if torch.cuda.is_available() else "cpu"
gpu_memory_gb = 0.0
gpu_name = ""

print(f"Device: {device}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_memory_gb:.2f} GB")

    # Determine GPU capability
    if gpu_memory_gb >= 40:
        print("‚úÖ GPU has sufficient memory (‚â•40GB) - Can run both models simultaneously")
        can_run_both_models = True
    elif gpu_memory_gb >= 24:
        print("‚ö†Ô∏è GPU has moderate memory (24-40GB) - Can run both models but may need careful memory management")
        can_run_both_models = True
    else:
        print("‚ö†Ô∏è GPU has limited memory (<24GB) - Recommended to run one model at a time")
        can_run_both_models = False
else:
    print("‚ö†Ô∏è WARNING: GPU not available! Please change runtime to GPU.")
    can_run_both_models = False

# Set memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"

print("‚úÖ Imports complete!")

Device: cuda
GPU: Tesla T4
GPU Memory: 15.83 GB
‚ö†Ô∏è GPU has limited memory (<24GB) - Recommended to run one model at a time
‚úÖ Imports complete!


## GPU Memory Requirements & Model Selection

The notebook automatically detects your GPU and adjusts settings:

**GPU Configurations:**
- **T4 GPU (16GB)**:
  - Can run **one model at a time** (BiGemma3 OR ColGemma3)
  - Batch size: 1 page at a time
  - Recommended for small PDFs (<20 pages)
  
- **L4 / RTX 4090 (24GB)**:
  - Can run **one model at a time**
  - Batch size: 2 pages at a time
  - Good for medium PDFs (<50 pages)

- **L40S (48GB) / A100 (40-80GB)**:
  - Can run **both models simultaneously**
  - Batch size: 4 pages at a time
  - Best for large PDFs and running comparisons

The interface will automatically show or hide the "Both" option based on available GPU memory.

## 2. Initialize Document Index and Model Management

In [3]:
# Global state for models and indexed documents
class DocumentIndex:
    def __init__(self):
        self.images: List[Image.Image] = []
        self.bigemma_embeddings = None
        self.colgemma_embeddings = None
        self.bigemma_model = None
        self.bigemma_processor = None
        self.colgemma_model = None
        self.colgemma_processor = None
        self.models_loaded = {"bigemma": False, "colgemma": False}


doc_index = DocumentIndex()

# Configuration - Adjust batch size based on GPU memory
if gpu_memory_gb >= 40:
    MAX_BATCH_SIZE = 32  # Large GPU (A100, L40S)
    EMBEDDING_BATCH_SIZE = 4  # For embedding generation
elif gpu_memory_gb >= 24:
    MAX_BATCH_SIZE = 16  # Medium GPU (L4, RTX 4090)
    EMBEDDING_BATCH_SIZE = 2
else:
    MAX_BATCH_SIZE = 8  # Small GPU (T4)
    EMBEDDING_BATCH_SIZE = 1  # Process one page at a time

print("‚úÖ Document index initialized!")
print(f"   - Max batch size: {MAX_BATCH_SIZE}")
print(f"   - Embedding batch size: {EMBEDDING_BATCH_SIZE}")

‚úÖ Document index initialized!
   - Max batch size: 8
   - Embedding batch size: 1


In [4]:
# Helper functions for model management
def get_loaded_models() -> List[str]:
    """Get list of currently loaded models."""
    loaded = []
    if doc_index.bigemma_model is not None:
        loaded.append("BiGemma3")
    if doc_index.colgemma_model is not None:
        loaded.append("ColGemma3")
    return loaded


def get_model_choice_from_loaded() -> str:
    """Determine model choice string based on what's loaded."""
    loaded = get_loaded_models()
    if "BiGemma3" in loaded and "ColGemma3" in loaded:
        return "Both"
    elif "BiGemma3" in loaded:
        return "NetraEmbed (BiGemma3)"
    elif "ColGemma3" in loaded:
        return "ColNetraEmbed (ColGemma3)"
    else:
        return ""


def load_bigemma_model():
    """Load BiGemma3 model and processor."""
    if doc_index.bigemma_model is None:
        print("Loading BiGemma3 (NetraEmbed)...")
        try:
            doc_index.bigemma_processor = BiGemmaProcessor3.from_pretrained(
                "Cognitive-Lab/NetraEmbed",
                use_fast=True,
            )
            doc_index.bigemma_model = BiGemma3.from_pretrained(
                "Cognitive-Lab/NetraEmbed",
                torch_dtype=torch.bfloat16,
                device_map=device,
            )
            doc_index.bigemma_model.eval()
            doc_index.models_loaded["bigemma"] = True
            print("‚úì BiGemma3 loaded successfully")
        except Exception as e:
            print(f"‚ùå Failed to load BiGemma3: {str(e)}")
            raise
    return doc_index.bigemma_model, doc_index.bigemma_processor


def load_colgemma_model():
    """Load ColGemma3 model and processor."""
    if doc_index.colgemma_model is None:
        print("Loading ColGemma3 (ColNetraEmbed)...")
        try:
            doc_index.colgemma_model = ColGemma3.from_pretrained(
                "Cognitive-Lab/ColNetraEmbed",
                dtype=torch.bfloat16,
                device_map=device,
            )
            doc_index.colgemma_model.eval()
            doc_index.colgemma_processor = ColGemmaProcessor3.from_pretrained(
                "Cognitive-Lab/ColNetraEmbed",
                use_fast=True,
            )
            doc_index.models_loaded["colgemma"] = True
            print("‚úì ColGemma3 loaded successfully")
        except Exception as e:
            print(f"‚ùå Failed to load ColGemma3: {str(e)}")
            raise
    return doc_index.colgemma_model, doc_index.colgemma_processor


def unload_models():
    """Unload models and free GPU memory."""
    try:
        if doc_index.bigemma_model is not None:
            del doc_index.bigemma_model
            del doc_index.bigemma_processor
            doc_index.bigemma_model = None
            doc_index.bigemma_processor = None
            doc_index.models_loaded["bigemma"] = False

        if doc_index.colgemma_model is not None:
            del doc_index.colgemma_model
            del doc_index.colgemma_processor
            doc_index.colgemma_model = None
            doc_index.colgemma_processor = None
            doc_index.models_loaded["colgemma"] = False

        doc_index.bigemma_embeddings = None
        doc_index.colgemma_embeddings = None
        doc_index.images = []

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

        return "‚úÖ Models unloaded and GPU memory cleared"
    except Exception as e:
        return f"‚ùå Error unloading models: {str(e)}"


def clear_incompatible_embeddings(model_choice: str) -> str:
    """Clear embeddings that are incompatible with currently loading models."""
    cleared = []

    if model_choice == "NetraEmbed (BiGemma3)":
        if doc_index.colgemma_embeddings is not None:
            doc_index.colgemma_embeddings = None
            doc_index.images = []
            cleared.append("ColGemma3")
            print("Cleared ColGemma3 embeddings")

    elif model_choice == "ColNetraEmbed (ColGemma3)":
        if doc_index.bigemma_embeddings is not None:
            doc_index.bigemma_embeddings = None
            doc_index.images = []
            cleared.append("BiGemma3")
            print("Cleared BiGemma3 embeddings")

    if cleared:
        return f"Cleared {', '.join(cleared)} embeddings - please re-index"
    return ""


def pdf_to_images(pdf_path: str) -> List[Image.Image]:
    """Convert PDF to list of PIL Images with error handling."""
    try:
        print(f"Converting PDF to images: {pdf_path}")
        images = convert_from_path(pdf_path, dpi=200)
        print(f"Converted {len(images)} pages")
        return images
    except Exception as e:
        print(f"‚ùå PDF conversion error: {str(e)}")
        raise Exception(f"Failed to convert PDF: {str(e)}")


print("‚úÖ Helper functions defined!")

‚úÖ Helper functions defined!


## 3. Document Processing and Query Functions

In [5]:
def generate_colgemma_heatmap(
    image: Image.Image,
    query: str,
    query_embedding: torch.Tensor,
    image_embedding: torch.Tensor,
    model,
    processor,
) -> Image.Image:
    """Generate heatmap overlay for ColGemma3 results."""
    try:
        batch_images = processor.process_images([image]).to(device)

        if "input_ids" in batch_images and hasattr(model.config, "image_token_id"):
            image_token_id = model.config.image_token_id
            image_mask = batch_images["input_ids"] == image_token_id
        else:
            image_mask = torch.ones(image_embedding.shape[0], image_embedding.shape[1], dtype=torch.bool, device=device)

        num_image_tokens = image_mask.sum().item()
        n_side = int(math.sqrt(num_image_tokens))

        if n_side * n_side == num_image_tokens:
            n_patches = (n_side, n_side)
        else:
            n_patches = (16, 16)

        similarity_maps_list = get_similarity_maps_from_embeddings(
            image_embeddings=image_embedding,
            query_embeddings=query_embedding,
            n_patches=n_patches,
            image_mask=image_mask,
        )

        similarity_map = similarity_maps_list[0]

        if similarity_map.dtype == torch.bfloat16:
            similarity_map = similarity_map.float()
        aggregated_map = torch.mean(similarity_map, dim=0)

        img_array = np.array(image.convert("RGBA"))
        similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
        similarity_map_array = rearrange(similarity_map_array, "h w -> w h")

        similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize(
            image.size, Image.Resampling.BICUBIC
        )

        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(img_array)
        ax.imshow(
            similarity_map_image,
            cmap=sns.color_palette("mako", as_cmap=True),
            alpha=0.5,
        )
        ax.set_axis_off()
        plt.tight_layout()

        buffer = io.BytesIO()
        plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight", pad_inches=0)
        buffer.seek(0)
        heatmap_image = Image.open(buffer).copy()
        plt.close()

        return heatmap_image

    except Exception as e:
        print(f"‚ùå Heatmap generation error: {str(e)}")
        return image


print("‚úÖ Heatmap function defined!")

‚úÖ Heatmap function defined!


In [6]:
@torch.inference_mode()
def index_document(pdf_file, model_choice: str) -> Iterator[str]:
    """Upload and index a PDF document with progress updates."""
    if pdf_file is None:
        yield "‚ö†Ô∏è Please upload a PDF document first."
        return

    try:
        status_messages = []

        status_messages.append("‚è≥ Converting PDF to images...")
        yield "\n".join(status_messages)

        doc_index.images = pdf_to_images(pdf_file.name)
        num_pages = len(doc_index.images)

        status_messages.append(f"‚úì Converted PDF to {num_pages} images")

        if num_pages > MAX_BATCH_SIZE:
            status_messages.append(f"‚ö†Ô∏è Large PDF ({num_pages} pages). Processing in batches of {MAX_BATCH_SIZE}...")
            yield "\n".join(status_messages)

        if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
            if doc_index.bigemma_model is None:
                status_messages.append("‚è≥ Loading BiGemma3 model...")
                yield "\n".join(status_messages)
                load_bigemma_model()
                status_messages.append("‚úì BiGemma3 loaded")
            else:
                status_messages.append("‚úì Using cached BiGemma3 model")

            yield "\n".join(status_messages)

            model, processor = doc_index.bigemma_model, doc_index.bigemma_processor

            status_messages.append("‚è≥ Encoding images with BiGemma3...")
            yield "\n".join(status_messages)

            # Process in batches to avoid OOM on smaller GPUs
            bigemma_embeddings_list = []
            for i in range(0, num_pages, EMBEDDING_BATCH_SIZE):
                batch = doc_index.images[i : i + EMBEDDING_BATCH_SIZE]
                batch_images = processor.process_images(batch).to(device)
                batch_embeddings = model(**batch_images, embedding_dim=768)
                bigemma_embeddings_list.append(batch_embeddings.cpu())

                # Clear GPU memory
                del batch_images, batch_embeddings
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                if i % (EMBEDDING_BATCH_SIZE * 5) == 0 and i > 0:
                    status_messages[-1] = f"‚è≥ Encoding images with BiGemma3... ({i}/{num_pages})"
                    yield "\n".join(status_messages)

            doc_index.bigemma_embeddings = torch.cat(bigemma_embeddings_list, dim=0).to(device)

            status_messages[-1] = "‚úì Indexed with BiGemma3 (shape: {})".format(doc_index.bigemma_embeddings.shape)
            yield "\n".join(status_messages)

        if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
            if doc_index.colgemma_model is None:
                status_messages.append("‚è≥ Loading ColGemma3 model...")
                yield "\n".join(status_messages)
                load_colgemma_model()
                status_messages.append("‚úì ColGemma3 loaded")
            else:
                status_messages.append("‚úì Using cached ColGemma3 model")

            yield "\n".join(status_messages)

            model, processor = doc_index.colgemma_model, doc_index.colgemma_processor

            status_messages.append("‚è≥ Encoding images with ColGemma3...")
            yield "\n".join(status_messages)

            # Process in batches to avoid OOM on smaller GPUs
            colgemma_embeddings_list = []
            for i in range(0, num_pages, EMBEDDING_BATCH_SIZE):
                batch = doc_index.images[i : i + EMBEDDING_BATCH_SIZE]
                batch_images = processor.process_images(batch).to(device)
                batch_embeddings = model(**batch_images)
                colgemma_embeddings_list.append(batch_embeddings.cpu())

                # Clear GPU memory
                del batch_images, batch_embeddings
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                if i % (EMBEDDING_BATCH_SIZE * 5) == 0 and i > 0:
                    status_messages[-1] = f"‚è≥ Encoding images with ColGemma3... ({i}/{num_pages})"
                    yield "\n".join(status_messages)

            doc_index.colgemma_embeddings = torch.cat(colgemma_embeddings_list, dim=0).to(device)

            status_messages[-1] = "‚úì Indexed with ColGemma3 (shape: {})".format(doc_index.colgemma_embeddings.shape)
            yield "\n".join(status_messages)

        final_status = "\n".join(status_messages) + "\n\n‚úÖ Document ready for querying!"
        yield final_status

    except Exception as e:
        import traceback

        error_details = traceback.format_exc()
        print(f"Indexing error: {error_details}")
        yield f"‚ùå Error indexing document: {str(e)}"


@torch.inference_mode()
def query_documents(
    query: str, model_choice: str, top_k: int, show_heatmap: bool = False
) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[List]]:
    """Query the indexed documents."""
    if not doc_index.images:
        return "‚ö†Ô∏è Please upload and index a document first.", None, None, None

    if not query.strip():
        return "‚ö†Ô∏è Please enter a query.", None, None, None

    try:
        results_bi = None
        results_col = None
        gallery_images_bi = []
        gallery_images_col = []

        if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
            if doc_index.bigemma_embeddings is None:
                return "‚ö†Ô∏è Please index the document with BiGemma3 first.", None, None, None

            model, processor = doc_index.bigemma_model, doc_index.bigemma_processor

            batch_query = processor.process_texts([query]).to(device)
            query_embedding = model(**batch_query, embedding_dim=768)

            scores = processor.score(
                qs=query_embedding,
                ps=doc_index.bigemma_embeddings,
            )

            top_k_actual = min(top_k, len(doc_index.images))
            top_indices = scores[0].argsort(descending=True)[:top_k_actual]

            results_bi = "### BiGemma3 (NetraEmbed) Results\n\n"
            for rank, idx in enumerate(top_indices):
                score = scores[0, idx].item()
                results_bi += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n"
                gallery_images_bi.append(
                    (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})")
                )

        if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
            if doc_index.colgemma_embeddings is None:
                return "‚ö†Ô∏è Please index the document with ColGemma3 first.", None, None, None

            model, processor = doc_index.colgemma_model, doc_index.colgemma_processor

            batch_query = processor.process_queries([query]).to(device)
            query_embedding = model(**batch_query)

            scores = processor.score_multi_vector(
                qs=query_embedding,
                ps=doc_index.colgemma_embeddings,
            )

            top_k_actual = min(top_k, len(doc_index.images))
            top_indices = scores[0].argsort(descending=True)[:top_k_actual]

            results_col = "### ColGemma3 (ColNetraEmbed) Results\n\n"
            for rank, idx in enumerate(top_indices):
                score = scores[0, idx].item()
                results_col += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n"

                if show_heatmap:
                    heatmap_image = generate_colgemma_heatmap(
                        image=doc_index.images[idx.item()],
                        query=query,
                        query_embedding=query_embedding,
                        image_embedding=doc_index.colgemma_embeddings[idx.item()].unsqueeze(0),
                        model=model,
                        processor=processor,
                    )
                    gallery_images_col.append(
                        (heatmap_image, f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
                    )
                else:
                    gallery_images_col.append(
                        (
                            doc_index.images[idx.item()],
                            f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})",
                        )
                    )

        if model_choice == "NetraEmbed (BiGemma3)":
            return results_bi, None, gallery_images_bi, None
        elif model_choice == "ColNetraEmbed (ColGemma3)":
            return results_col, None, None, gallery_images_col
        else:
            return results_bi, results_col, gallery_images_bi, gallery_images_col

    except Exception as e:
        import traceback

        error_details = traceback.format_exc()
        print(f"Query error: {error_details}")
        return f"‚ùå Error during query: {str(e)}", None, None, None


print("‚úÖ Index and query functions defined!")

‚úÖ Index and query functions defined!


## 4. Create and Launch Gradio Interface

In [7]:
def load_models_with_progress(model_choice: str) -> Iterator[Tuple]:
    """Load models with progress updates."""
    if not model_choice:
        yield (
            "‚ùå Please select a model first.",
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(value="Load model first"),
        )
        return

    # Validate GPU memory for "Both" option
    if model_choice == "Both" and not can_run_both_models:
        yield (
            f"‚ùå Insufficient GPU memory ({gpu_memory_gb:.1f}GB) to run both models.\nPlease select one model at a time.",
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(value="Load model first"),
        )
        return

    try:
        status_messages = []
        clear_msg = clear_incompatible_embeddings(model_choice)
        if clear_msg:
            status_messages.append(f"‚ö†Ô∏è {clear_msg}")

        if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
            status_messages.append("‚è≥ Loading BiGemma3 (NetraEmbed)...")
            yield (
                "\n".join(status_messages),
                gr.update(visible=True),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(value="Loading models..."),
            )

            load_bigemma_model()
            status_messages[-1] = "‚úÖ BiGemma3 loaded successfully"
            yield (
                "\n".join(status_messages),
                gr.update(visible=True),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(value="Loading models..."),
            )

        if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
            status_messages.append("‚è≥ Loading ColGemma3 (ColNetraEmbed)...")
            yield (
                "\n".join(status_messages),
                gr.update(visible=True),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(value="Loading models..."),
            )

            load_colgemma_model()
            status_messages[-1] = "‚úÖ ColGemma3 loaded successfully"
            yield (
                "\n".join(status_messages),
                gr.update(visible=True),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(interactive=False),
                gr.update(value="Loading models..."),
            )

        show_bigemma = model_choice in ["NetraEmbed (BiGemma3)", "Both"]
        show_colgemma = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]
        show_heatmap_checkbox = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]

        final_status = "\n".join(status_messages) + "\n\n‚úÖ Ready!"
        yield (
            final_status,
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=show_bigemma),
            gr.update(visible=show_colgemma),
            gr.update(visible=show_heatmap_checkbox),
            gr.update(interactive=True),
            gr.update(interactive=True),
            gr.update(interactive=True),
            gr.update(interactive=True),
            gr.update(interactive=True),
            gr.update(value="Ready to index"),
        )

    except Exception as e:
        import traceback

        error_details = traceback.format_exc()
        print(f"Model loading error: {error_details}")
        yield (
            f"‚ùå Failed to load models: {str(e)}",
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(value="Load model first"),
        )


def unload_models_and_hide_ui():
    """Unload models and hide main UI."""
    status = unload_models()
    return (
        status,
        gr.update(visible=True),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(interactive=False),
        gr.update(interactive=False),
        gr.update(interactive=False),
        gr.update(interactive=False),
        gr.update(interactive=False),
        gr.update(value="Load model first"),
    )


def index_with_current_models(pdf_file):
    """Index document with currently loaded models."""
    if pdf_file is None:
        yield "‚ö†Ô∏è Please upload a PDF document first."
        return

    model_choice = get_model_choice_from_loaded()
    if not model_choice:
        yield "‚ö†Ô∏è No models loaded. Please load a model first."
        return

    for status in index_document(pdf_file, model_choice):
        yield status


def query_with_current_models(query, top_k, show_heatmap):
    """Query with currently loaded models."""
    model_choice = get_model_choice_from_loaded()
    if not model_choice:
        return "‚ö†Ô∏è No models loaded. Please load a model first.", None, None, None

    return query_documents(query, model_choice, top_k, show_heatmap)


print("‚úÖ UI helper functions defined!")

‚úÖ UI helper functions defined!


In [8]:
# Create Gradio interface
with gr.Blocks(title="NetraEmbed Demo") as demo:
    # Header section
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("# NetraEmbed")
            gr.HTML(
                """
                <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;">
                    <a href="https://arxiv.org/abs/2512.03514" target="_blank">
                        <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper">
                    </a>
                    <a href="https://github.com/adithya-s-k/colpali" target="_blank">
                        <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
                    </a>
                    <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank">
                        <img src="https://img.shields.io/badge/ü§ó%20HuggingFace-Model-yellow" alt="Model">
                    </a>
                    <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank">
                        <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog">
                    </a>
                </div>
                """
            )
            gr.Markdown(
                """
                **üöÄ Universal Multilingual Multimodal Document Retrieval**

                Upload a PDF document, select your model(s), and query using semantic search.

                **Available Models:**
                - **NetraEmbed (BiGemma3)**: Single-vector embedding - Fast retrieval
                - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding - High-quality retrieval with heatmaps
                """
            )

        with gr.Column(scale=1):
            gr.HTML(
                """
                <div style="text-align: center;">
                    <img src="https://cdn-uploads.huggingface.co/production/uploads/6442d975ad54813badc1ddf7/-fYMikXhSuqRqm-UIdulK.png"
                         alt="NetraEmbed Banner"
                         style="width: 100%; height: auto; border-radius: 8px;">
                </div>
                """
            )

    gr.Markdown("---")

    # Compact 3-column layout
    with gr.Row():
        # Column 1: Model Management
        with gr.Column(scale=1):
            gr.Markdown("### ü§ñ Model Management")

            # Conditionally show model options based on GPU memory
            if can_run_both_models:
                model_choices = ["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"]
                default_choice = "Both"
                gpu_info_text = f"‚úÖ GPU: {gpu_name} ({gpu_memory_gb:.0f}GB) - Can run both models"
            else:
                model_choices = ["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)"]
                default_choice = "NetraEmbed (BiGemma3)"
                gpu_info_text = f"‚ö†Ô∏è GPU: {gpu_name} ({gpu_memory_gb:.0f}GB) - Run one model at a time"

            gr.Markdown(f"**{gpu_info_text}**")

            model_select = gr.Radio(
                choices=model_choices,
                value=default_choice,
                label="Select Model(s)",
            )

            load_model_btn = gr.Button("üîÑ Load Model", variant="primary", size="sm")
            unload_model_btn = gr.Button("üóëÔ∏è Unload", variant="secondary", size="sm")

            model_status = gr.Textbox(
                label="Status",
                lines=6,
                interactive=False,
                value="Select and load a model",
            )

            loading_info = gr.Markdown(
                f"""
                **GPU Memory:** {gpu_memory_gb:.1f} GB
                **First load:** 2-3 min
                **Cached:** ~30 sec
                **Batch size:** {EMBEDDING_BATCH_SIZE} pages
                """,
                visible=True,
            )

        # Column 2: Document Upload & Indexing
        with gr.Column(scale=1):
            gr.Markdown("### üìÑ Upload & Index")
            pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"], interactive=False)
            index_btn = gr.Button("üì• Index Document", variant="primary", size="sm", interactive=False)

            index_status = gr.Textbox(
                label="Indexing Status",
                lines=6,
                interactive=False,
                value="Load model first",
            )

        # Column 3: Query
        with gr.Column(scale=1):
            gr.Markdown("### üîé Query Document")
            query_input = gr.Textbox(
                label="Enter Query",
                placeholder="e.g., financial report, organizational structure...",
                lines=2,
                interactive=False,
            )

            with gr.Row():
                top_k_slider = gr.Slider(
                    minimum=1,
                    maximum=10,
                    value=5,
                    step=1,
                    label="Top K",
                    scale=2,
                    interactive=False,
                )
                heatmap_checkbox = gr.Checkbox(
                    label="Heatmaps",
                    value=False,
                    visible=False,
                    scale=1,
                )

            query_btn = gr.Button("üîç Search", variant="primary", size="sm", interactive=False)

    gr.Markdown("---")

    # Results section
    with gr.Column(visible=False) as main_interface:
        gr.Markdown("### üìä Results")

        with gr.Row(equal_height=True):
            with gr.Column(scale=1, visible=False) as bigemma_column:
                bigemma_results = gr.Markdown(
                    value="*BiGemma3 results will appear here...*",
                )
                bigemma_gallery = gr.Gallery(
                    label="BiGemma3 - Top Retrieved Pages",
                    show_label=True,
                    columns=2,
                    height="auto",
                    object_fit="contain",
                )
            with gr.Column(scale=1, visible=False) as colgemma_column:
                colgemma_results = gr.Markdown(
                    value="*ColGemma3 results will appear here...*",
                )
                colgemma_gallery = gr.Gallery(
                    label="ColGemma3 - Top Retrieved Pages",
                    show_label=True,
                    columns=2,
                    height="auto",
                    object_fit="contain",
                )

        # Tips
        with gr.Accordion("üí° Tips", open=False):
            gr.Markdown(
                """
                - **Both models**: Compare results side-by-side
                - **Scores**: BiGemma3 uses cosine similarity (-1 to 1), ColGemma3 uses MaxSim (higher is better)
                - **Heatmaps**: Enable to visualize ColGemma3 attention patterns (brighter = higher attention)
                """
            )

    # Event handlers
    load_model_btn.click(
        fn=load_models_with_progress,
        inputs=[model_select],
        outputs=[
            model_status,
            loading_info,
            main_interface,
            bigemma_column,
            colgemma_column,
            heatmap_checkbox,
            pdf_upload,
            index_btn,
            query_input,
            top_k_slider,
            query_btn,
            index_status,
        ],
    )

    unload_model_btn.click(
        fn=unload_models_and_hide_ui,
        outputs=[
            model_status,
            loading_info,
            main_interface,
            bigemma_column,
            colgemma_column,
            heatmap_checkbox,
            pdf_upload,
            index_btn,
            query_input,
            top_k_slider,
            query_btn,
            index_status,
        ],
    )

    index_btn.click(
        fn=index_with_current_models,
        inputs=[pdf_upload],
        outputs=[index_status],
    )

    query_btn.click(
        fn=query_with_current_models,
        inputs=[query_input, top_k_slider, heatmap_checkbox],
        outputs=[bigemma_results, colgemma_results, bigemma_gallery, colgemma_gallery],
    )

# Launch the demo
demo.queue(max_size=20)
demo.launch(debug=True, share=True)

print("\n" + "=" * 80)
print("üöÄ NetraEmbed Gradio Demo is now running!")
print("=" * 80)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://ba6ec03eaef003496b.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Loading BiGemma3 (NetraEmbed)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

chat_template.json: 0.00B [00:00, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

‚úì BiGemma3 loaded successfully
Converting PDF to images: /tmp/gradio/a67d56bcfc4db0b21251ad626c8da167daeb717dc567a00fd6c37e570dc6abd1/1706.03762v7 2.pdf
Converted 15 pages
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://ba6ec03eaef003496b.gradio.live

üöÄ NetraEmbed Gradio Demo is now running!
