diff options
Diffstat (limited to 'src/compiler/nir')
-rw-r--r-- | src/compiler/nir/nir.h | 1 | ||||
-rw-r--r-- | src/compiler/nir/nir_lower_subgroups.c | 36 |
2 files changed, 13 insertions, 24 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index c047ab7512b..6d28a8b3223 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -2484,6 +2484,7 @@ bool nir_lower_samplers_as_deref(nir_shader *shader, const struct gl_shader_program *shader_program); typedef struct nir_lower_subgroups_options { + uint8_t subgroup_size; uint8_t ballot_bit_size; bool lower_to_scalar:1; bool lower_vote_trivial:1; diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 76e831691ee..a99ffe2ea99 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -96,28 +96,6 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) } static nir_ssa_def * -high_subgroup_mask(nir_builder *b, - nir_ssa_def *count, - uint64_t base_mask, - unsigned bit_size) -{ - /* group_mask could probably be calculated more efficiently but we want to - * be sure not to shift by 64 if the subgroup size is 64 because the GLSL - * shift operator is undefined in that case. In any case if we were worried - * about efficency this should probably be done further down because the - * subgroup size is likely to be known at compile time. - */ - nir_ssa_def *subgroup_size = nir_load_subgroup_size(b); - nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size); - nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size); - nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift); - nir_ssa_def *higher_bits = - nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count); - - return nir_iand(b, higher_bits, group_mask); -} - -static nir_ssa_def * lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, const nir_lower_subgroups_options *options) { @@ -133,6 +111,11 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, return nir_imm_int(b, NIR_TRUE); break; + case nir_intrinsic_load_subgroup_size: + if (options->subgroup_size) + return nir_imm_int(b, options->subgroup_size); + break; + case nir_intrinsic_read_invocation: case nir_intrinsic_read_first_invocation: if (options->lower_to_scalar && intrin->num_components > 1) @@ -154,6 +137,9 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, const unsigned bit_size = MAX2(options->ballot_bit_size, intrin->dest.ssa.bit_size); + assert(options->subgroup_size <= 64); + uint64_t group_mask = ~0ull >> (64 - options->subgroup_size); + nir_ssa_def *count = nir_load_subgroup_invocation(b); nir_ssa_def *val; switch (intrin->intrinsic) { @@ -161,10 +147,12 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count); break; case nir_intrinsic_load_subgroup_ge_mask: - val = high_subgroup_mask(b, count, ~0ull, bit_size); + val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count), + nir_imm_intN_t(b, group_mask, bit_size)); break; case nir_intrinsic_load_subgroup_gt_mask: - val = high_subgroup_mask(b, count, ~1ull, bit_size); + val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count), + nir_imm_intN_t(b, group_mask, bit_size)); break; case nir_intrinsic_load_subgroup_le_mask: val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count)); |