Skip to content

Commit 6088dd8

Browse files
fyellinalmarklein
andauthored
overrideable constants (#579)
* Initial implementation of: overrideable constands eliding entry_point when only one. * Fix lint and codegen issues. * More test info. Fix leak * Finish up tests for override. --------- Co-authored-by: Almar Klein <almar@almarklein.org>
1 parent dfa41d8 commit 6088dd8

File tree

3 files changed

+261
-19
lines changed

3 files changed

+261
-19
lines changed

tests/test_set_override.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import pytest
2+
3+
import wgpu.utils
4+
from tests.testutils import can_use_wgpu_lib, run_tests
5+
from wgpu import TextureFormat
6+
7+
if not can_use_wgpu_lib:
8+
pytest.skip("Skipping tests that need the wgpu lib", allow_module_level=True)
9+
10+
11+
"""
12+
The vertex shader should be called exactly once, which then calls the fragment shader
13+
exactly once. Alternatively, we call the compute shader exactly once
14+
15+
This copies the values of the four variables a, b, c, and d as seen by each of the shaders
16+
and writes it into a buffer. We can then examine that buffer to see the values of the
17+
constants.
18+
19+
This code is also showing that you no longer need to include the name of a shader when
20+
it is the only shader of that type.
21+
"""
22+
23+
SHADER_SOURCE = """
24+
override a: i32 = 1;
25+
override b: u32 = 2u;
26+
@id(1) override c: f32 = 3.0;
27+
@id(2) override d: bool = false;
28+
29+
// Put the results here
30+
@group(0) @binding(0) var<storage, read_write> data: array<u32>;
31+
32+
struct VertexOutput {
33+
@location(0) values: vec4u,
34+
@builtin(position) position: vec4f,
35+
}
36+
37+
@vertex
38+
fn vertex(@builtin(vertex_index) index: u32) -> VertexOutput {
39+
var output: VertexOutput;
40+
output.position = vec4f(0, 0, 0, 1);
41+
output.values = vec4u(u32(a), u32(b), u32(c), u32(d));
42+
return output;
43+
}
44+
45+
@fragment
46+
fn fragment(output: VertexOutput) -> @location(0) vec4f {
47+
let values1 = output.values;
48+
let values2 = vec4u(u32(a), u32(b), u32(c), u32(d));
49+
write_results(values1, values2);
50+
return vec4f();
51+
}
52+
53+
@compute @workgroup_size(1)
54+
fn computeMain() {
55+
let results = vec4u(u32(a), u32(b), u32(c), u32(d));
56+
write_results(results, results);
57+
}
58+
59+
fn write_results(results1: vec4u, results2: vec4u) {
60+
for (var i = 0; i < 4; i++) {
61+
data[i] = results1[i];
62+
data[i + 4] = results2[i];
63+
}
64+
}
65+
"""
66+
67+
BIND_GROUP_ENTRIES = [
68+
{"binding": 0, "visibility": "FRAGMENT|COMPUTE", "buffer": {"type": "storage"}},
69+
]
70+
71+
72+
class Runner:
73+
def __init__(self):
74+
self.device = device = wgpu.utils.get_default_device()
75+
self.output_texture = device.create_texture(
76+
# Actual size is immaterial. Could just be 1x1
77+
size=[128, 128],
78+
format=TextureFormat.rgba8unorm,
79+
usage="RENDER_ATTACHMENT|COPY_SRC",
80+
)
81+
self.shader = device.create_shader_module(code=SHADER_SOURCE)
82+
bind_group_layout = device.create_bind_group_layout(entries=BIND_GROUP_ENTRIES)
83+
self.render_pipeline_layout = device.create_pipeline_layout(
84+
bind_group_layouts=[bind_group_layout],
85+
)
86+
87+
self.output_buffer = device.create_buffer(size=8 * 4, usage="STORAGE|COPY_SRC")
88+
self.bind_group = device.create_bind_group(
89+
layout=bind_group_layout,
90+
entries=[
91+
{"binding": 0, "resource": {"buffer": self.output_buffer}},
92+
],
93+
)
94+
95+
self.color_attachment = {
96+
"clear_value": (0, 0, 0, 0), # only first value matters
97+
"load_op": "clear",
98+
"store_op": "store",
99+
"view": self.output_texture.create_view(),
100+
}
101+
102+
def create_render_pipeline(self, vertex_constants, fragment_constants):
103+
return self.device.create_render_pipeline(
104+
layout=self.render_pipeline_layout,
105+
vertex={
106+
"module": self.shader,
107+
"constants": vertex_constants,
108+
},
109+
fragment={
110+
"module": self.shader,
111+
"targets": [{"format": self.output_texture.format}],
112+
"constants": fragment_constants,
113+
},
114+
primitive={
115+
"topology": "point-list",
116+
},
117+
)
118+
119+
def create_compute_pipeline(self, constants):
120+
return self.device.create_compute_pipeline(
121+
layout=self.render_pipeline_layout,
122+
compute={
123+
"module": self.shader,
124+
"constants": constants,
125+
},
126+
)
127+
128+
def run_test(
129+
self,
130+
*,
131+
render: bool = False,
132+
compute: bool = False,
133+
vertex_constants=None,
134+
fragment_constants=None,
135+
compute_constants=None
136+
):
137+
assert render + compute == 1
138+
device = self.device
139+
encoder = device.create_command_encoder()
140+
if render:
141+
this_pass = encoder.begin_render_pass(
142+
color_attachments=[self.color_attachment]
143+
)
144+
pipeline = self.create_render_pipeline(vertex_constants, fragment_constants)
145+
else:
146+
this_pass = encoder.begin_compute_pass()
147+
pipeline = self.create_compute_pipeline(compute_constants)
148+
this_pass.set_bind_group(0, self.bind_group)
149+
this_pass.set_pipeline(pipeline)
150+
if render:
151+
this_pass.draw(1)
152+
else:
153+
this_pass.dispatch_workgroups(1)
154+
this_pass.end()
155+
device.queue.submit([encoder.finish()])
156+
result = device.queue.read_buffer(self.output_buffer).cast("I").tolist()
157+
if compute:
158+
result = result[:4]
159+
print(result)
160+
return result
161+
162+
163+
@pytest.fixture(scope="module")
164+
def runner():
165+
return Runner()
166+
167+
168+
def test_no_overridden_constants_render(runner):
169+
assert runner.run_test(render=True) == [1, 2, 3, 0, 1, 2, 3, 0]
170+
171+
172+
def test_no_constants_compute(runner):
173+
runner.run_test(compute=True) == [1, 2, 3, 0]
174+
175+
176+
def test_override_vertex_constants(runner):
177+
# Note that setting "d" to any non-zero value is setting it to True
178+
overrides = {"a": 21, "b": 22, 1: 23, 2: 24}
179+
assert [21, 22, 23, 1, 1, 2, 3, 0] == runner.run_test(
180+
render=True, vertex_constants=overrides
181+
)
182+
183+
184+
def test_override_fragment_constants(runner):
185+
# Note that setting "d" to any non-zero value is setting it to True
186+
overrides = {"a": 21, "b": 22, 1: 23, 2: -1}
187+
assert [1, 2, 3, 0, 21, 22, 23, 1] == runner.run_test(
188+
render=True, fragment_constants=overrides
189+
)
190+
191+
192+
def test_override_compute_constants(runner):
193+
# Note that setting "d" to any non-zero value is setting it to True
194+
overrides = {"a": 21, "b": 22, 1: 23, 2: 24}
195+
assert [21, 22, 23, 1] == runner.run_test(compute=True, compute_constants=overrides)
196+
197+
198+
def test_numbered_constants_must_be_overridden_by_number(runner):
199+
overrides = {"c": 23, "d": 24}
200+
# This does absolutely nothing. It doesn't even error.
201+
assert [1, 2, 3, 0, 1, 2, 3, 0] == runner.run_test(
202+
render=True, vertex_constants=overrides, fragment_constants=overrides
203+
)
204+
205+
206+
if __name__ == "__main__":
207+
run_tests(globals())

wgpu/backends/wgpu_native/_api.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import ctypes.util
2121
from weakref import WeakKeyDictionary
22-
from typing import List, Dict, Union
22+
from typing import List, Dict, Optional, Union
2323

2424
from ... import classes, flags, enums, structs
2525
from ..._coreutils import str_flag_to_int
@@ -172,6 +172,36 @@ def _tuple_from_color(rgba):
172172
return _tuple_from_tuple_or_dict(rgba, "rgba")
173173

174174

175+
def _get_override_constant_entries(field):
176+
constants = field.get("constants")
177+
if not constants:
178+
return ffi.NULL, []
179+
c_constant_entries = []
180+
for key, value in constants.items():
181+
assert isinstance(key, (str, int))
182+
assert isinstance(value, (int, float, bool))
183+
# H: nextInChain: WGPUChainedStruct *, key: char *, value: float
184+
c_constant_entry = new_struct(
185+
"WGPUConstantEntry",
186+
key=to_c_string(str(key)),
187+
value=float(value),
188+
# not used: nextInChain
189+
)
190+
c_constant_entries.append(c_constant_entry)
191+
# We need to return and hold onto c_constant_entries in order to prevent the C
192+
# strings from being GC'ed.
193+
c_constants = ffi.new("WGPUConstantEntry[]", c_constant_entries)
194+
return c_constants, c_constant_entries
195+
196+
197+
def to_c_string(string: str):
198+
return ffi.new("char []", string.encode())
199+
200+
201+
def to_c_string_or_null(string: Optional[str]):
202+
return ffi.NULL if string is None else ffi.new("char []", string.encode())
203+
204+
175205
_empty_label = ffi.new("char []", b"")
176206

177207

@@ -180,7 +210,7 @@ def to_c_label(label):
180210
if not label:
181211
return _empty_label
182212
else:
183-
return ffi.new("char []", label.encode())
213+
return to_c_string(label)
184214

185215

186216
def feature_flag_to_feature_names(flag):
@@ -945,7 +975,7 @@ def canonicalize_limit_name(name):
945975

946976
c_trace_path = ffi.NULL
947977
if trace_path: # no-cover
948-
c_trace_path = ffi.new("char []", trace_path.encode())
978+
c_trace_path = to_c_string(trace_path)
949979

950980
# H: chain: WGPUChainedStruct, tracePath: char *
951981
extras = new_struct_p(
@@ -1485,15 +1515,15 @@ def create_shader_module(
14851515
# H: name: char *, value: char *
14861516
new_struct(
14871517
"WGPUShaderDefine",
1488-
name=ffi.new("char []", "gl_VertexID".encode()),
1489-
value=ffi.new("char []", "gl_VertexIndex".encode()),
1518+
name=ffi.new("char []", b"gl_VertexID"),
1519+
value=ffi.new("char []", b"gl_VertexIndex"),
14901520
)
14911521
)
14921522
c_defines = ffi.new("WGPUShaderDefine []", defines)
14931523
# H: chain: WGPUChainedStruct, stage: WGPUShaderStage, code: char *, defineCount: int, defines: WGPUShaderDefine *
14941524
source_struct = new_struct_p(
14951525
"WGPUShaderModuleGLSLDescriptor *",
1496-
code=ffi.new("char []", code.encode()),
1526+
code=to_c_string(code),
14971527
stage=c_stage,
14981528
defineCount=len(defines),
14991529
defines=c_defines,
@@ -1506,7 +1536,7 @@ def create_shader_module(
15061536
# H: chain: WGPUChainedStruct, code: char *
15071537
source_struct = new_struct_p(
15081538
"WGPUShaderModuleWGSLDescriptor *",
1509-
code=ffi.new("char []", code.encode()),
1539+
code=to_c_string(code),
15101540
# not used: chain
15111541
)
15121542
source_struct[0].chain.next = ffi.NULL
@@ -1558,14 +1588,15 @@ def create_compute_pipeline(
15581588
compute: "structs.ProgrammableStage",
15591589
):
15601590
check_struct("ProgrammableStage", compute)
1591+
c_constants, c_constant_entries = _get_override_constant_entries(compute)
15611592
# H: nextInChain: WGPUChainedStruct *, module: WGPUShaderModule, entryPoint: char *, constantCount: int, constants: WGPUConstantEntry *
15621593
c_compute_stage = new_struct(
15631594
"WGPUProgrammableStageDescriptor",
15641595
module=compute["module"]._internal,
1565-
entryPoint=ffi.new("char []", compute["entry_point"].encode()),
1596+
entryPoint=to_c_string_or_null(compute.get("entry_point")),
1597+
constantCount=len(c_constant_entries),
1598+
constants=c_constants,
15661599
# not used: nextInChain
1567-
# not used: constantCount
1568-
# not used: constants
15691600
)
15701601

15711602
if isinstance(layout, GPUPipelineLayout):
@@ -1643,16 +1674,17 @@ def create_render_pipeline(
16431674
c_vertex_buffer_descriptors_array = ffi.new(
16441675
"WGPUVertexBufferLayout []", c_vertex_buffer_layout_list
16451676
)
1677+
c_vertex_constants, c_vertex_entries = _get_override_constant_entries(vertex)
16461678
# H: nextInChain: WGPUChainedStruct *, module: WGPUShaderModule, entryPoint: char *, constantCount: int, constants: WGPUConstantEntry *, bufferCount: int, buffers: WGPUVertexBufferLayout *
16471679
c_vertex_state = new_struct(
16481680
"WGPUVertexState",
16491681
module=vertex["module"]._internal,
1650-
entryPoint=ffi.new("char []", vertex["entry_point"].encode()),
1682+
entryPoint=to_c_string_or_null(vertex.get("entry_point")),
16511683
buffers=c_vertex_buffer_descriptors_array,
16521684
bufferCount=len(c_vertex_buffer_layout_list),
1685+
constantCount=len(c_vertex_entries),
1686+
constants=c_vertex_constants,
16531687
# not used: nextInChain
1654-
# not used: constantCount
1655-
# not used: constants
16561688
)
16571689

16581690
# H: nextInChain: WGPUChainedStruct *, topology: WGPUPrimitiveTopology, stripIndexFormat: WGPUIndexFormat, frontFace: WGPUFrontFace, cullMode: WGPUCullMode
@@ -1753,16 +1785,19 @@ def create_render_pipeline(
17531785
"WGPUColorTargetState []", c_color_targets_list
17541786
)
17551787
check_struct("FragmentState", fragment)
1788+
c_fragment_constants, c_fragment_entries = _get_override_constant_entries(
1789+
fragment
1790+
)
17561791
# H: nextInChain: WGPUChainedStruct *, module: WGPUShaderModule, entryPoint: char *, constantCount: int, constants: WGPUConstantEntry *, targetCount: int, targets: WGPUColorTargetState *
17571792
c_fragment_state = new_struct_p(
17581793
"WGPUFragmentState *",
17591794
module=fragment["module"]._internal,
1760-
entryPoint=ffi.new("char []", fragment["entry_point"].encode()),
1795+
entryPoint=to_c_string_or_null(fragment.get("entry_point")),
17611796
targets=c_color_targets_array,
17621797
targetCount=len(c_color_targets_list),
1798+
constantCount=len(c_fragment_entries),
1799+
constants=c_fragment_constants,
17631800
# not used: nextInChain
1764-
# not used: constantCount
1765-
# not used: constants
17661801
)
17671802

17681803
if isinstance(layout, GPUPipelineLayout):
@@ -2315,7 +2350,7 @@ def set_bind_group(
23152350
class GPUDebugCommandsMixin(classes.GPUDebugCommandsMixin):
23162351
# whole class is likely going to solved better: https://github.com/pygfx/wgpu-py/pull/546
23172352
def push_debug_group(self, group_label):
2318-
c_group_label = ffi.new("char []", group_label.encode())
2353+
c_group_label = to_c_string(group_label)
23192354
# H: void wgpuCommandEncoderPushDebugGroup(WGPUCommandEncoder commandEncoder, char const * groupLabel)
23202355
# H: void wgpuComputePassEncoderPushDebugGroup(WGPUComputePassEncoder computePassEncoder, char const * groupLabel)
23212356
# H: void wgpuRenderPassEncoderPushDebugGroup(WGPURenderPassEncoder renderPassEncoder, char const * groupLabel)
@@ -2332,7 +2367,7 @@ def pop_debug_group(self):
23322367
function(self._internal)
23332368

23342369
def insert_debug_marker(self, marker_label):
2335-
c_marker_label = ffi.new("char []", marker_label.encode())
2370+
c_marker_label = to_c_string(marker_label)
23362371
# H: void wgpuCommandEncoderInsertDebugMarker(WGPUCommandEncoder commandEncoder, char const * markerLabel)
23372372
# H: void wgpuComputePassEncoderInsertDebugMarker(WGPUComputePassEncoder computePassEncoder, char const * markerLabel)
23382373
# H: void wgpuRenderPassEncoderInsertDebugMarker(WGPURenderPassEncoder renderPassEncoder, char const * markerLabel)

wgpu/resources/codegen_report.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@
3737
* Wrote 236 enum mappings and 47 struct-field mappings to wgpu_native/_mappings.py
3838
* Validated 131 C function calls
3939
* Not using 72 C functions
40-
* Validated 80 C structs
40+
* Validated 81 C structs

0 commit comments

Comments
 (0)