diff options
-rw-r--r-- | src/compiler/nir/nir.h | 1 | ||||
-rw-r--r-- | src/compiler/nir/nir_lower_subgroups.c | 36 | ||||
-rw-r--r-- | src/intel/compiler/brw_fs_nir.cpp | 4 | ||||
-rw-r--r-- | src/intel/compiler/brw_nir.c | 2 |
4 files changed, 15 insertions, 28 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index c047ab7512b..6d28a8b3223 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -2484,6 +2484,7 @@ bool nir_lower_samplers_as_deref(nir_shader *shader, const struct gl_shader_program *shader_program); typedef struct nir_lower_subgroups_options { + uint8_t subgroup_size; uint8_t ballot_bit_size; bool lower_to_scalar:1; bool lower_vote_trivial:1; diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 76e831691ee..a99ffe2ea99 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -96,28 +96,6 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) } static nir_ssa_def * -high_subgroup_mask(nir_builder *b, - nir_ssa_def *count, - uint64_t base_mask, - unsigned bit_size) -{ - /* group_mask could probably be calculated more efficiently but we want to - * be sure not to shift by 64 if the subgroup size is 64 because the GLSL - * shift operator is undefined in that case. In any case if we were worried - * about efficency this should probably be done further down because the - * subgroup size is likely to be known at compile time. - */ - nir_ssa_def *subgroup_size = nir_load_subgroup_size(b); - nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size); - nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size); - nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift); - nir_ssa_def *higher_bits = - nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count); - - return nir_iand(b, higher_bits, group_mask); -} - -static nir_ssa_def * lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, const nir_lower_subgroups_options *options) { @@ -133,6 +111,11 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, return nir_imm_int(b, NIR_TRUE); break; + case nir_intrinsic_load_subgroup_size: + if (options->subgroup_size) + return nir_imm_int(b, options->subgroup_size); + break; + case nir_intrinsic_read_invocation: case nir_intrinsic_read_first_invocation: if (options->lower_to_scalar && intrin->num_components > 1) @@ -154,6 +137,9 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, const unsigned bit_size = MAX2(options->ballot_bit_size, intrin->dest.ssa.bit_size); + assert(options->subgroup_size <= 64); + uint64_t group_mask = ~0ull >> (64 - options->subgroup_size); + nir_ssa_def *count = nir_load_subgroup_invocation(b); nir_ssa_def *val; switch (intrin->intrinsic) { @@ -161,10 +147,12 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count); break; case nir_intrinsic_load_subgroup_ge_mask: - val = high_subgroup_mask(b, count, ~0ull, bit_size); + val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count), + nir_imm_intN_t(b, group_mask, bit_size)); break; case nir_intrinsic_load_subgroup_gt_mask: - val = high_subgroup_mask(b, count, ~1ull, bit_size); + val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count), + nir_imm_intN_t(b, group_mask, bit_size)); break; case nir_intrinsic_load_subgroup_le_mask: val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count)); diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp index f8970997371..2f47b0253b2 100644 --- a/src/intel/compiler/brw_fs_nir.cpp +++ b/src/intel/compiler/brw_fs_nir.cpp @@ -4184,10 +4184,6 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr break; } - case nir_intrinsic_load_subgroup_size: - bld.MOV(retype(dest, BRW_REGISTER_TYPE_D), brw_imm_d(dispatch_width)); - break; - case nir_intrinsic_load_subgroup_invocation: bld.MOV(retype(dest, BRW_REGISTER_TYPE_D), nir_system_values[SYSTEM_VALUE_SUBGROUP_INVOCATION]); diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c index 0d59d36ca63..5ed36fe1bf7 100644 --- a/src/intel/compiler/brw_nir.c +++ b/src/intel/compiler/brw_nir.c @@ -637,6 +637,8 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir) OPT(nir_lower_system_values); const nir_lower_subgroups_options subgroups_options = { + .subgroup_size = nir->info.stage == MESA_SHADER_COMPUTE ? 32 : + nir->info.stage == MESA_SHADER_FRAGMENT ? 16 : 8, .ballot_bit_size = 32, .lower_to_scalar = true, .lower_subgroup_masks = true, |