diff options
Diffstat (limited to 'src/compiler/spirv')
-rw-r--r-- | src/compiler/spirv/spirv_to_nir.c | 5 | ||||
-rw-r--r-- | src/compiler/spirv/vtn_subgroup.c | 143 | ||||
-rw-r--r-- | src/compiler/spirv/vtn_variables.c | 20 |
3 files changed, 160 insertions, 8 deletions
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 4d2c1533d24..38a1df9fd21 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -3296,6 +3296,11 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, spv_check_supported(subgroup_basic, cap); break; + case SpvCapabilitySubgroupBallotKHR: + case SpvCapabilityGroupNonUniformBallot: + spv_check_supported(subgroup_ballot, cap); + break; + case SpvCapabilityVariablePointersStorageBuffer: case SpvCapabilityVariablePointers: spv_check_supported(variable_pointers, cap); diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index 033c43e601c..a86f0cb2832 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -23,6 +23,44 @@ #include "vtn_private.h" +static void +vtn_build_subgroup_instr(struct vtn_builder *b, + nir_intrinsic_op nir_op, + struct vtn_ssa_value *dst, + struct vtn_ssa_value *src0, + nir_ssa_def *index) +{ + /* Some of the subgroup operations take an index. SPIR-V allows this to be + * any integer type. To make things simpler for drivers, we only support + * 32-bit indices. + */ + if (index && index->bit_size != 32) + index = nir_u2u32(&b->nb, index); + + vtn_assert(dst->type == src0->type); + if (!glsl_type_is_vector_or_scalar(dst->type)) { + for (unsigned i = 0; i < glsl_get_length(dst->type); i++) { + vtn_build_subgroup_instr(b, nir_op, dst->elems[i], + src0->elems[i], index); + } + return; + } + + nir_intrinsic_instr *intrin = + nir_intrinsic_instr_create(b->nb.shader, nir_op); + nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, + dst->type, NULL); + intrin->num_components = intrin->dest.ssa.num_components; + + intrin->src[0] = nir_src_for_ssa(src0->def); + if (index) + intrin->src[1] = nir_src_for_ssa(index); + + nir_builder_instr_insert(&b->nb, &intrin->instr); + + dst->def = &intrin->dest.ssa; +} + void vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -43,17 +81,106 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, break; } - case SpvOpGroupNonUniformAll: - case SpvOpGroupNonUniformAny: - case SpvOpGroupNonUniformAllEqual: - case SpvOpGroupNonUniformBroadcast: - case SpvOpGroupNonUniformBroadcastFirst: - case SpvOpGroupNonUniformBallot: - case SpvOpGroupNonUniformInverseBallot: + case SpvOpGroupNonUniformBallot: { + vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4), + "OpGroupNonUniformBallot must return a uvec4"); + nir_intrinsic_instr *ballot = + nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot); + ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); + nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL); + ballot->num_components = 4; + nir_builder_instr_insert(&b->nb, &ballot->instr); + val->ssa->def = &ballot->dest.ssa; + break; + } + + case SpvOpGroupNonUniformInverseBallot: { + /* This one is just a BallotBitfieldExtract with subgroup invocation. + * We could add a NIR intrinsic but it's easier to just lower it on the + * spot. + */ + nir_intrinsic_instr *intrin = + nir_intrinsic_instr_create(b->nb.shader, + nir_intrinsic_ballot_bitfield_extract); + + intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); + intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb)); + + nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL); + nir_builder_instr_insert(&b->nb, &intrin->instr); + + val->ssa->def = &intrin->dest.ssa; + break; + } + case SpvOpGroupNonUniformBallotBitExtract: case SpvOpGroupNonUniformBallotBitCount: case SpvOpGroupNonUniformBallotFindLSB: - case SpvOpGroupNonUniformBallotFindMSB: + case SpvOpGroupNonUniformBallotFindMSB: { + nir_ssa_def *src0, *src1 = NULL; + nir_intrinsic_op op; + switch (opcode) { + case SpvOpGroupNonUniformBallotBitExtract: + op = nir_intrinsic_ballot_bitfield_extract; + src0 = vtn_ssa_value(b, w[4])->def; + src1 = vtn_ssa_value(b, w[5])->def; + break; + case SpvOpGroupNonUniformBallotBitCount: + switch ((SpvGroupOperation)w[4]) { + case SpvGroupOperationReduce: + op = nir_intrinsic_ballot_bit_count_reduce; + break; + case SpvGroupOperationInclusiveScan: + op = nir_intrinsic_ballot_bit_count_inclusive; + break; + case SpvGroupOperationExclusiveScan: + op = nir_intrinsic_ballot_bit_count_exclusive; + break; + default: + unreachable("Invalid group operation"); + } + src0 = vtn_ssa_value(b, w[5])->def; + break; + case SpvOpGroupNonUniformBallotFindLSB: + op = nir_intrinsic_ballot_find_lsb; + src0 = vtn_ssa_value(b, w[4])->def; + break; + case SpvOpGroupNonUniformBallotFindMSB: + op = nir_intrinsic_ballot_find_msb; + src0 = vtn_ssa_value(b, w[4])->def; + break; + default: + unreachable("Unhandled opcode"); + } + + nir_intrinsic_instr *intrin = + nir_intrinsic_instr_create(b->nb.shader, op); + + intrin->src[0] = nir_src_for_ssa(src0); + if (src1) + intrin->src[1] = nir_src_for_ssa(src1); + + nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL); + nir_builder_instr_insert(&b->nb, &intrin->instr); + + val->ssa->def = &intrin->dest.ssa; + break; + } + + case SpvOpGroupNonUniformBroadcastFirst: + vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, + val->ssa, vtn_ssa_value(b, w[4]), NULL); + break; + + case SpvOpGroupNonUniformBroadcast: + vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, + val->ssa, vtn_ssa_value(b, w[4]), + vtn_ssa_value(b, w[5])->def); + break; + + case SpvOpGroupNonUniformAll: + case SpvOpGroupNonUniformAny: + case SpvOpGroupNonUniformAllEqual: case SpvOpGroupNonUniformShuffle: case SpvOpGroupNonUniformShuffleXor: case SpvOpGroupNonUniformShuffleUp: diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 68e1adf8152..61caaafa311 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -1317,6 +1317,26 @@ vtn_get_builtin_location(struct vtn_builder *b, *location = SYSTEM_VALUE_VIEW_INDEX; set_mode_system_value(b, mode); break; + case SpvBuiltInSubgroupEqMask: + *location = SYSTEM_VALUE_SUBGROUP_EQ_MASK, + set_mode_system_value(b, mode); + break; + case SpvBuiltInSubgroupGeMask: + *location = SYSTEM_VALUE_SUBGROUP_GE_MASK, + set_mode_system_value(b, mode); + break; + case SpvBuiltInSubgroupGtMask: + *location = SYSTEM_VALUE_SUBGROUP_GT_MASK, + set_mode_system_value(b, mode); + break; + case SpvBuiltInSubgroupLeMask: + *location = SYSTEM_VALUE_SUBGROUP_LE_MASK, + set_mode_system_value(b, mode); + break; + case SpvBuiltInSubgroupLtMask: + *location = SYSTEM_VALUE_SUBGROUP_LT_MASK, + set_mode_system_value(b, mode); + break; default: vtn_fail("unsupported builtin"); } |