diff options
-rw-r--r-- | src/compiler/nir/nir.h | 1 | ||||
-rw-r--r-- | src/compiler/nir/nir_lower_subgroups.c | 68 |
2 files changed, 56 insertions, 13 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index f33049d7134..f3326e6df94 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -2609,6 +2609,7 @@ typedef struct nir_lower_subgroups_options { bool lower_vote_eq_to_ballot:1; bool lower_subgroup_masks:1; bool lower_shuffle:1; + bool lower_shuffle_to_32bit:1; bool lower_quad:1; } nir_lower_subgroups_options; diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index e0e1063fc43..ee5e8bd644b 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -28,6 +28,38 @@ * \file nir_opt_intrinsics.c */ +static nir_intrinsic_instr * +lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin, + unsigned int component) +{ + nir_ssa_def *comp; + if (component == 0) + comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa); + else + comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa); + + nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic); + nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL); + intr->const_index[0] = intrin->const_index[0]; + intr->const_index[1] = intrin->const_index[1]; + intr->src[0] = nir_src_for_ssa(comp); + if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2) + nir_src_copy(&intr->src[1], &intrin->src[1], intr); + + intr->num_components = 1; + nir_builder_instr_insert(b, &intr->instr); + return intr; +} + +static nir_ssa_def * +lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin) +{ + assert(intrin->src[0].ssa->bit_size == 64); + nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0); + nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1); + return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa); +} + static nir_ssa_def * ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size) { @@ -80,7 +112,8 @@ uint_to_ballot_type(nir_builder *b, nir_ssa_def *value, } static nir_ssa_def * -lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) +lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin, + bool lower_to_32bit) { /* This is safe to call on scalar things but it would be silly */ assert(intrin->dest.ssa.num_components > 1); @@ -107,9 +140,12 @@ lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) chan_intrin->const_index[0] = intrin->const_index[0]; chan_intrin->const_index[1] = intrin->const_index[1]; - nir_builder_instr_insert(b, &chan_intrin->instr); - - reads[i] = &chan_intrin->dest.ssa; + if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) { + reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin); + } else { + nir_builder_instr_insert(b, &chan_intrin->instr); + reads[i] = &chan_intrin->dest.ssa; + } } return nir_vec(b, reads, intrin->num_components); @@ -188,7 +224,7 @@ lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin, static nir_ssa_def * lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin, - bool lower_to_scalar) + bool lower_to_scalar, bool lower_to_32bit) { nir_ssa_def *index = nir_load_subgroup_invocation(b); switch (intrin->intrinsic) { @@ -241,7 +277,9 @@ lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin, intrin->dest.ssa.bit_size, NULL); if (lower_to_scalar && shuffle->num_components > 1) { - return lower_subgroup_op_to_scalar(b, shuffle); + return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit); + } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) { + return lower_subgroup_op_to_32bit(b, shuffle); } else { nir_builder_instr_insert(b, &shuffle->instr); return &shuffle->dest.ssa; @@ -279,7 +317,7 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, case nir_intrinsic_read_invocation: case nir_intrinsic_read_first_invocation: if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, false); break; case nir_intrinsic_load_subgroup_eq_mask: @@ -401,16 +439,20 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, case nir_intrinsic_shuffle: if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit); + else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) + return lower_subgroup_op_to_32bit(b, intrin); break; case nir_intrinsic_shuffle_xor: case nir_intrinsic_shuffle_up: case nir_intrinsic_shuffle_down: if (options->lower_shuffle) - return lower_shuffle(b, intrin, options->lower_to_scalar); + return lower_shuffle(b, intrin, options->lower_to_scalar, options->lower_shuffle_to_32bit); else if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit); + else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) + return lower_subgroup_op_to_32bit(b, intrin); break; case nir_intrinsic_quad_broadcast: @@ -418,16 +460,16 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, case nir_intrinsic_quad_swap_vertical: case nir_intrinsic_quad_swap_diagonal: if (options->lower_quad) - return lower_shuffle(b, intrin, options->lower_to_scalar); + return lower_shuffle(b, intrin, options->lower_to_scalar, false); else if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, false); break; case nir_intrinsic_reduce: case nir_intrinsic_inclusive_scan: case nir_intrinsic_exclusive_scan: if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, false); break; default: |