Skip to content

Commit d1e0d74

Browse files
zjing14meta-codesync[bot]
authored andcommitted
Adjust heuristic to skip comparison of batch (#5196)
Summary: Pull Request resolved: #5196 X-link: https://github.com/facebookresearch/FBGEMM/pull/2193 - Skip comparison of batch size in heuristic Reviewed By: jerryzh168 Differential Revision: D88674202 fbshipit-source-id: cd63e7706d53af628f1516808d4d8a5c1ac2180a
1 parent 055dbbc commit d1e0d74

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ struct ProblemSize {
3737
std::vector<int64_t> stride;
3838
std::vector<int64_t> dilation;
3939
bool operator==(const ProblemSize& ps) const {
40-
return activation_shape == ps.activation_shape &&
40+
return activation_shape[1] == ps.activation_shape[1] &&
41+
activation_shape[2] == ps.activation_shape[2] &&
42+
activation_shape[3] == ps.activation_shape[3] &&
4143
filter_shape == ps.filter_shape;
4244
}
4345
void print() const {
@@ -53,8 +55,7 @@ struct ProblemSize {
5355
<< filter_shape[1] << ","
5456
<< filter_shape[2] << ","
5557
<< filter_shape[3] << ","
56-
<< filter_shape[4] << ","
57-
<< std::endl;
58+
<< filter_shape[4] << ",";
5859
// clang-format on
5960
}
6061
};
@@ -67,17 +68,20 @@ inline void hash_combine(std::size_t& seed, std::size_t value) {
6768
struct ProblemSizeHash {
6869
std::size_t operator()(const ProblemSize& ps) const {
6970
std::size_t seed = 0;
71+
// Only hash spatial dimensions (D, H, W) from activation_shape, not batch
72+
// (N) or channels (C)
73+
hash_combine(seed, std::hash<int64_t>{}(ps.activation_shape[1]));
74+
hash_combine(seed, std::hash<int64_t>{}(ps.activation_shape[2]));
75+
hash_combine(seed, std::hash<int64_t>{}(ps.activation_shape[3]));
76+
// Hash the entire filter_shape
7077
auto vec_hash = [](const std::vector<int64_t>& v) {
7178
std::size_t h = 0;
7279
for (auto x : v)
7380
hash_combine(h, std::hash<int64_t>{}(x));
7481
return h;
7582
};
76-
hash_combine(seed, vec_hash(ps.activation_shape));
7783
hash_combine(seed, vec_hash(ps.filter_shape));
78-
// hash_combine(seed, vec_hash(ps.padding));
79-
// hash_combine(seed, vec_hash(ps.stride));
80-
// hash_combine(seed, vec_hash(ps.dilation));
84+
// Exclude padding, stride, and dilation from hash
8185
return seed;
8286
}
8387
};
@@ -132,8 +136,9 @@ Kernel_f8f8bf16_conv get_kernel_via_heuristic(
132136
if (it != kernel_map.end()) {
133137
return it->second;
134138
} else {
135-
std::cout << "warning: not found";
139+
std::cout << "warning: not found - ";
136140
ps.print();
141+
std::cout << std::endl;
137142
}
138143
// Fallback kernel
139144
return f8f8bf16_conv_256x256x128_2x1x1;

0 commit comments

Comments
 (0)