1[[block]]
2struct Globals {
3    num_lights: vec4<u32>;
4};
5
6[[group(0), binding(0)]]
7var<uniform> u_globals: Globals;
8
9struct Light {
10    proj: mat4x4<f32>;
11    pos: vec4<f32>;
12    color: vec4<f32>;
13};
14
15[[block]]
16struct Lights {
17    data: [[stride(96)]] array<Light>;
18};
19
20[[group(0), binding(1)]]
21var<storage> s_lights: [[access(read)]] Lights;
22[[group(0), binding(2)]]
23var t_shadow: texture_depth_2d_array;
24[[group(0), binding(3)]]
25var sampler_shadow: sampler_comparison;
26
27fn fetch_shadow(light_id: u32, homogeneous_coords: vec4<f32>) -> f32 {
28    if (homogeneous_coords.w <= 0.0) {
29        return 1.0;
30    }
31    let flip_correction = vec2<f32>(0.5, -0.5);
32    let light_local = homogeneous_coords.xy * flip_correction / homogeneous_coords.w + vec2<f32>(0.5, 0.5);
33    return textureSampleCompare(t_shadow, sampler_shadow, light_local, i32(light_id), homogeneous_coords.z / homogeneous_coords.w);
34}
35
36let c_ambient: vec3<f32> = vec3<f32>(0.05, 0.05, 0.05);
37let c_max_lights: u32 = 10u;
38
39[[stage(fragment)]]
40fn fs_main(
41    [[location(0)]] raw_normal: vec3<f32>,
42    [[location(1)]] position: vec4<f32>
43) -> [[location(0)]] vec4<f32> {
44    let normal: vec3<f32> = normalize(raw_normal);
45    // accumulate color
46    var color: vec3<f32> = c_ambient;
47    var i: u32 = 0u;
48    loop {
49        if (i >= min(u_globals.num_lights.x, c_max_lights)) {
50            break;
51        }
52        let light = s_lights.data[i];
53        let shadow = fetch_shadow(i, light.proj * position);
54        let light_dir = normalize(light.pos.xyz - position.xyz);
55        let diffuse = max(0.0, dot(normal, light_dir));
56        color = color + shadow * diffuse * light.color.xyz;
57        continuing {
58            i = i + 1u;
59        }
60    }
61    // multiply the light by material color
62    return vec4<f32>(color, 1.0);
63}
64