@@ -37,7 +37,9 @@ struct ProblemSize {
3737 std::vector<int64_t > stride;
3838 std::vector<int64_t > dilation;
3939 bool operator ==(const ProblemSize& ps) const {
40- return activation_shape == ps.activation_shape &&
40+ return activation_shape[1 ] == ps.activation_shape [1 ] &&
41+ activation_shape[2 ] == ps.activation_shape [2 ] &&
42+ activation_shape[3 ] == ps.activation_shape [3 ] &&
4143 filter_shape == ps.filter_shape ;
4244 }
4345 void print () const {
@@ -53,8 +55,7 @@ struct ProblemSize {
5355 << filter_shape[1 ] << " ,"
5456 << filter_shape[2 ] << " ,"
5557 << filter_shape[3 ] << " ,"
56- << filter_shape[4 ] << " ,"
57- << std::endl;
58+ << filter_shape[4 ] << " ," ;
5859 // clang-format on
5960 }
6061};
@@ -67,17 +68,20 @@ inline void hash_combine(std::size_t& seed, std::size_t value) {
6768struct ProblemSizeHash {
6869 std::size_t operator ()(const ProblemSize& ps) const {
6970 std::size_t seed = 0 ;
71+ // Only hash spatial dimensions (D, H, W) from activation_shape, not batch
72+ // (N) or channels (C)
73+ hash_combine (seed, std::hash<int64_t >{}(ps.activation_shape [1 ]));
74+ hash_combine (seed, std::hash<int64_t >{}(ps.activation_shape [2 ]));
75+ hash_combine (seed, std::hash<int64_t >{}(ps.activation_shape [3 ]));
76+ // Hash the entire filter_shape
7077 auto vec_hash = [](const std::vector<int64_t >& v) {
7178 std::size_t h = 0 ;
7279 for (auto x : v)
7380 hash_combine (h, std::hash<int64_t >{}(x));
7481 return h;
7582 };
76- hash_combine (seed, vec_hash (ps.activation_shape ));
7783 hash_combine (seed, vec_hash (ps.filter_shape ));
78- // hash_combine(seed, vec_hash(ps.padding));
79- // hash_combine(seed, vec_hash(ps.stride));
80- // hash_combine(seed, vec_hash(ps.dilation));
84+ // Exclude padding, stride, and dilation from hash
8185 return seed;
8286 }
8387};
@@ -132,8 +136,9 @@ Kernel_f8f8bf16_conv get_kernel_via_heuristic(
132136 if (it != kernel_map.end ()) {
133137 return it->second ;
134138 } else {
135- std::cout << " warning: not found" ;
139+ std::cout << " warning: not found - " ;
136140 ps.print ();
141+ std::cout << std::endl;
137142 }
138143 // Fallback kernel
139144 return f8f8bf16_conv_256x256x128_2x1x1;
0 commit comments