summaryrefslogtreecommitdiffstats
path: root/src/compiler/nir/nir_lower_subgroups.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/nir/nir_lower_subgroups.c')
-rw-r--r--src/compiler/nir/nir_lower_subgroups.c92
1 files changed, 82 insertions, 10 deletions
diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c
index 0d11dc9c23a..76e831691ee 100644
--- a/src/compiler/nir/nir_lower_subgroups.c
+++ b/src/compiler/nir/nir_lower_subgroups.c
@@ -28,6 +28,42 @@
* \file nir_opt_intrinsics.c
*/
+/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
+static nir_ssa_def *
+uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
+ unsigned num_components, unsigned bit_size)
+{
+ assert(value->num_components == 1);
+ assert(value->bit_size == 32 || value->bit_size == 64);
+
+ nir_ssa_def *zero = nir_imm_int(b, 0);
+ if (num_components > 1) {
+ /* SPIR-V uses a uvec4 for ballot values */
+ assert(num_components == 4);
+ assert(bit_size == 32);
+
+ if (value->bit_size == 32) {
+ return nir_vec4(b, value, zero, zero, zero);
+ } else {
+ assert(value->bit_size == 64);
+ return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
+ nir_unpack_64_2x32_split_y(b, value),
+ zero, zero);
+ }
+ } else {
+ /* GLSL uses a uint64_t for ballot values */
+ assert(num_components == 1);
+ assert(bit_size == 64);
+
+ if (value->bit_size == 32) {
+ return nir_pack_64_2x32_split(b, value, zero);
+ } else {
+ assert(value->bit_size == 64);
+ return value;
+ }
+ }
+}
+
static nir_ssa_def *
lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
{
@@ -62,7 +98,8 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
static nir_ssa_def *
high_subgroup_mask(nir_builder *b,
nir_ssa_def *count,
- uint64_t base_mask)
+ uint64_t base_mask,
+ unsigned bit_size)
{
/* group_mask could probably be calculated more efficiently but we want to
* be sure not to shift by 64 if the subgroup size is 64 because the GLSL
@@ -71,10 +108,11 @@ high_subgroup_mask(nir_builder *b,
* subgroup size is likely to be known at compile time.
*/
nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
- nir_ssa_def *all_bits = nir_imm_int64(b, ~0ull);
+ nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size);
nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size);
nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift);
- nir_ssa_def *higher_bits = nir_ishl(b, nir_imm_int64(b, base_mask), count);
+ nir_ssa_def *higher_bits =
+ nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count);
return nir_iand(b, higher_bits, group_mask);
}
@@ -109,24 +147,58 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
if (!options->lower_subgroup_masks)
return NULL;
- nir_ssa_def *count = nir_load_subgroup_invocation(b);
+ /* If either the result or the requested bit size is 64-bits then we
+ * know that we have 64-bit types and using them will probably be more
+ * efficient than messing around with 32-bit shifts and packing.
+ */
+ const unsigned bit_size = MAX2(options->ballot_bit_size,
+ intrin->dest.ssa.bit_size);
+ nir_ssa_def *count = nir_load_subgroup_invocation(b);
+ nir_ssa_def *val;
switch (intrin->intrinsic) {
case nir_intrinsic_load_subgroup_eq_mask:
- return nir_ishl(b, nir_imm_int64(b, 1ull), count);
+ val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
+ break;
case nir_intrinsic_load_subgroup_ge_mask:
- return high_subgroup_mask(b, count, ~0ull);
+ val = high_subgroup_mask(b, count, ~0ull, bit_size);
+ break;
case nir_intrinsic_load_subgroup_gt_mask:
- return high_subgroup_mask(b, count, ~1ull);
+ val = high_subgroup_mask(b, count, ~1ull, bit_size);
+ break;
case nir_intrinsic_load_subgroup_le_mask:
- return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~1ull), count));
+ val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
+ break;
case nir_intrinsic_load_subgroup_lt_mask:
- return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~0ull), count));
+ val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
+ break;
default:
unreachable("you seriously can't tell this is unreachable?");
}
- break;
+
+ return uint_to_ballot_type(b, val,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size);
+ }
+
+ case nir_intrinsic_ballot: {
+ if (intrin->dest.ssa.num_components == 1 &&
+ intrin->dest.ssa.bit_size == options->ballot_bit_size)
+ return NULL;
+
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+ ballot->num_components = 1;
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+ 1, options->ballot_bit_size, NULL);
+ nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
+ nir_builder_instr_insert(b, &ballot->instr);
+
+ return uint_to_ballot_type(b, &ballot->dest.ssa,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size);
}
+
default:
break;
}