Skip to content

Commit 17a17f7

Browse files
Initial precompiled shaders implementation (#7834)
Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
1 parent eb9b2e9 commit 17a17f7

File tree

26 files changed

+383
-315
lines changed

26 files changed

+383
-315
lines changed

CHANGELOG.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,27 @@ We have merged the acceleration structure feature into the `RayQuery` feature. T
5353

5454
By @Vecvec in [#7913](https://github.com/gfx-rs/wgpu/pull/7913).
5555

56+
#### New `EXPERIMENTAL_PRECOMPILED_SHADERS` API
57+
We have added `Features::EXPERIMENTAL_PRECOMPILED_SHADERS`, replacing existing passthrough types with a unified `CreateShaderModuleDescriptorPassthrough` which allows passing multiple shader codes for different backends. By @SupaMaggie70Incorporated in [#7834](https://github.com/gfx-rs/wgpu/pull/7834)
58+
59+
Difference for SPIR-V passthrough:
60+
```diff
61+
- device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
62+
- wgpu::ShaderModuleDescriptorSpirV {
63+
- label: None,
64+
- source: spirv_code,
65+
- },
66+
- ))
67+
+ device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
68+
+ entry_point: "main".into(),
69+
+ label: None,
70+
+ spirv: Some(spirv_code),
71+
+ ..Default::default()
72+
})
73+
```
74+
This allows using precompiled shaders without manually checking which backend's code to pass, for example if you have shaders precompiled for both DXIL and SPIR-V.
75+
76+
5677
### New Features
5778

5879
#### General

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

deno_webgpu/webidl.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,6 @@ pub enum GPUFeatureName {
419419
VertexWritableStorage,
420420
#[webidl(rename = "clear-texture")]
421421
ClearTexture,
422-
#[webidl(rename = "msl-shader-passthrough")]
423-
MslShaderPassthrough,
424-
#[webidl(rename = "spirv-shader-passthrough")]
425-
SpirvShaderPassthrough,
426422
#[webidl(rename = "multiview")]
427423
Multiview,
428424
#[webidl(rename = "vertex-attribute-64-bit")]
@@ -435,6 +431,8 @@ pub enum GPUFeatureName {
435431
ShaderPrimitiveIndex,
436432
#[webidl(rename = "shader-early-depth-test")]
437433
ShaderEarlyDepthTest,
434+
#[webidl(rename = "passthrough-shaders")]
435+
PassthroughShaders,
438436
}
439437

440438
pub fn feature_names_to_features(names: Vec<GPUFeatureName>) -> wgpu_types::Features {
@@ -482,14 +480,13 @@ pub fn feature_names_to_features(names: Vec<GPUFeatureName>) -> wgpu_types::Feat
482480
GPUFeatureName::ConservativeRasterization => Features::CONSERVATIVE_RASTERIZATION,
483481
GPUFeatureName::VertexWritableStorage => Features::VERTEX_WRITABLE_STORAGE,
484482
GPUFeatureName::ClearTexture => Features::CLEAR_TEXTURE,
485-
GPUFeatureName::MslShaderPassthrough => Features::MSL_SHADER_PASSTHROUGH,
486-
GPUFeatureName::SpirvShaderPassthrough => Features::SPIRV_SHADER_PASSTHROUGH,
487483
GPUFeatureName::Multiview => Features::MULTIVIEW,
488484
GPUFeatureName::VertexAttribute64Bit => Features::VERTEX_ATTRIBUTE_64BIT,
489485
GPUFeatureName::ShaderF64 => Features::SHADER_F64,
490486
GPUFeatureName::ShaderI16 => Features::SHADER_I16,
491487
GPUFeatureName::ShaderPrimitiveIndex => Features::SHADER_PRIMITIVE_INDEX,
492488
GPUFeatureName::ShaderEarlyDepthTest => Features::SHADER_EARLY_DEPTH_TEST,
489+
GPUFeatureName::PassthroughShaders => Features::EXPERIMENTAL_PASSTHROUGH_SHADERS,
493490
};
494491
features.set(feature, true);
495492
}
@@ -626,9 +623,6 @@ pub fn features_to_feature_names(features: wgpu_types::Features) -> HashSet<GPUF
626623
if features.contains(wgpu_types::Features::CLEAR_TEXTURE) {
627624
return_features.insert(ClearTexture);
628625
}
629-
if features.contains(wgpu_types::Features::SPIRV_SHADER_PASSTHROUGH) {
630-
return_features.insert(SpirvShaderPassthrough);
631-
}
632626
if features.contains(wgpu_types::Features::MULTIVIEW) {
633627
return_features.insert(Multiview);
634628
}
@@ -648,6 +642,9 @@ pub fn features_to_feature_names(features: wgpu_types::Features) -> HashSet<GPUF
648642
if features.contains(wgpu_types::Features::SHADER_EARLY_DEPTH_TEST) {
649643
return_features.insert(ShaderEarlyDepthTest);
650644
}
645+
if features.contains(wgpu_types::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS) {
646+
return_features.insert(PassthroughShaders);
647+
}
651648

652649
return_features
653650
}

examples/features/src/mesh_shader/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ fn compile_glsl(
2424
let output = cmd.wait_with_output().expect("Error waiting for glslc");
2525
assert!(output.status.success());
2626
unsafe {
27-
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
28-
wgpu::ShaderModuleDescriptorSpirV {
29-
label: None,
30-
source: wgpu::util::make_spirv_raw(&output.stdout),
31-
},
32-
))
27+
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
28+
entry_point: "main".into(),
29+
label: None,
30+
spirv: Some(wgpu::util::make_spirv_raw(&output.stdout)),
31+
..Default::default()
32+
})
3333
}
3434
}
3535

@@ -119,7 +119,7 @@ impl crate::framework::Example for Example {
119119
Default::default()
120120
}
121121
fn required_features() -> wgpu::Features {
122-
wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::SPIRV_SHADER_PASSTHROUGH
122+
wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS
123123
}
124124
fn required_limits() -> wgpu::Limits {
125125
wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()

player/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ log.workspace = true
2626
raw-window-handle.workspace = true
2727
ron.workspace = true
2828
winit = { workspace = true, optional = true }
29+
bytemuck.workspace = true
2930

3031
# Non-Webassembly
3132
#

player/src/lib.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,84 @@ impl GlobalPlay for wgc::global::Global {
315315
println!("shader compilation error:\n---{code}\n---\n{e}");
316316
}
317317
}
318+
Action::CreateShaderModulePassthrough {
319+
id,
320+
data,
321+
entry_point,
322+
label,
323+
num_workgroups,
324+
runtime_checks,
325+
} => {
326+
let spirv = data.iter().find_map(|a| {
327+
if a.ends_with(".spv") {
328+
let data = fs::read(dir.join(a)).unwrap();
329+
assert!(data.len() % 4 == 0);
330+
331+
Some(Cow::Owned(bytemuck::pod_collect_to_vec(&data)))
332+
} else {
333+
None
334+
}
335+
});
336+
let dxil = data.iter().find_map(|a| {
337+
if a.ends_with(".dxil") {
338+
let vec = std::fs::read(dir.join(a)).unwrap();
339+
Some(Cow::Owned(vec))
340+
} else {
341+
None
342+
}
343+
});
344+
let hlsl = data.iter().find_map(|a| {
345+
if a.ends_with(".hlsl") {
346+
let code = fs::read_to_string(dir.join(a)).unwrap();
347+
Some(Cow::Owned(code))
348+
} else {
349+
None
350+
}
351+
});
352+
let msl = data.iter().find_map(|a| {
353+
if a.ends_with(".msl") {
354+
let code = fs::read_to_string(dir.join(a)).unwrap();
355+
Some(Cow::Owned(code))
356+
} else {
357+
None
358+
}
359+
});
360+
let glsl = data.iter().find_map(|a| {
361+
if a.ends_with(".glsl") {
362+
let code = fs::read_to_string(dir.join(a)).unwrap();
363+
Some(Cow::Owned(code))
364+
} else {
365+
None
366+
}
367+
});
368+
let wgsl = data.iter().find_map(|a| {
369+
if a.ends_with(".wgsl") {
370+
let code = fs::read_to_string(dir.join(a)).unwrap();
371+
Some(Cow::Owned(code))
372+
} else {
373+
None
374+
}
375+
});
376+
let desc = wgt::CreateShaderModuleDescriptorPassthrough {
377+
entry_point,
378+
label,
379+
num_workgroups,
380+
runtime_checks,
381+
382+
spirv,
383+
dxil,
384+
hlsl,
385+
msl,
386+
glsl,
387+
wgsl,
388+
};
389+
let (_, error) = unsafe {
390+
self.device_create_shader_module_passthrough(device, &desc, Some(id))
391+
};
392+
if let Some(e) = error {
393+
println!("shader compilation error: {e}");
394+
}
395+
}
318396
Action::DestroyShaderModule(id) => {
319397
self.shader_module_drop(id);
320398
}

tests/tests/wgpu-gpu/mesh_shader/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ fn compile_glsl(
4141
let output = cmd.wait_with_output().expect("Error waiting for glslc");
4242
assert!(output.status.success());
4343
unsafe {
44-
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
45-
wgpu::ShaderModuleDescriptorSpirV {
46-
label: None,
47-
source: wgpu::util::make_spirv_raw(&output.stdout),
48-
},
49-
))
44+
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
45+
entry_point: "main".into(),
46+
label: None,
47+
spirv: Some(wgpu::util::make_spirv_raw(&output.stdout)),
48+
..Default::default()
49+
})
5050
}
5151
}
5252

@@ -267,7 +267,7 @@ fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
267267
.test_features_limits()
268268
.features(
269269
wgpu::Features::EXPERIMENTAL_MESH_SHADER
270-
| wgpu::Features::SPIRV_SHADER_PASSTHROUGH
270+
| wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS
271271
| match draw_type {
272272
DrawType::Standard | DrawType::Indirect => wgpu::Features::empty(),
273273
DrawType::MultiIndirect => wgpu::Features::MULTI_DRAW_INDIRECT,

wgpu-core/src/device/global.rs

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,36 +1094,27 @@ impl Global {
10941094

10951095
#[cfg(feature = "trace")]
10961096
if let Some(ref mut trace) = *device.trace.lock() {
1097-
let data = trace.make_binary(desc.trace_binary_ext(), desc.trace_data());
1098-
trace.add(trace::Action::CreateShaderModule {
1097+
let mut file_names = Vec::new();
1098+
for (data, ext) in [
1099+
(desc.spirv.as_ref().map(|a| bytemuck::cast_slice(a)), "spv"),
1100+
(desc.dxil.as_deref(), "dxil"),
1101+
(desc.hlsl.as_ref().map(|a| a.as_bytes()), "hlsl"),
1102+
(desc.msl.as_ref().map(|a| a.as_bytes()), "msl"),
1103+
(desc.glsl.as_ref().map(|a| a.as_bytes()), "glsl"),
1104+
(desc.wgsl.as_ref().map(|a| a.as_bytes()), "wgsl"),
1105+
] {
1106+
if let Some(data) = data {
1107+
file_names.push(trace.make_binary(ext, data));
1108+
}
1109+
}
1110+
trace.add(trace::Action::CreateShaderModulePassthrough {
10991111
id: fid.id(),
1100-
desc: match desc {
1101-
pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => {
1102-
pipeline::ShaderModuleDescriptor {
1103-
label: inner.label.clone(),
1104-
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
1105-
}
1106-
}
1107-
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
1108-
pipeline::ShaderModuleDescriptor {
1109-
label: inner.label.clone(),
1110-
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
1111-
}
1112-
}
1113-
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
1114-
pipeline::ShaderModuleDescriptor {
1115-
label: inner.label.clone(),
1116-
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
1117-
}
1118-
}
1119-
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
1120-
pipeline::ShaderModuleDescriptor {
1121-
label: inner.label.clone(),
1122-
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
1123-
}
1124-
}
1125-
},
1126-
data,
1112+
data: file_names,
1113+
1114+
entry_point: desc.entry_point.clone(),
1115+
label: desc.label.clone(),
1116+
num_workgroups: desc.num_workgroups,
1117+
runtime_checks: desc.runtime_checks,
11271118
});
11281119
};
11291120

@@ -1138,7 +1129,7 @@ impl Global {
11381129
return (id, None);
11391130
};
11401131

1141-
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label().to_string())));
1132+
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string())));
11421133
(id, Some(error))
11431134
}
11441135

0 commit comments

Comments
 (0)