summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/compiler/nir/nir.h1
-rw-r--r--src/compiler/nir/nir_lower_subgroups.c36
-rw-r--r--src/intel/compiler/brw_fs_nir.cpp4
-rw-r--r--src/intel/compiler/brw_nir.c2
4 files changed, 15 insertions, 28 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index c047ab7512b..6d28a8b3223 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -2484,6 +2484,7 @@ bool nir_lower_samplers_as_deref(nir_shader *shader,
const struct gl_shader_program *shader_program);
typedef struct nir_lower_subgroups_options {
+ uint8_t subgroup_size;
uint8_t ballot_bit_size;
bool lower_to_scalar:1;
bool lower_vote_trivial:1;
diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c
index 76e831691ee..a99ffe2ea99 100644
--- a/src/compiler/nir/nir_lower_subgroups.c
+++ b/src/compiler/nir/nir_lower_subgroups.c
@@ -96,28 +96,6 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
}
static nir_ssa_def *
-high_subgroup_mask(nir_builder *b,
- nir_ssa_def *count,
- uint64_t base_mask,
- unsigned bit_size)
-{
- /* group_mask could probably be calculated more efficiently but we want to
- * be sure not to shift by 64 if the subgroup size is 64 because the GLSL
- * shift operator is undefined in that case. In any case if we were worried
- * about efficency this should probably be done further down because the
- * subgroup size is likely to be known at compile time.
- */
- nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
- nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size);
- nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size);
- nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift);
- nir_ssa_def *higher_bits =
- nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count);
-
- return nir_iand(b, higher_bits, group_mask);
-}
-
-static nir_ssa_def *
lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
const nir_lower_subgroups_options *options)
{
@@ -133,6 +111,11 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
return nir_imm_int(b, NIR_TRUE);
break;
+ case nir_intrinsic_load_subgroup_size:
+ if (options->subgroup_size)
+ return nir_imm_int(b, options->subgroup_size);
+ break;
+
case nir_intrinsic_read_invocation:
case nir_intrinsic_read_first_invocation:
if (options->lower_to_scalar && intrin->num_components > 1)
@@ -154,6 +137,9 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
const unsigned bit_size = MAX2(options->ballot_bit_size,
intrin->dest.ssa.bit_size);
+ assert(options->subgroup_size <= 64);
+ uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
+
nir_ssa_def *count = nir_load_subgroup_invocation(b);
nir_ssa_def *val;
switch (intrin->intrinsic) {
@@ -161,10 +147,12 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
break;
case nir_intrinsic_load_subgroup_ge_mask:
- val = high_subgroup_mask(b, count, ~0ull, bit_size);
+ val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
+ nir_imm_intN_t(b, group_mask, bit_size));
break;
case nir_intrinsic_load_subgroup_gt_mask:
- val = high_subgroup_mask(b, count, ~1ull, bit_size);
+ val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
+ nir_imm_intN_t(b, group_mask, bit_size));
break;
case nir_intrinsic_load_subgroup_le_mask:
val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp
index f8970997371..2f47b0253b2 100644
--- a/src/intel/compiler/brw_fs_nir.cpp
+++ b/src/intel/compiler/brw_fs_nir.cpp
@@ -4184,10 +4184,6 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr
break;
}
- case nir_intrinsic_load_subgroup_size:
- bld.MOV(retype(dest, BRW_REGISTER_TYPE_D), brw_imm_d(dispatch_width));
- break;
-
case nir_intrinsic_load_subgroup_invocation:
bld.MOV(retype(dest, BRW_REGISTER_TYPE_D),
nir_system_values[SYSTEM_VALUE_SUBGROUP_INVOCATION]);
diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c
index 0d59d36ca63..5ed36fe1bf7 100644
--- a/src/intel/compiler/brw_nir.c
+++ b/src/intel/compiler/brw_nir.c
@@ -637,6 +637,8 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir)
OPT(nir_lower_system_values);
const nir_lower_subgroups_options subgroups_options = {
+ .subgroup_size = nir->info.stage == MESA_SHADER_COMPUTE ? 32 :
+ nir->info.stage == MESA_SHADER_FRAGMENT ? 16 : 8,
.ballot_bit_size = 32,
.lower_to_scalar = true,
.lower_subgroup_masks = true,