Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import torch
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
from triton_kernels.tensor import bitwidth
from typing import Callable

# Function type: takes four ints (batch_size, m, n, k) and a output dtype, returns an int
CallableSplitK = Callable[[int, int, int, int, torch.dtype], int]


@dataclass
Expand All @@ -33,6 +37,19 @@ def __post_init__(self):
raise ValueError("Not supported")


def get_split_k_from_constraints(
constraints_split_k: int | CallableSplitK,
batch_size: int,
m: int,
n: int,
k: int,
out_dtype: torch.dtype) -> int:
if isinstance(constraints_split_k, int):
return constraints_split_k

return constraints_split_k(batch_size, m, n, k, out_dtype)


def make_default_opt_flags_amd(
out_dtype,
lhs_dtype,
Expand Down Expand Up @@ -91,7 +108,8 @@ def make_default_opt_flags_amd(
is_persistent = constraints.get("is_persistent", False)
# split_k:
if constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
split_k = get_split_k_from_constraints(
constraints["split_k"], batch_size, m, n, k, out_dtype)
elif is_persistent or enforce_bitwise_invariance:
split_k = 1
else:
Expand Down Expand Up @@ -221,7 +239,8 @@ def make_default_opt_flags_nvidia(
block_n, block_k = block_k, block_n
# split_k
if constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
split_k = get_split_k_from_constraints(
constraints["split_k"], batch_size, m, n, k, out_dtype)
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
split_k = 1
else:
Expand Down Expand Up @@ -293,7 +312,7 @@ def make_default_opt_flags_nvidia(
_opt_flags_constraints: dict = dict()
_opt_flags: OptFlags | None = None

def update_opt_flags_constraints(constraints: dict[str, int]):
def update_opt_flags_constraints(constraints: dict[str, int | Callable]):
global _opt_flags_constraints
_opt_flags_constraints.update(constraints)

Expand Down