diff options
Diffstat (limited to 'src/amd/common')
-rw-r--r-- | src/amd/common/ac_nir_to_llvm.c | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c index 84a1e2462b2..5820e62105c 100644 --- a/src/amd/common/ac_nir_to_llvm.c +++ b/src/amd/common/ac_nir_to_llvm.c @@ -275,7 +275,7 @@ static LLVMValueRef create_llvm_function(LLVMContextRef ctx, LLVMModuleRef module, LLVMBuilderRef builder, LLVMTypeRef *return_types, unsigned num_return_elems, LLVMTypeRef *param_types, - unsigned param_count, unsigned array_params, + unsigned param_count, unsigned array_params_mask, unsigned sgpr_params, bool unsafe_math) { LLVMTypeRef main_function_type, ret_type; @@ -298,7 +298,7 @@ create_llvm_function(LLVMContextRef ctx, LLVMModuleRef module, LLVMSetFunctionCallConv(main_function, RADEON_LLVM_AMDGPU_CS); for (unsigned i = 0; i < sgpr_params; ++i) { - if (i < array_params) { + if (array_params_mask & (1 << i)) { LLVMValueRef P = LLVMGetParam(main_function, i); ac_add_function_attr(main_function, i + 1, AC_FUNC_ATTR_BYVAL); ac_add_attr_dereferenceable(P, UINT64_MAX); @@ -455,7 +455,7 @@ static void create_function(struct nir_to_llvm_context *ctx) { LLVMTypeRef arg_types[23]; unsigned arg_idx = 0; - unsigned array_count = 0; + unsigned array_params_mask = 0; unsigned sgpr_count = 0, user_sgpr_count; unsigned i; unsigned num_sets = ctx->options->layout ? ctx->options->layout->num_sets : 0; @@ -472,16 +472,17 @@ static void create_function(struct nir_to_llvm_context *ctx) /* 1 for each descriptor set */ for (unsigned i = 0; i < num_sets; ++i) { if (ctx->options->layout->set[i].layout->shader_stages & (1 << ctx->stage)) { + array_params_mask |= (1 << arg_idx); arg_types[arg_idx++] = const_array(ctx->i8, 1024 * 1024); } } if (need_push_constants) { /* 1 for push constants and dynamic descriptors */ + array_params_mask |= (1 << arg_idx); arg_types[arg_idx++] = const_array(ctx->i8, 1024 * 1024); } - array_count = arg_idx; switch (ctx->stage) { case MESA_SHADER_COMPUTE: arg_types[arg_idx++] = LLVMVectorType(ctx->i32, 3); /* grid size */ @@ -530,7 +531,7 @@ static void create_function(struct nir_to_llvm_context *ctx) ctx->main_function = create_llvm_function( ctx->context, ctx->module, ctx->builder, NULL, 0, arg_types, - arg_idx, array_count, sgpr_count, ctx->options->unsafe_math); + arg_idx, array_params_mask, sgpr_count, ctx->options->unsafe_math); set_llvm_calling_convention(ctx->main_function, ctx->stage); |