Skip to content

Commit 4e2b80c

Browse files
authored
Add WaveActiveMax tests (#429)
Adds WaveActiveMax tests. Fixes #122
1 parent 9c27399 commit 4e2b80c

File tree

6 files changed

+1892
-0
lines changed

6 files changed

+1892
-0
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
#--- source.hlsl
2+
#define VALUE_SETS 2
3+
#define NUM_MASKS 4
4+
#define NUM_THREADS 4
5+
6+
struct MaskStruct {
7+
int mask[NUM_THREADS];
8+
};
9+
10+
StructuredBuffer<half4> In : register(t0);
11+
RWStructuredBuffer<half> Out1 : register(u1); // test scalar
12+
RWStructuredBuffer<half2> Out2 : register(u2); // test half2
13+
RWStructuredBuffer<half4> Out3 : register(u3); // test half3
14+
RWStructuredBuffer<half4> Out4 : register(u4); // test half4
15+
RWStructuredBuffer<half4> Out5 : register(u5); // constant folding
16+
StructuredBuffer<MaskStruct> Masks : register(t6);
17+
18+
19+
[numthreads(NUM_THREADS,1,1)]
20+
void main(uint3 tid : SV_GroupThreadID)
21+
{
22+
for (uint ValueSet = 0; ValueSet < VALUE_SETS; ValueSet++) {
23+
const uint ValueSetOffset = ValueSet * NUM_MASKS * NUM_THREADS;
24+
for (uint MaskIdx = 0; MaskIdx < NUM_MASKS; MaskIdx++) {
25+
half4 v = In[ValueSet * ValueSetOffset + MaskIdx * NUM_THREADS + tid.x];
26+
const uint OutIdx = ValueSetOffset + MaskIdx * NUM_THREADS + tid.x;
27+
if (Masks[MaskIdx].mask[tid.x]) {
28+
Out1[OutIdx] = WaveActiveMax( v.x );
29+
Out2[OutIdx].xy = WaveActiveMax( v.xy );
30+
Out3[OutIdx].xyz = WaveActiveMax( v.xyz );
31+
Out4[OutIdx] = WaveActiveMax( v );
32+
}
33+
}
34+
}
35+
36+
// constant folding case
37+
Out5[0] = WaveActiveMax(half4(1,2,3,4));
38+
}
39+
40+
41+
//--- pipeline.yaml
42+
43+
---
44+
Shaders:
45+
- Stage: Compute
46+
Entry: main
47+
DispatchSize: [1, 1, 1]
48+
Buffers:
49+
- Name: In
50+
Format: Float16
51+
Stride: 8
52+
# 2 value sets
53+
# For each value set,
54+
# and for each specific one of the 4 thread masks in that value set,
55+
# and for each of the 4 threads in that thread mask,
56+
# there will be a unique set of 4 values, such that
57+
# none of the other threads in that thread mask share any values
58+
Data: [
59+
0x2000, 0x2200, 0x2400, 0x2800, # <-- Value set 0, thread mask 0, thread id 0 will read these In values
60+
0x2A00, 0x2C00, 0x2E00, 0x3000, # <-- Value set 0, thread mask 0, thread id 1 will read these In values
61+
0x3200, 0x3400, 0x3600, 0x3800,
62+
0x3900, 0x3A00, 0x3B00, 0x3BC0,
63+
0x2200, 0x2400, 0x2800, 0x2A00, # <-- Value set 0, thread mask 1, thread id 0 will read these In values
64+
0x2C00, 0x2E00, 0x3000, 0x3200,
65+
0x3400, 0x3600, 0x3800, 0x3900,
66+
0x3A00, 0x3B00, 0x3BC0, 0x2000,
67+
0x2400, 0x2800, 0x2A00, 0x2C00,
68+
0x2E00, 0x3000, 0x3200, 0x3400,
69+
0x3600, 0x3800, 0x3900, 0x3A00,
70+
0x3B00, 0x3BC0, 0x2000, 0x2200,
71+
0x2800, 0x2A00, 0x2C00, 0x2E00,
72+
0x3000, 0x3200, 0x3400, 0x3600,
73+
0x3800, 0x3900, 0x3A00, 0x3B00,
74+
0x3BC0, 0x2000, 0x2200, 0x2400,
75+
0x2800, 0x2400, 0x2200, 0x2000, # <-- Value set 1, thread mask 0, thread id 0 will read these In values
76+
0x3000, 0x2E00, 0x2C00, 0x2A00,
77+
0x3800, 0x3600, 0x3400, 0x3200,
78+
0x3BC0, 0x3B00, 0x3A00, 0x3900,
79+
0x2A00, 0x2800, 0x2400, 0x2200,
80+
0x3200, 0x3000, 0x2E00, 0x2C00,
81+
0x3900, 0x3800, 0x3600, 0x3400,
82+
0x2000, 0x3BC0, 0x3B00, 0x3A00,
83+
0x2C00, 0x2A00, 0x2800, 0x2400,
84+
0x3400, 0x3200, 0x3000, 0x2E00,
85+
0x3A00, 0x3900, 0x3800, 0x3600,
86+
0x2200, 0x2000, 0x3BC0, 0x3B00,
87+
0x2E00, 0x2C00, 0x2A00, 0x2800,
88+
0x3600, 0x3400, 0x3200, 0x3000,
89+
0x3B00, 0x3A00, 0x3900, 0x3800,
90+
0x2400, 0x2200, 0x2000, 0x3BC0 ]
91+
92+
- Name: Out1
93+
Format: Float16
94+
Stride: 2
95+
# 1 half is 2 bytes, * 4 halves for 4 threads, * 4 thread masks, * 2 value sets
96+
FillSize: 64
97+
- Name: Out2
98+
Format: Float16
99+
Stride: 4
100+
FillSize: 128
101+
- Name: Out3
102+
Format: Float16
103+
Stride: 8
104+
FillSize: 256
105+
- Name: Out4
106+
Format: Float16
107+
Stride: 8
108+
FillSize: 256
109+
- Name: Out5
110+
Format: Float16
111+
Stride: 8
112+
FillSize: 8
113+
- Name: Masks
114+
Format: Int32
115+
Stride: 16
116+
# 4 active mask sets for threads 0, 1, 2, 3:
117+
# 0 0 0 0
118+
# 1 1 1 1
119+
# 1 0 0 0
120+
# 0 1 1 0
121+
Data: [
122+
0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0]
123+
- Name: ExpectedOut1
124+
Format: Float16
125+
Stride: 8
126+
# 2 value sets, 4 masks per value set, 4 threads per mask, 1 result value per thread
127+
Data: [ 0x0, 0x0, 0x0, 0x0,
128+
0x3A00, 0x3A00, 0x3A00, 0x3A00,
129+
0x2400, 0x0, 0x0, 0x0,
130+
0x0, 0x3800, 0x3800, 0x0,
131+
0x0, 0x0, 0x0, 0x0,
132+
0x3900, 0x3900, 0x3900, 0x3900,
133+
0x2C00, 0x0, 0x0, 0x0,
134+
0x0, 0x3B00, 0x3B00, 0x0 ]
135+
- Name: ExpectedOut2
136+
Format: Float16
137+
Stride: 8
138+
# 2 value sets, 4 masks per value set, 4 threads per mask, 1 result value per thread
139+
Data: [ 0x0, 0x0, 0x0, 0x0,
140+
0x0, 0x0, 0x0, 0x0,
141+
0x3A00, 0x3B00, 0x3A00, 0x3B00,
142+
0x3A00, 0x3B00, 0x3A00, 0x3B00,
143+
0x2400, 0x2800, 0x0, 0x0,
144+
0x0, 0x0, 0x0, 0x0,
145+
0x0, 0x0, 0x3800, 0x3900,
146+
0x3800, 0x3900, 0x0, 0x0,
147+
0x0, 0x0, 0x0, 0x0,
148+
0x0, 0x0, 0x0, 0x0,
149+
0x3900, 0x3BC0, 0x3900, 0x3BC0,
150+
0x3900, 0x3BC0, 0x3900, 0x3BC0,
151+
0x2C00, 0x2A00, 0x0, 0x0,
152+
0x0, 0x0, 0x0, 0x0,
153+
0x0, 0x0, 0x3B00, 0x3A00,
154+
0x3B00, 0x3A00, 0x0, 0x0 ]
155+
- Name: ExpectedOut3
156+
Format: Float16
157+
Stride: 8
158+
# 2 value sets, 4 masks per value set, 4 threads per mask, 4 result values per thread
159+
# Note, vecs of 3 must be aligned, so the 3 result values are placed into a 4 element vec
160+
Data: [ 0x0, 0x0, 0x0, 0x0,
161+
0x0, 0x0, 0x0, 0x0,
162+
0x0, 0x0, 0x0, 0x0,
163+
0x0, 0x0, 0x0, 0x0,
164+
0x3A00, 0x3B00, 0x3BC0, 0x0,
165+
0x3A00, 0x3B00, 0x3BC0, 0x0,
166+
0x3A00, 0x3B00, 0x3BC0, 0x0,
167+
0x3A00, 0x3B00, 0x3BC0, 0x0,
168+
0x2400, 0x2800, 0x2A00, 0x0,
169+
0x0, 0x0, 0x0, 0x0,
170+
0x0, 0x0, 0x0, 0x0,
171+
0x0, 0x0, 0x0, 0x0,
172+
0x0, 0x0, 0x0, 0x0,
173+
0x3800, 0x3900, 0x3A00, 0x0,
174+
0x3800, 0x3900, 0x3A00, 0x0,
175+
0x0, 0x0, 0x0, 0x0,
176+
0x0, 0x0, 0x0, 0x0,
177+
0x0, 0x0, 0x0, 0x0,
178+
0x0, 0x0, 0x0, 0x0,
179+
0x0, 0x0, 0x0, 0x0,
180+
0x3900, 0x3BC0, 0x3B00, 0x0,
181+
0x3900, 0x3BC0, 0x3B00, 0x0,
182+
0x3900, 0x3BC0, 0x3B00, 0x0,
183+
0x3900, 0x3BC0, 0x3B00, 0x0,
184+
0x2C00, 0x2A00, 0x2800, 0x0,
185+
0x0, 0x0, 0x0, 0x0,
186+
0x0, 0x0, 0x0, 0x0,
187+
0x0, 0x0, 0x0, 0x0,
188+
0x0, 0x0, 0x0, 0x0,
189+
0x3B00, 0x3A00, 0x3900, 0x0,
190+
0x3B00, 0x3A00, 0x3900, 0x0,
191+
0x0, 0x0, 0x0, 0x0 ]
192+
- Name: ExpectedOut4
193+
Format: Float16
194+
Stride: 8
195+
Data: [ 0x0, 0x0, 0x0, 0x0,
196+
0x0, 0x0, 0x0, 0x0,
197+
0x0, 0x0, 0x0, 0x0,
198+
0x0, 0x0, 0x0, 0x0,
199+
0x3A00, 0x3B00, 0x3BC0, 0x3900,
200+
0x3A00, 0x3B00, 0x3BC0, 0x3900,
201+
0x3A00, 0x3B00, 0x3BC0, 0x3900,
202+
0x3A00, 0x3B00, 0x3BC0, 0x3900,
203+
0x2400, 0x2800, 0x2A00, 0x2C00,
204+
0x0, 0x0, 0x0, 0x0,
205+
0x0, 0x0, 0x0, 0x0,
206+
0x0, 0x0, 0x0, 0x0,
207+
0x0, 0x0, 0x0, 0x0,
208+
0x3800, 0x3900, 0x3A00, 0x3B00,
209+
0x3800, 0x3900, 0x3A00, 0x3B00,
210+
0x0, 0x0, 0x0, 0x0,
211+
0x0, 0x0, 0x0, 0x0,
212+
0x0, 0x0, 0x0, 0x0,
213+
0x0, 0x0, 0x0, 0x0,
214+
0x0, 0x0, 0x0, 0x0,
215+
0x3900, 0x3BC0, 0x3B00, 0x3A00,
216+
0x3900, 0x3BC0, 0x3B00, 0x3A00,
217+
0x3900, 0x3BC0, 0x3B00, 0x3A00,
218+
0x3900, 0x3BC0, 0x3B00, 0x3A00,
219+
0x2C00, 0x2A00, 0x2800, 0x2400,
220+
0x0, 0x0, 0x0, 0x0,
221+
0x0, 0x0, 0x0, 0x0,
222+
0x0, 0x0, 0x0, 0x0,
223+
0x0, 0x0, 0x0, 0x0,
224+
0x3B00, 0x3A00, 0x3900, 0x3800,
225+
0x3B00, 0x3A00, 0x3900, 0x3800,
226+
0x0, 0x0, 0x0, 0x0 ]
227+
- Name: ExpectedOut5
228+
Format: Float16
229+
Stride: 8
230+
Data: [ 0x3C00, 0x4000, 0x4200, 0x4400 ]
231+
Results:
232+
- Result: ExpectedOut1
233+
Rule: BufferExact
234+
Actual: Out1
235+
Expected: ExpectedOut1
236+
- Result: ExpectedOut2
237+
Rule: BufferExact
238+
Actual: Out2
239+
Expected: ExpectedOut2
240+
- Result: ExpectedOut3
241+
Rule: BufferExact
242+
Actual: Out3
243+
Expected: ExpectedOut3
244+
- Result: ExpectedOut4
245+
Rule: BufferExact
246+
Actual: Out4
247+
Expected: ExpectedOut4
248+
- Result: ExpectedOut5
249+
Rule: BufferExact
250+
Actual: Out5
251+
Expected: ExpectedOut5
252+
DescriptorSets:
253+
- Resources:
254+
- Name: In
255+
Kind: StructuredBuffer
256+
DirectXBinding:
257+
Register: 0
258+
Space: 0
259+
VulkanBinding:
260+
Binding: 0
261+
- Name: Out1
262+
Kind: RWStructuredBuffer
263+
DirectXBinding:
264+
Register: 1
265+
Space: 0
266+
VulkanBinding:
267+
Binding: 1
268+
- Name: Out2
269+
Kind: RWStructuredBuffer
270+
DirectXBinding:
271+
Register: 2
272+
Space: 0
273+
VulkanBinding:
274+
Binding: 2
275+
- Name: Out3
276+
Kind: RWStructuredBuffer
277+
DirectXBinding:
278+
Register: 3
279+
Space: 0
280+
VulkanBinding:
281+
Binding: 3
282+
- Name: Out4
283+
Kind: RWStructuredBuffer
284+
DirectXBinding:
285+
Register: 4
286+
Space: 0
287+
VulkanBinding:
288+
Binding: 4
289+
- Name: Out5
290+
Kind: RWStructuredBuffer
291+
DirectXBinding:
292+
Register: 5
293+
Space: 0
294+
VulkanBinding:
295+
Binding: 5
296+
- Name: Masks
297+
Kind: StructuredBuffer
298+
DirectXBinding:
299+
Register: 6
300+
Space: 0
301+
VulkanBinding:
302+
Binding: 6
303+
304+
...
305+
#--- end
306+
307+
# Bug https://github.com/llvm/llvm-project/issues/156775
308+
# XFAIL: Clang
309+
310+
# Bug https://github.com/llvm/offload-test-suite/issues/393
311+
# XFAIL: Metal
312+
313+
# RUN: split-file %s %t
314+
# RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl
315+
# RUN: %offloader %t/pipeline.yaml %t.o

0 commit comments

Comments
 (0)