Update modeling_dots_ocr_vllm.py
Browse files- modeling_dots_ocr_vllm.py +11 -0
modeling_dots_ocr_vllm.py
CHANGED
|
@@ -99,6 +99,7 @@ class DotsOCRProcessingInfo(Qwen2_5_VLProcessingInfo):
|
|
| 99 |
size: Optional[dict[str, int]] = None,
|
| 100 |
**kwargs: object,
|
| 101 |
) -> Qwen2VLProcessor:
|
|
|
|
| 102 |
processor = self.ctx.get_hf_processor(
|
| 103 |
Qwen2VLProcessor,
|
| 104 |
image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
|
|
@@ -166,6 +167,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
|
|
| 166 |
)
|
| 167 |
_tp_plan = {}
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 170 |
super().__init__()
|
| 171 |
|
|
@@ -409,6 +415,10 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
|
|
| 409 |
|
| 410 |
|
| 411 |
def patch_vllm_chat_placeholder():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
|
| 413 |
|
| 414 |
ori = BaseMultiModalItemTracker._placeholder_str
|
|
@@ -426,4 +436,5 @@ ModelRegistry.register_model(
|
|
| 426 |
"DotsOCRForCausalLM", DotsOCRForCausalLM,
|
| 427 |
)
|
| 428 |
|
|
|
|
| 429 |
patch_vllm_chat_placeholder()
|
|
|
|
| 99 |
size: Optional[dict[str, int]] = None,
|
| 100 |
**kwargs: object,
|
| 101 |
) -> Qwen2VLProcessor:
|
| 102 |
+
self.get_tokenizer().image_token = "<|imgpad|>" # Ensure image token is set
|
| 103 |
processor = self.ctx.get_hf_processor(
|
| 104 |
Qwen2VLProcessor,
|
| 105 |
image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
|
|
|
|
| 167 |
)
|
| 168 |
_tp_plan = {}
|
| 169 |
|
| 170 |
+
@classmethod
|
| 171 |
+
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
| 172 |
+
if modality in ("image",):
|
| 173 |
+
return "<|img|><|imgpad|><|endofimg|>"
|
| 174 |
+
|
| 175 |
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 176 |
super().__init__()
|
| 177 |
|
|
|
|
| 415 |
|
| 416 |
|
| 417 |
def patch_vllm_chat_placeholder():
|
| 418 |
+
import vllm
|
| 419 |
+
# return when vllm version > 0.9.1
|
| 420 |
+
if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1):
|
| 421 |
+
return
|
| 422 |
from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
|
| 423 |
|
| 424 |
ori = BaseMultiModalItemTracker._placeholder_str
|
|
|
|
| 436 |
"DotsOCRForCausalLM", DotsOCRForCausalLM,
|
| 437 |
)
|
| 438 |
|
| 439 |
+
|
| 440 |
patch_vllm_chat_placeholder()
|