diff options
-rw-r--r-- | src/compiler/nir/nir.h | 12 | ||||
-rw-r--r-- | src/compiler/nir/nir_lower_io.c | 405 |
2 files changed, 417 insertions, 0 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 29b4da22330..553410b92d1 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -2896,6 +2896,18 @@ bool nir_lower_io(nir_shader *shader, nir_variable_mode modes, int (*type_size)(const struct glsl_type *), nir_lower_io_options); + +typedef enum { + /** + * An address format which is comprised of a vec2 where the first + * component is a vulkan descriptor index and the second is an offset. + */ + nir_address_format_vk_index_offset, +} nir_address_format; +bool nir_lower_explicit_io(nir_shader *shader, + nir_variable_mode modes, + nir_address_format); + nir_src *nir_get_io_offset_src(nir_intrinsic_instr *instr); nir_src *nir_get_io_vertex_index_src(nir_intrinsic_instr *instr); diff --git a/src/compiler/nir/nir_lower_io.c b/src/compiler/nir/nir_lower_io.c index 2ccba8c032b..bcbfebdfa3b 100644 --- a/src/compiler/nir/nir_lower_io.c +++ b/src/compiler/nir/nir_lower_io.c @@ -528,6 +528,411 @@ nir_lower_io(nir_shader *shader, nir_variable_mode modes, return progress; } +static unsigned +type_scalar_size_bytes(const struct glsl_type *type) +{ + assert(glsl_type_is_vector_or_scalar(type) || + glsl_type_is_matrix(type)); + return glsl_type_is_boolean(type) ? 4 : glsl_get_bit_size(type) / 8; +} + +static nir_ssa_def * +build_addr_iadd(nir_builder *b, nir_ssa_def *addr, + nir_address_format addr_format, nir_ssa_def *offset) +{ + assert(offset->num_components == 1); + assert(addr->bit_size == offset->bit_size); + + switch (addr_format) { + case nir_address_format_vk_index_offset: + assert(addr->num_components == 2); + return nir_vec2(b, nir_channel(b, addr, 0), + nir_iadd(b, nir_channel(b, addr, 1), offset)); + } + unreachable("Invalid address format"); +} + +static nir_ssa_def * +build_addr_iadd_imm(nir_builder *b, nir_ssa_def *addr, + nir_address_format addr_format, int64_t offset) +{ + return build_addr_iadd(b, addr, addr_format, + nir_imm_intN_t(b, offset, addr->bit_size)); +} + +static nir_ssa_def * +addr_to_index(nir_builder *b, nir_ssa_def *addr, + nir_address_format addr_format) +{ + assert(addr_format == nir_address_format_vk_index_offset); + assert(addr->num_components == 2); + return nir_channel(b, addr, 0); +} + +static nir_ssa_def * +addr_to_offset(nir_builder *b, nir_ssa_def *addr, + nir_address_format addr_format) +{ + assert(addr_format == nir_address_format_vk_index_offset); + assert(addr->num_components == 2); + return nir_channel(b, addr, 1); +} + +static nir_ssa_def * +build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin, + nir_ssa_def *addr, nir_address_format addr_format, + unsigned num_components) +{ + nir_variable_mode mode = nir_src_as_deref(intrin->src[0])->mode; + + nir_intrinsic_op op; + switch (mode) { + case nir_var_ubo: + op = nir_intrinsic_load_ubo; + break; + case nir_var_ssbo: + op = nir_intrinsic_load_ssbo; + break; + default: + unreachable("Unsupported explicit IO variable mode"); + } + + nir_intrinsic_instr *load = nir_intrinsic_instr_create(b->shader, op); + + load->src[0] = nir_src_for_ssa(addr_to_index(b, addr, addr_format)); + load->src[1] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format)); + + if (mode != nir_var_ubo) + nir_intrinsic_set_access(load, nir_intrinsic_access(intrin)); + + /* TODO: We should try and provide a better alignment. For OpenCL, we need + * to plumb the alignment through from SPIR-V when we have one. + */ + nir_intrinsic_set_align(load, intrin->dest.ssa.bit_size / 8, 0); + + assert(intrin->dest.is_ssa); + load->num_components = num_components; + nir_ssa_dest_init(&load->instr, &load->dest, num_components, + intrin->dest.ssa.bit_size, intrin->dest.ssa.name); + nir_builder_instr_insert(b, &load->instr); + + return &load->dest.ssa; +} + +static void +build_explicit_io_store(nir_builder *b, nir_intrinsic_instr *intrin, + nir_ssa_def *addr, nir_address_format addr_format, + nir_ssa_def *value, nir_component_mask_t write_mask) +{ + nir_variable_mode mode = nir_src_as_deref(intrin->src[0])->mode; + + nir_intrinsic_op op; + switch (mode) { + case nir_var_ssbo: + op = nir_intrinsic_store_ssbo; + break; + default: + unreachable("Unsupported explicit IO variable mode"); + } + + nir_intrinsic_instr *store = nir_intrinsic_instr_create(b->shader, op); + + store->src[0] = nir_src_for_ssa(value); + store->src[1] = nir_src_for_ssa(addr_to_index(b, addr, addr_format)); + store->src[2] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format)); + + nir_intrinsic_set_write_mask(store, write_mask); + + nir_intrinsic_set_access(store, nir_intrinsic_access(intrin)); + + /* TODO: We should try and provide a better alignment. For OpenCL, we need + * to plumb the alignment through from SPIR-V when we have one. + */ + nir_intrinsic_set_align(store, value->bit_size / 8, 0); + + assert(value->num_components == 1 || + value->num_components == intrin->num_components); + store->num_components = value->num_components; + nir_builder_instr_insert(b, &store->instr); +} + +static nir_ssa_def * +build_explicit_io_atomic(nir_builder *b, nir_intrinsic_instr *intrin, + nir_ssa_def *addr, nir_address_format addr_format) +{ + nir_variable_mode mode = nir_src_as_deref(intrin->src[0])->mode; + const unsigned num_data_srcs = + nir_intrinsic_infos[intrin->intrinsic].num_srcs - 1; + + nir_intrinsic_op op; + switch (mode) { + case nir_var_ssbo: + switch (intrin->intrinsic) { +#define OP(O) case nir_intrinsic_deref_##O: op = nir_intrinsic_ssbo_##O; break; + OP(atomic_exchange) + OP(atomic_comp_swap) + OP(atomic_add) + OP(atomic_imin) + OP(atomic_umin) + OP(atomic_imax) + OP(atomic_umax) + OP(atomic_and) + OP(atomic_or) + OP(atomic_xor) + OP(atomic_fadd) + OP(atomic_fmin) + OP(atomic_fmax) + OP(atomic_fcomp_swap) +#undef OP + default: + unreachable("Invalid SSBO atomic"); + } + break; + default: + unreachable("Unsupported explicit IO variable mode"); + } + + nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b->shader, op); + + atomic->src[0] = nir_src_for_ssa(addr_to_index(b, addr, addr_format)); + atomic->src[1] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format)); + for (unsigned i = 0; i < num_data_srcs; i++) { + assert(intrin->src[1 + i].is_ssa); + atomic->src[2 + i] = nir_src_for_ssa(intrin->src[1 + i].ssa); + } + + assert(intrin->dest.ssa.num_components == 1); + nir_ssa_dest_init(&atomic->instr, &atomic->dest, + 1, intrin->dest.ssa.bit_size, intrin->dest.ssa.name); + nir_builder_instr_insert(b, &atomic->instr); + + return &atomic->dest.ssa; +} + +static void +lower_explicit_io_deref(nir_builder *b, nir_deref_instr *deref, + nir_address_format addr_format) +{ + /* Just delete the deref if it's not used. We can't use + * nir_deref_instr_remove_if_unused here because it may remove more than + * one deref which could break our list walking since we walk the list + * backwards. + */ + assert(list_empty(&deref->dest.ssa.if_uses)); + if (list_empty(&deref->dest.ssa.uses)) { + nir_instr_remove(&deref->instr); + return; + } + + b->cursor = nir_after_instr(&deref->instr); + + /* Var derefs must be lowered away by the driver */ + assert(deref->deref_type != nir_deref_type_var); + + assert(deref->parent.is_ssa); + nir_ssa_def *parent_addr = deref->parent.ssa; + + nir_ssa_def *addr; + assert(deref->dest.is_ssa); + switch (deref->deref_type) { + case nir_deref_type_var: + unreachable("Must be lowered by the driver"); + break; + + case nir_deref_type_array: { + nir_deref_instr *parent = nir_deref_instr_parent(deref); + + unsigned stride = glsl_get_explicit_stride(parent->type); + if ((glsl_type_is_matrix(parent->type) && + glsl_matrix_type_is_row_major(parent->type)) || + (glsl_type_is_vector(parent->type) && stride == 0)) + stride = type_scalar_size_bytes(parent->type); + + assert(stride > 0); + + nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1); + index = nir_i2i(b, index, parent_addr->bit_size); + addr = build_addr_iadd(b, parent_addr, addr_format, + nir_imul_imm(b, index, stride)); + break; + } + + case nir_deref_type_ptr_as_array: { + nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1); + index = nir_i2i(b, index, parent_addr->bit_size); + unsigned stride = nir_deref_instr_ptr_as_array_stride(deref); + addr = build_addr_iadd(b, parent_addr, addr_format, + nir_imul_imm(b, index, stride)); + break; + } + + case nir_deref_type_array_wildcard: + unreachable("Wildcards should be lowered by now"); + break; + + case nir_deref_type_struct: { + nir_deref_instr *parent = nir_deref_instr_parent(deref); + int offset = glsl_get_struct_field_offset(parent->type, + deref->strct.index); + assert(offset >= 0); + addr = build_addr_iadd_imm(b, parent_addr, addr_format, offset); + break; + } + + case nir_deref_type_cast: + /* Nothing to do here */ + addr = parent_addr; + break; + } + + nir_instr_remove(&deref->instr); + nir_ssa_def_rewrite_uses(&deref->dest.ssa, nir_src_for_ssa(addr)); +} + +static void +lower_explicit_io_access(nir_builder *b, nir_intrinsic_instr *intrin, + nir_address_format addr_format) +{ + b->cursor = nir_after_instr(&intrin->instr); + + nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); + unsigned vec_stride = glsl_get_explicit_stride(deref->type); + unsigned scalar_size = type_scalar_size_bytes(deref->type); + assert(vec_stride == 0 || glsl_type_is_vector(deref->type)); + assert(vec_stride == 0 || vec_stride >= scalar_size); + + nir_ssa_def *addr = &deref->dest.ssa; + if (intrin->intrinsic == nir_intrinsic_load_deref) { + nir_ssa_def *value; + if (vec_stride > scalar_size) { + nir_ssa_def *comps[4] = { NULL, }; + for (unsigned i = 0; i < intrin->num_components; i++) { + nir_ssa_def *comp_addr = build_addr_iadd_imm(b, addr, addr_format, + vec_stride * i); + comps[i] = build_explicit_io_load(b, intrin, comp_addr, + addr_format, 1); + } + value = nir_vec(b, comps, intrin->num_components); + } else { + value = build_explicit_io_load(b, intrin, addr, addr_format, + intrin->num_components); + } + nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(value)); + } else if (intrin->intrinsic == nir_intrinsic_store_deref) { + assert(intrin->src[1].is_ssa); + nir_ssa_def *value = intrin->src[1].ssa; + nir_component_mask_t write_mask = nir_intrinsic_write_mask(intrin); + if (vec_stride > scalar_size) { + for (unsigned i = 0; i < intrin->num_components; i++) { + if (!(write_mask & (1 << i))) + continue; + + nir_ssa_def *comp_addr = build_addr_iadd_imm(b, addr, addr_format, + vec_stride * i); + build_explicit_io_store(b, intrin, comp_addr, addr_format, + nir_channel(b, value, i), 1); + } + } else { + build_explicit_io_store(b, intrin, addr, addr_format, + value, write_mask); + } + } else { + nir_ssa_def *value = + build_explicit_io_atomic(b, intrin, addr, addr_format); + nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(value)); + } + + nir_instr_remove(&intrin->instr); +} + +static bool +nir_lower_explicit_io_impl(nir_function_impl *impl, nir_variable_mode modes, + nir_address_format addr_format) +{ + bool progress = false; + + nir_builder b; + nir_builder_init(&b, impl); + + /* Walk in reverse order so that we can see the full deref chain when we + * lower the access operations. We lower them assuming that the derefs + * will be turned into address calculations later. + */ + nir_foreach_block_reverse(block, impl) { + nir_foreach_instr_reverse_safe(instr, block) { + switch (instr->type) { + case nir_instr_type_deref: { + nir_deref_instr *deref = nir_instr_as_deref(instr); + if (deref->mode & modes) { + lower_explicit_io_deref(&b, deref, addr_format); + progress = true; + } + break; + } + + case nir_instr_type_intrinsic: { + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + switch (intrin->intrinsic) { + case nir_intrinsic_load_deref: + case nir_intrinsic_store_deref: + case nir_intrinsic_deref_atomic_add: + case nir_intrinsic_deref_atomic_imin: + case nir_intrinsic_deref_atomic_umin: + case nir_intrinsic_deref_atomic_imax: + case nir_intrinsic_deref_atomic_umax: + case nir_intrinsic_deref_atomic_and: + case nir_intrinsic_deref_atomic_or: + case nir_intrinsic_deref_atomic_xor: + case nir_intrinsic_deref_atomic_exchange: + case nir_intrinsic_deref_atomic_comp_swap: + case nir_intrinsic_deref_atomic_fadd: + case nir_intrinsic_deref_atomic_fmin: + case nir_intrinsic_deref_atomic_fmax: + case nir_intrinsic_deref_atomic_fcomp_swap: { + nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); + if (deref->mode & modes) { + lower_explicit_io_access(&b, intrin, addr_format); + progress = true; + } + break; + } + + default: + break; + } + break; + } + + default: + /* Nothing to do */ + break; + } + } + } + + if (progress) { + nir_metadata_preserve(impl, nir_metadata_block_index | + nir_metadata_dominance); + } + + return progress; +} + +bool +nir_lower_explicit_io(nir_shader *shader, nir_variable_mode modes, + nir_address_format addr_format) +{ + bool progress = false; + + nir_foreach_function(function, shader) { + if (function->impl && + nir_lower_explicit_io_impl(function->impl, modes, addr_format)) + progress = true; + } + + return progress; +} + /** * Return the offset source for a load/store intrinsic. */ |