summaryrefslogtreecommitdiffstats
path: root/src/amd/common/ac_llvm_build.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/amd/common/ac_llvm_build.c')
-rw-r--r--src/amd/common/ac_llvm_build.c24
1 files changed, 15 insertions, 9 deletions
diff --git a/src/amd/common/ac_llvm_build.c b/src/amd/common/ac_llvm_build.c
index 041b6cd797e..1551df07959 100644
--- a/src/amd/common/ac_llvm_build.c
+++ b/src/amd/common/ac_llvm_build.c
@@ -58,7 +58,8 @@ struct ac_llvm_flow {
*/
void
ac_llvm_context_init(struct ac_llvm_context *ctx,
- enum chip_class chip_class, enum radeon_family family)
+ enum chip_class chip_class, enum radeon_family family,
+ unsigned wave_size)
{
LLVMValueRef args[1];
@@ -66,6 +67,7 @@ ac_llvm_context_init(struct ac_llvm_context *ctx,
ctx->chip_class = chip_class;
ctx->family = family;
+ ctx->wave_size = wave_size;
ctx->module = NULL;
ctx->builder = NULL;
@@ -2225,10 +2227,14 @@ ac_get_thread_id(struct ac_llvm_context *ctx)
"llvm.amdgcn.mbcnt.lo", ctx->i32,
tid_args, 2, AC_FUNC_ATTR_READNONE);
- tid = ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.hi",
- ctx->i32, tid_args,
- 2, AC_FUNC_ATTR_READNONE);
- set_range_metadata(ctx, tid, 0, 64);
+ if (ctx->wave_size == 32) {
+ tid = tid_args[1];
+ } else {
+ tid = ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.hi",
+ ctx->i32, tid_args,
+ 2, AC_FUNC_ATTR_READNONE);
+ }
+ set_range_metadata(ctx, tid, 0, ctx->wave_size);
return tid;
}
@@ -4260,7 +4266,7 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
LLVMTypeOf(identity), "");
- result = ac_build_scan(ctx, op, result, identity, 64, true);
+ result = ac_build_scan(ctx, op, result, identity, ctx->wave_size, true);
return ac_build_wwm(ctx, result);
}
@@ -4284,7 +4290,7 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
LLVMTypeOf(identity), "");
- result = ac_build_scan(ctx, op, result, identity, 64, false);
+ result = ac_build_scan(ctx, op, result, identity, ctx->wave_size, false);
return ac_build_wwm(ctx, result);
}
@@ -4360,12 +4366,12 @@ ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
if (ws->maxwaves <= 1)
return;
- const LLVMValueRef i32_63 = LLVMConstInt(ctx->i32, 63, false);
+ const LLVMValueRef last_lane = LLVMConstInt(ctx->i32, ctx->wave_size - 1, false);
LLVMBuilderRef builder = ctx->builder;
LLVMValueRef tid = ac_get_thread_id(ctx);
LLVMValueRef tmp;
- tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, i32_63, "");
+ tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, last_lane, "");
ac_build_ifcc(ctx, tmp, 1000);
LLVMBuildStore(builder, ws->src, LLVMBuildGEP(builder, ws->scratch, &ws->waveidx, 1, ""));
ac_build_endif(ctx, 1000);