Skip to content

Conversation

@Chenhao-Guan
Copy link

Resolves #41862

Hi @zucchini-nlp and @Rocketknight1,

Following your guidance in the issue, this PR re-implements the InternVL-Flash model as a completely separate model (instead of using an if flag in the existing InternVL class).

Implementation Details
Created a new, independent model directory: src/transformers/models/internvl_flash/.

Used the transformers add-new-model-like script to scaffold the new model, as you suggested.

Implemented the model logic in modular_internvl_flash.py (including Gating, CrossAttentionPooling, etc.) and converted it using the modular script.

Testing
All local tests are passing:

make fixup (style, quality, and repository consistency checks all pass)

pytest tests/models/internvl_flash/test_modeling_internvl_flash.py

Thank you for the guidance!

Before submitting
[ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).

[x] Did you read the contributor guideline, Pull Request section?

[x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. (Link: #41862)

[x] Did you make sure to update the documentation with your changes? (Added docs/source/en/model_doc/internvl_flash.md and updated _toctree.yml)

[x] Did you write any new necessary tests? (Added tests/models/internvl_flash/test_modeling_internvl_flash.py)

Who can review?
@zucchini-nlp @Rocketknight1

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Nov 13, 2025

Taking a look tomorrow-Monday, thanks for making a new model class

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Chenhao-Guan , thanks for making a separate PR for the model!

I have a few major comments:

  1. The model doesn't seem to support batch size > 1 currently. Prob the official release works with a single batch size which is oke. However we need to enable batched inference before merging this PR
  2. I see that you kept 'flash" and "non-flash" paths. We have to delete the "non-flash" code path as it is not needed to run with InternVLFlash released checkpoint
  3. A few minor issues like naming and sticking to transformers standards 👇🏻

)


class Gating(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InternVLFlashGating or similar naming since it's recommended to have model's name explicit in layer names

Comment on lines 71 to 76
if self.use_checkpoint:
x = x + cp.checkpoint(self.block1, x)
x = x + cp.checkpoint(self.block2, x)
x = x + cp.checkpoint(self.block3, x)
x = x + cp.checkpoint(self.block4, x)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think it, GC can be toggled on by PreTrainedModel if needed

Comment on lines 54 to 63
def mlp_block(in_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(out_dim, in_dim),
nn.Dropout(dropout),
nn.LayerNorm(in_dim),
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate layers are more preferred than a Sequential. Let's create an "nn.Module" with the following and call it smth like InternVLFlashMLP

self.block2 = mlp_block(hidden_size, mid_dim)
self.block3 = mlp_block(hidden_size, mid_dim)
self.block4 = mlp_block(hidden_size, mid_dim)
self.gate = nn.Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, 2)) # 2 experts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, let's have them separated

return probs


class CrossAttentionPooling(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment for naming

Comment on lines 295 to 307
flag_idx = 0
for s, e, l, num_blocks in zip(starts.tolist(), ends.tolist(), lengths.tolist(), block_counts):
for i in range(num_blocks):
block_start = s + i * 256
block_end = block_start + 256

compress = gate_result[flag_idx]
flag_idx += 1

if compress:
keep_mask[block_start + 64 : block_end] = False
delete_flags[block_start + 64 : block_end] = 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could it be vectorized?

Comment on lines 311 to 313
mask_idx = mask_idx.squeeze(0)
updated_mask_idx = mask_idx - cumulative_deletes[mask_idx.to(cumulative_deletes.device)].to(mask_idx.device)
updated_mask_idx = updated_mask_idx.unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new attention mask is not used as I can see from current commit, so we have to fix it first. Then, why can't we attention_mask = attention_mask[keep_mask]

Comment on lines 43 to 44
if is_vision_available():
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dummy import, can be deleted

Comment on lines 199 to 203
@unittest.skip(
reason="Failing with `torch._inductor.exc.InductorError: RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_tem_fused_0 Required: 147456 Hardware limit:101376 Reducing block sizes or `num_stages` may help.`"
)
def test_flex_attention_with_grads(self):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you delete skip on it? Prob it was failing on your local hardware, might pass with CI runners

Comment on lines +217 to +219
@unittest.skip("Skipping compilation test: fails with batch_size=0 reshape error")
def test_generate_compile_model_forward_fullgraph(self):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to support any batch size for the model before merging

@Chenhao-Guan
Copy link
Author

@zucchini-nlp Thank you for the advice. This is my first time submitting a PR, and Gonna working to resolve the test failures related to batch_size > 1 support. My initial intention in adding non-Flash methods was specifically to bypass these failing tests temporarily. I will continue working to implement a full solution.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, internvl_flash

@Chenhao-Guan
Copy link
Author

@zucchini-nlp I've finished the requested modifications. Please let me know if there are any other points to discuss before we merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Request for InternVL3_5_Flash

2 participants