summaryrefslogtreecommitdiffstats
path: root/src/compiler/spirv/vtn_subgroup.c
diff options
context:
space:
mode:
authorJason Ekstrand <[email protected]>2017-08-22 16:53:05 -0700
committerJason Ekstrand <[email protected]>2018-03-07 12:13:47 -0800
commit9812fce60b6ffbcd136b66bfb609143449ad3f7c (patch)
treefde13b6279321f095490fa3b74fc8564052d3564 /src/compiler/spirv/vtn_subgroup.c
parent974daec495eae05b3c3179cd6c131a65ff2efcc7 (diff)
spirv: Add subgroup ballot support
Reviewed-by: Iago Toral Quiroga <[email protected]>
Diffstat (limited to 'src/compiler/spirv/vtn_subgroup.c')
-rw-r--r--src/compiler/spirv/vtn_subgroup.c143
1 files changed, 135 insertions, 8 deletions
diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c
index 033c43e601c..a86f0cb2832 100644
--- a/src/compiler/spirv/vtn_subgroup.c
+++ b/src/compiler/spirv/vtn_subgroup.c
@@ -23,6 +23,44 @@
#include "vtn_private.h"
+static void
+vtn_build_subgroup_instr(struct vtn_builder *b,
+ nir_intrinsic_op nir_op,
+ struct vtn_ssa_value *dst,
+ struct vtn_ssa_value *src0,
+ nir_ssa_def *index)
+{
+ /* Some of the subgroup operations take an index. SPIR-V allows this to be
+ * any integer type. To make things simpler for drivers, we only support
+ * 32-bit indices.
+ */
+ if (index && index->bit_size != 32)
+ index = nir_u2u32(&b->nb, index);
+
+ vtn_assert(dst->type == src0->type);
+ if (!glsl_type_is_vector_or_scalar(dst->type)) {
+ for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
+ vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
+ src0->elems[i], index);
+ }
+ return;
+ }
+
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader, nir_op);
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ dst->type, NULL);
+ intrin->num_components = intrin->dest.ssa.num_components;
+
+ intrin->src[0] = nir_src_for_ssa(src0->def);
+ if (index)
+ intrin->src[1] = nir_src_for_ssa(index);
+
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ dst->def = &intrin->dest.ssa;
+}
+
void
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@@ -43,17 +81,106 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
break;
}
- case SpvOpGroupNonUniformAll:
- case SpvOpGroupNonUniformAny:
- case SpvOpGroupNonUniformAllEqual:
- case SpvOpGroupNonUniformBroadcast:
- case SpvOpGroupNonUniformBroadcastFirst:
- case SpvOpGroupNonUniformBallot:
- case SpvOpGroupNonUniformInverseBallot:
+ case SpvOpGroupNonUniformBallot: {
+ vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
+ "OpGroupNonUniformBallot must return a uvec4");
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
+ ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
+ ballot->num_components = 4;
+ nir_builder_instr_insert(&b->nb, &ballot->instr);
+ val->ssa->def = &ballot->dest.ssa;
+ break;
+ }
+
+ case SpvOpGroupNonUniformInverseBallot: {
+ /* This one is just a BallotBitfieldExtract with subgroup invocation.
+ * We could add a NIR intrinsic but it's easier to just lower it on the
+ * spot.
+ */
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader,
+ nir_intrinsic_ballot_bitfield_extract);
+
+ intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+ intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
+
+ nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ val->ssa->def = &intrin->dest.ssa;
+ break;
+ }
+
case SpvOpGroupNonUniformBallotBitExtract:
case SpvOpGroupNonUniformBallotBitCount:
case SpvOpGroupNonUniformBallotFindLSB:
- case SpvOpGroupNonUniformBallotFindMSB:
+ case SpvOpGroupNonUniformBallotFindMSB: {
+ nir_ssa_def *src0, *src1 = NULL;
+ nir_intrinsic_op op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformBallotBitExtract:
+ op = nir_intrinsic_ballot_bitfield_extract;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ src1 = vtn_ssa_value(b, w[5])->def;
+ break;
+ case SpvOpGroupNonUniformBallotBitCount:
+ switch ((SpvGroupOperation)w[4]) {
+ case SpvGroupOperationReduce:
+ op = nir_intrinsic_ballot_bit_count_reduce;
+ break;
+ case SpvGroupOperationInclusiveScan:
+ op = nir_intrinsic_ballot_bit_count_inclusive;
+ break;
+ case SpvGroupOperationExclusiveScan:
+ op = nir_intrinsic_ballot_bit_count_exclusive;
+ break;
+ default:
+ unreachable("Invalid group operation");
+ }
+ src0 = vtn_ssa_value(b, w[5])->def;
+ break;
+ case SpvOpGroupNonUniformBallotFindLSB:
+ op = nir_intrinsic_ballot_find_lsb;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ break;
+ case SpvOpGroupNonUniformBallotFindMSB:
+ op = nir_intrinsic_ballot_find_msb;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ break;
+ default:
+ unreachable("Unhandled opcode");
+ }
+
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader, op);
+
+ intrin->src[0] = nir_src_for_ssa(src0);
+ if (src1)
+ intrin->src[1] = nir_src_for_ssa(src1);
+
+ nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ val->ssa->def = &intrin->dest.ssa;
+ break;
+ }
+
+ case SpvOpGroupNonUniformBroadcastFirst:
+ vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
+ val->ssa, vtn_ssa_value(b, w[4]), NULL);
+ break;
+
+ case SpvOpGroupNonUniformBroadcast:
+ vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
+ val->ssa, vtn_ssa_value(b, w[4]),
+ vtn_ssa_value(b, w[5])->def);
+ break;
+
+ case SpvOpGroupNonUniformAll:
+ case SpvOpGroupNonUniformAny:
+ case SpvOpGroupNonUniformAllEqual:
case SpvOpGroupNonUniformShuffle:
case SpvOpGroupNonUniformShuffleXor:
case SpvOpGroupNonUniformShuffleUp: