Commit 58b07f0
committed
float8 inference: fix bmm semantics
Summary:
Fixes the `Float8Tensor` `torch.bmm` override to match the semantics of the
high precision op. Specifically, input 1 is of shape (B, M, K) and input
2 is of shape (B, K, N).
Previously, the shape expectation from `torch.bmm`, which is confusing.
This is important for quantizing LLaMa 4 MoE variants, which use
`torch.bmm` in the HF implementation.
Test Plan:
```
pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -x -k bmm
```
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: 9e16572
ghstack-comment-id: 3493356198
Pull-Request: #32961 parent a257166 commit 58b07f0
File tree
2 files changed
+21
-14
lines changed- test/quantization/quantize_/workflows/float8
- torchao/quantization/quantize_/workflows/float8
2 files changed
+21
-14
lines changedLines changed: 14 additions & 8 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
444 | 444 | | |
445 | 445 | | |
446 | 446 | | |
447 | | - | |
| 447 | + | |
448 | 448 | | |
449 | 449 | | |
450 | 450 | | |
451 | 451 | | |
452 | 452 | | |
453 | | - | |
| 453 | + | |
454 | 454 | | |
455 | 455 | | |
456 | 456 | | |
457 | | - | |
458 | | - | |
459 | | - | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
460 | 463 | | |
461 | | - | |
462 | | - | |
463 | 464 | | |
464 | 465 | | |
465 | | - | |
| 466 | + | |
| 467 | + | |
466 | 468 | | |
467 | 469 | | |
468 | 470 | | |
| |||
551 | 553 | | |
552 | 554 | | |
553 | 555 | | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
554 | 560 | | |
555 | 561 | | |
556 | 562 | | |
| |||
Lines changed: 7 additions & 6 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
422 | 422 | | |
423 | 423 | | |
424 | 424 | | |
425 | | - | |
426 | | - | |
| 425 | + | |
427 | 426 | | |
428 | 427 | | |
429 | | - | |
430 | | - | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
431 | 431 | | |
432 | 432 | | |
433 | 433 | | |
434 | 434 | | |
435 | 435 | | |
436 | 436 | | |
437 | | - | |
| 437 | + | |
438 | 438 | | |
439 | 439 | | |
440 | 440 | | |
441 | | - | |
| 441 | + | |
442 | 442 | | |
| 443 | + | |
443 | 444 | | |
444 | 445 | | |
445 | 446 | | |
| |||
0 commit comments