aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorBas Nieuwenhuizen <[email protected]>2019-11-27 23:33:59 +0100
committerBas Nieuwenhuizen <[email protected]>2019-11-28 11:35:11 +0100
commite09426ad6bea4156a02958c59804263dae8dbf82 (patch)
treecfdaee06e21b7d5111e446f732c5ade2fb231291 /src
parentd347f2805d8d8c37eb3e50483346bff9583c8e48 (diff)
amd/llvm: Refactor ac_build_scan.
Split out the logic for exclusive scans into a separate function that makes clear what it does instead of having this opaque 60 line if. Reviewed-by: Samuel Pitoiset <[email protected]>
Diffstat (limited to 'src')
-rw-r--r--src/amd/llvm/ac_llvm_build.c91
1 files changed, 51 insertions, 40 deletions
diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c
index cf6eda30e2c..47c27893fe5 100644
--- a/src/amd/llvm/ac_llvm_build.c
+++ b/src/amd/llvm/ac_llvm_build.c
@@ -4045,18 +4045,17 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
}
/**
+ * \param src The value to shift.
+ * \param identity The value to use the first lane.
* \param maxprefix specifies that the result only needs to be correct for a
* prefix of this many threads
+ * \return src, shifted 1 lane up, and identity shifted into lane 0.
*/
static LLVMValueRef
-ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
- unsigned maxprefix, bool inclusive)
+ac_wavefront_shift_right_1(struct ac_llvm_context *ctx, LLVMValueRef src,
+ LLVMValueRef identity, unsigned maxprefix)
{
- LLVMValueRef result, tmp;
-
- if (inclusive) {
- result = src;
- } else if (ctx->chip_class >= GFX10) {
+ if (ctx->chip_class >= GFX10) {
/* wavefront shift_right by 1 on GFX10 (emulate dpp_wf_sr1) */
LLVMValueRef active, tmp1, tmp2;
LLVMValueRef tid = ac_get_thread_id(ctx);
@@ -4079,45 +4078,57 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
LLVMBuildAnd(ctx->builder, tid,
LLVMConstInt(ctx->i32, 0x1f, false), ""),
LLVMConstInt(ctx->i32, 0x10, false), ""), "");
- src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+ return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
} else if (maxprefix > 16) {
active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid,
LLVMConstInt(ctx->i32, 16, false), "");
- src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+ return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
}
-
- result = src;
} else if (ctx->chip_class >= GFX8) {
- src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
- result = src;
- } else {
- /* wavefront shift_right by 1 on SI/CI */
- LLVMValueRef active, tmp1, tmp2;
- LLVMValueRef tid = ac_get_thread_id(ctx);
- tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
- tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
- active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
- LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
- LLVMConstInt(ctx->i32, 0x4, 0), "");
- tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
- tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
- active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
- LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
- LLVMConstInt(ctx->i32, 0x8, 0), "");
- tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
- tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
- active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
- LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
- LLVMConstInt(ctx->i32, 0x10, 0), "");
- tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
- tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
- active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
- tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
- active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
- src = LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
- result = src;
- }
+ return ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
+ }
+
+ /* wavefront shift_right by 1 on SI/CI */
+ LLVMValueRef active, tmp1, tmp2;
+ LLVMValueRef tid = ac_get_thread_id(ctx);
+ tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
+ tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
+ active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+ LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
+ LLVMConstInt(ctx->i32, 0x4, 0), "");
+ tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+ tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
+ active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+ LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
+ LLVMConstInt(ctx->i32, 0x8, 0), "");
+ tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+ tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
+ active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+ LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
+ LLVMConstInt(ctx->i32, 0x10, 0), "");
+ tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+ tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
+ active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
+ tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+ active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
+ return LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
+}
+
+/**
+ * \param maxprefix specifies that the result only needs to be correct for a
+ * prefix of this many threads
+ */
+static LLVMValueRef
+ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
+ unsigned maxprefix, bool inclusive)
+{
+ LLVMValueRef result, tmp;
+
+ if (!inclusive)
+ src = ac_wavefront_shift_right_1(ctx, src, identity, maxprefix);
+
+ result = src;
if (ctx->chip_class <= GFX7) {
assert(maxprefix == 64);