summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/amd/vulkan/radv_nir_to_llvm.c14
-rw-r--r--src/amd/vulkan/radv_shader_info.c20
2 files changed, 20 insertions, 14 deletions
diff --git a/src/amd/vulkan/radv_nir_to_llvm.c b/src/amd/vulkan/radv_nir_to_llvm.c
index 0ba81322ac0..92fcec9015a 100644
--- a/src/amd/vulkan/radv_nir_to_llvm.c
+++ b/src/amd/vulkan/radv_nir_to_llvm.c
@@ -2845,22 +2845,8 @@ handle_es_outputs_post(struct radv_shader_context *ctx,
struct radv_es_output_info *outinfo)
{
int j;
- uint64_t max_output_written = 0;
LLVMValueRef lds_base = NULL;
- for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
- int param_index;
-
- if (!(ctx->output_mask & (1ull << i)))
- continue;
-
- param_index = shader_io_get_unique_index(i);
-
- max_output_written = MAX2(param_index, max_output_written);
- }
-
- outinfo->esgs_itemsize = (max_output_written + 1) * 16;
-
if (ctx->ac.chip_class >= GFX9) {
unsigned itemsize_dw = outinfo->esgs_itemsize / 4;
LLVMValueRef vertex_idx = ac_get_thread_id(&ctx->ac);
diff --git a/src/amd/vulkan/radv_shader_info.c b/src/amd/vulkan/radv_shader_info.c
index 065cec3e0e7..1ff429e63fd 100644
--- a/src/amd/vulkan/radv_shader_info.c
+++ b/src/amd/vulkan/radv_shader_info.c
@@ -753,4 +753,24 @@ radv_nir_shader_info_pass(const struct nir_shader *nir,
info->gs.max_gsvs_emit_size =
info->gs.gsvs_vertex_size * nir->info.gs.vertices_out;
}
+
+ /* Compute the ESGS item size for VS or TES as ES. */
+ if ((nir->info.stage == MESA_SHADER_VERTEX ||
+ nir->info.stage == MESA_SHADER_TESS_EVAL) &&
+ options->key.vs_common_out.as_es) {
+ struct radv_es_output_info *es_info =
+ nir->info.stage == MESA_SHADER_VERTEX ? &info->vs.es_info : &info->tes.es_info;
+ uint32_t max_output_written = 0;
+
+ uint64_t output_mask = nir->info.outputs_written;
+ while (output_mask) {
+ const int i = u_bit_scan64(&output_mask);
+ unsigned param_index = shader_io_get_unique_index(i);
+
+ max_output_written = MAX2(param_index, max_output_written);
+ }
+
+ es_info->esgs_itemsize = (max_output_written + 1) * 16;
+ }
+
}