diff options
-rw-r--r-- | src/compiler/nir/nir_intrinsics.py | 2 | ||||
-rw-r--r-- | src/compiler/nir/nir_lower_system_values.c | 45 |
2 files changed, 29 insertions, 18 deletions
diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 1d388c64fc9..d53b26c88d6 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -532,7 +532,7 @@ system_value("subgroup_lt_mask", 0, bit_sizes=[32, 64]) system_value("num_subgroups", 1) system_value("subgroup_id", 1) system_value("local_group_size", 3) -system_value("global_invocation_id", 3) +system_value("global_invocation_id", 3, bit_sizes=[32, 64]) system_value("work_dim", 1) # Driver-specific viewport scale/offset parameters. # diff --git a/src/compiler/nir/nir_lower_system_values.c b/src/compiler/nir/nir_lower_system_values.c index 68b0ea89c8d..de5ccab0f38 100644 --- a/src/compiler/nir/nir_lower_system_values.c +++ b/src/compiler/nir/nir_lower_system_values.c @@ -29,7 +29,7 @@ #include "nir_builder.h" static nir_ssa_def* -build_local_group_size(nir_builder *b) +build_local_group_size(nir_builder *b, unsigned bit_size) { nir_ssa_def *local_size; @@ -40,6 +40,8 @@ build_local_group_size(nir_builder *b) if (b->shader->info.cs.local_size_variable) { local_size = nir_load_local_group_size(b); } else { + /* using a 32 bit constant is safe here as no device/driver needs more + * than 32 bits for the local size */ nir_const_value local_size_const; memset(&local_size_const, 0, sizeof(local_size_const)); local_size_const.u32[0] = b->shader->info.cs.local_size[0]; @@ -48,12 +50,15 @@ build_local_group_size(nir_builder *b) local_size = nir_build_imm(b, 3, 32, local_size_const); } - return local_size; + return nir_u2u(b, local_size, bit_size); } static nir_ssa_def * -build_local_invocation_id(nir_builder *b) +build_local_invocation_id(nir_builder *b, unsigned bit_size) { + /* If lower_cs_local_id_from_index is true, then we derive the local + * index from the local id. + */ if (b->shader->options->lower_cs_local_id_from_index) { /* We lower gl_LocalInvocationID from gl_LocalInvocationIndex based * on this formula: @@ -73,8 +78,12 @@ build_local_invocation_id(nir_builder *b) * large so it can safely be omitted. */ nir_ssa_def *local_index = nir_load_local_invocation_index(b); - nir_ssa_def *local_size = build_local_group_size(b); + nir_ssa_def *local_size = build_local_group_size(b, 32); + /* Because no hardware supports a local workgroup size greater than + * about 1K, this calculation can be done in 32-bit and can save some + * 64-bit arithmetic. + */ nir_ssa_def *id_x, *id_y, *id_z; id_x = nir_umod(b, local_index, nir_channel(b, local_size, 0)); @@ -84,9 +93,9 @@ build_local_invocation_id(nir_builder *b) id_z = nir_udiv(b, local_index, nir_imul(b, nir_channel(b, local_size, 0), nir_channel(b, local_size, 1))); - return nir_vec3(b, id_x, id_y, id_z); + return nir_u2u(b, nir_vec3(b, id_x, id_y, id_z), bit_size); } else { - return nir_load_local_invocation_id(b); + return nir_u2u(b, nir_load_local_invocation_id(b), bit_size); } } @@ -120,6 +129,7 @@ convert_block(nir_block *block, nir_builder *b) b->cursor = nir_after_instr(&load_deref->instr); + unsigned bit_size = nir_dest_bit_size(load_deref->dest); nir_ssa_def *sysval = NULL; switch (var->data.location) { case SYSTEM_VALUE_GLOBAL_INVOCATION_ID: { @@ -128,9 +138,9 @@ convert_block(nir_block *block, nir_builder *b) * "The value of gl_GlobalInvocationID is equal to * gl_WorkGroupID * gl_WorkGroupSize + gl_LocalInvocationID" */ - nir_ssa_def *group_size = build_local_group_size(b); - nir_ssa_def *group_id = nir_load_work_group_id(b); - nir_ssa_def *local_id = build_local_invocation_id(b); + nir_ssa_def *group_size = build_local_group_size(b, bit_size); + nir_ssa_def *group_id = nir_u2u(b, nir_load_work_group_id(b), bit_size); + nir_ssa_def *local_id = build_local_invocation_id(b, bit_size); sysval = nir_iadd(b, nir_imul(b, group_id, group_size), local_id); break; @@ -157,24 +167,25 @@ convert_block(nir_block *block, nir_builder *b) nir_ssa_def *size_y = nir_imm_int(b, b->shader->info.cs.local_size[1]); + /* Because no hardware supports a local workgroup size greater than + * about 1K, this calculation can be done in 32-bit and can save some + * 64-bit arithmetic. + */ sysval = nir_imul(b, nir_channel(b, local_id, 2), nir_imul(b, size_x, size_y)); sysval = nir_iadd(b, sysval, nir_imul(b, nir_channel(b, local_id, 1), size_x)); sysval = nir_iadd(b, sysval, nir_channel(b, local_id, 0)); + sysval = nir_u2u(b, sysval, bit_size); break; } case SYSTEM_VALUE_LOCAL_INVOCATION_ID: - /* If lower_cs_local_id_from_index is true, then we derive the local - * index from the local id. - */ - if (b->shader->options->lower_cs_local_id_from_index) - sysval = build_local_invocation_id(b); + sysval = build_local_invocation_id(b, bit_size); break; case SYSTEM_VALUE_LOCAL_GROUP_SIZE: { - sysval = build_local_group_size(b); + sysval = build_local_group_size(b, bit_size); break; } @@ -248,8 +259,8 @@ convert_block(nir_block *block, nir_builder *b) break; case SYSTEM_VALUE_GLOBAL_GROUP_SIZE: { - nir_ssa_def *group_size = build_local_group_size(b); - nir_ssa_def *num_work_groups = nir_load_num_work_groups(b); + nir_ssa_def *group_size = build_local_group_size(b, bit_size); + nir_ssa_def *num_work_groups = nir_u2u(b, nir_load_num_work_groups(b), bit_size); sysval = nir_imul(b, group_size, num_work_groups); break; } |