diff options
Diffstat (limited to 'src/amd/common/ac_llvm_build.c')
-rw-r--r-- | src/amd/common/ac_llvm_build.c | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/src/amd/common/ac_llvm_build.c b/src/amd/common/ac_llvm_build.c index 041b6cd797e..1551df07959 100644 --- a/src/amd/common/ac_llvm_build.c +++ b/src/amd/common/ac_llvm_build.c @@ -58,7 +58,8 @@ struct ac_llvm_flow { */ void ac_llvm_context_init(struct ac_llvm_context *ctx, - enum chip_class chip_class, enum radeon_family family) + enum chip_class chip_class, enum radeon_family family, + unsigned wave_size) { LLVMValueRef args[1]; @@ -66,6 +67,7 @@ ac_llvm_context_init(struct ac_llvm_context *ctx, ctx->chip_class = chip_class; ctx->family = family; + ctx->wave_size = wave_size; ctx->module = NULL; ctx->builder = NULL; @@ -2225,10 +2227,14 @@ ac_get_thread_id(struct ac_llvm_context *ctx) "llvm.amdgcn.mbcnt.lo", ctx->i32, tid_args, 2, AC_FUNC_ATTR_READNONE); - tid = ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.hi", - ctx->i32, tid_args, - 2, AC_FUNC_ATTR_READNONE); - set_range_metadata(ctx, tid, 0, 64); + if (ctx->wave_size == 32) { + tid = tid_args[1]; + } else { + tid = ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.hi", + ctx->i32, tid_args, + 2, AC_FUNC_ATTR_READNONE); + } + set_range_metadata(ctx, tid, 0, ctx->wave_size); return tid; } @@ -4260,7 +4266,7 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src))); result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); - result = ac_build_scan(ctx, op, result, identity, 64, true); + result = ac_build_scan(ctx, op, result, identity, ctx->wave_size, true); return ac_build_wwm(ctx, result); } @@ -4284,7 +4290,7 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src))); result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); - result = ac_build_scan(ctx, op, result, identity, 64, false); + result = ac_build_scan(ctx, op, result, identity, ctx->wave_size, false); return ac_build_wwm(ctx, result); } @@ -4360,12 +4366,12 @@ ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) if (ws->maxwaves <= 1) return; - const LLVMValueRef i32_63 = LLVMConstInt(ctx->i32, 63, false); + const LLVMValueRef last_lane = LLVMConstInt(ctx->i32, ctx->wave_size - 1, false); LLVMBuilderRef builder = ctx->builder; LLVMValueRef tid = ac_get_thread_id(ctx); LLVMValueRef tmp; - tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, i32_63, ""); + tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, last_lane, ""); ac_build_ifcc(ctx, tmp, 1000); LLVMBuildStore(builder, ws->src, LLVMBuildGEP(builder, ws->scratch, &ws->waveidx, 1, "")); ac_build_endif(ctx, 1000); |