diff options
Diffstat (limited to 'src/compiler/spirv')
-rw-r--r-- | src/compiler/spirv/spirv_to_nir.c | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 5becd3418da..df5bba2c2a0 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -4453,6 +4453,68 @@ vtn_create_builder(const uint32_t *words, size_t word_count, return NULL; } +static nir_function * +vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b, + const nir_function *entry_point) +{ + vtn_assert(entry_point == b->entry_point->func->impl->function); + vtn_fail_if(!entry_point->name, "entry points are required to have a name"); + const char *func_name = + ralloc_asprintf(b->shader, "__wrapped_%s", entry_point->name); + + /* we shouldn't have any inputs yet */ + vtn_assert(!entry_point->shader->num_inputs); + vtn_assert(b->shader->info.stage == MESA_SHADER_KERNEL); + + nir_function *main_entry_point = nir_function_create(b->shader, func_name); + main_entry_point->impl = nir_function_impl_create(main_entry_point); + nir_builder_init(&b->nb, main_entry_point->impl); + b->nb.cursor = nir_after_cf_list(&main_entry_point->impl->body); + b->func_param_idx = 0; + + nir_call_instr *call = nir_call_instr_create(b->nb.shader, entry_point); + + for (unsigned i = 0; i < entry_point->num_params; ++i) { + struct vtn_type *param_type = b->entry_point->func->type->params[i]; + + /* consider all pointers to function memory to be parameters passed + * by value + */ + bool is_by_val = param_type->base_type == vtn_base_type_pointer && + param_type->storage_class == SpvStorageClassFunction; + + /* input variable */ + nir_variable *in_var = rzalloc(b->nb.shader, nir_variable); + in_var->data.mode = nir_var_shader_in; + in_var->data.read_only = true; + in_var->data.location = i; + + if (is_by_val) + in_var->type = param_type->deref->type; + else + in_var->type = param_type->type; + + nir_shader_add_variable(b->nb.shader, in_var); + b->nb.shader->num_inputs++; + + /* we have to copy the entire variable into function memory */ + if (is_by_val) { + nir_variable *copy_var = + nir_local_variable_create(main_entry_point->impl, in_var->type, + "copy_in"); + nir_copy_var(&b->nb, copy_var, in_var); + call->params[i] = + nir_src_for_ssa(&nir_build_deref_var(&b->nb, copy_var)->dest.ssa); + } else { + call->params[i] = nir_src_for_ssa(nir_load_var(&b->nb, in_var)); + } + } + + nir_builder_instr_insert(&b->nb, &call->instr); + + return main_entry_point; +} + nir_function * spirv_to_nir(const uint32_t *words, size_t word_count, struct nir_spirv_specialization *spec, unsigned num_spec, @@ -4542,6 +4604,10 @@ spirv_to_nir(const uint32_t *words, size_t word_count, nir_function *entry_point = b->entry_point->func->impl->function; vtn_assert(entry_point); + /* post process entry_points with input params */ + if (entry_point->num_params && b->shader->info.stage == MESA_SHADER_KERNEL) + entry_point = vtn_emit_kernel_entry_point_wrapper(b, entry_point); + entry_point->is_entrypoint = true; /* When multiple shader stages exist in the same SPIR-V module, we |