diff options
-rw-r--r-- | src/compiler/nir/nir_search.c | 113 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.h | 3 |
2 files changed, 78 insertions, 38 deletions
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index b78d3046a7b..e6f36493fe2 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -38,6 +38,11 @@ struct match_state { bool has_exact_alu; uint8_t comm_op_direction; unsigned variables_seen; + + /* Used for running the automaton on newly-constructed instructions. */ + struct util_dynarray *states; + const struct per_op_table *pass_op_table; + nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES]; struct hash_table *range_ht; }; @@ -46,6 +51,9 @@ static bool match_expression(const nir_search_expression *expr, nir_alu_instr *instr, unsigned num_components, const uint8_t *swizzle, struct match_state *state); +static void +nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states, + const struct per_op_table *pass_op_table); static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 }; @@ -490,6 +498,11 @@ construct_value(nir_builder *build, nir_builder_instr_insert(build, &alu->instr); + assert(alu->dest.dest.ssa.index == + util_dynarray_num_elements(state->states, uint16_t)); + util_dynarray_append(state->states, uint16_t, 0); + nir_algebraic_automaton(&alu->instr, state->states, state->pass_op_table); + nir_alu_src val; val.src = nir_src_for_ssa(&alu->dest.dest.ssa); val.negate = false; @@ -537,6 +550,12 @@ construct_value(nir_builder *build, unreachable("Invalid alu source type"); } + assert(cval->index == + util_dynarray_num_elements(state->states, uint16_t)); + util_dynarray_append(state->states, uint16_t, 0); + nir_algebraic_automaton(cval->parent_instr, state->states, + state->pass_op_table); + nir_alu_src val; val.src = nir_src_for_ssa(cval); val.negate = false; @@ -624,6 +643,8 @@ UNUSED static void dump_value(const nir_search_value *val) nir_ssa_def * nir_replace_instr(nir_builder *build, nir_alu_instr *instr, struct hash_table *range_ht, + struct util_dynarray *states, + const struct per_op_table *pass_op_table, const nir_search_expression *search, const nir_search_value *replace) { @@ -638,6 +659,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, state.inexact_match = false; state.has_exact_alu = false; state.range_ht = range_ht; + state.pass_op_table = pass_op_table; STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS); @@ -672,6 +694,8 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, build->cursor = nir_before_instr(&instr->instr); + state.states = states; + nir_alu_src val = construct_value(build, replace, instr->dest.dest.ssa.num_components, instr->dest.dest.ssa.bit_size, @@ -682,6 +706,11 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, */ nir_ssa_def *ssa_val = nir_mov_alu(build, val, instr->dest.dest.ssa.num_components); + if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) { + util_dynarray_append(states, uint16_t, 0); + nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table); + } + nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val)); /* We know this one has no more uses because we just rewrote them all, @@ -694,42 +723,43 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, } static void -nir_algebraic_automaton(nir_block *block, uint16_t *states, +nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states, const struct per_op_table *pass_op_table) { - nir_foreach_instr(instr, block) { - switch (instr->type) { - case nir_instr_type_alu: { - nir_alu_instr *alu = nir_instr_as_alu(instr); - nir_op op = alu->op; - uint16_t search_op = nir_search_op_for_nir_op(op); - const struct per_op_table *tbl = &pass_op_table[search_op]; - if (tbl->num_filtered_states == 0) - continue; - - /* Calculate the index into the transition table. Note the index - * calculated must match the iteration order of Python's - * itertools.product(), which was used to emit the transition - * table. - */ - uint16_t index = 0; - for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { - index *= tbl->num_filtered_states; - index += tbl->filter[states[alu->src[i].src.ssa->index]]; - } - states[alu->dest.dest.ssa.index] = tbl->table[index]; - break; + switch (instr->type) { + case nir_instr_type_alu: { + nir_alu_instr *alu = nir_instr_as_alu(instr); + nir_op op = alu->op; + uint16_t search_op = nir_search_op_for_nir_op(op); + const struct per_op_table *tbl = &pass_op_table[search_op]; + if (tbl->num_filtered_states == 0) + return; + + /* Calculate the index into the transition table. Note the index + * calculated must match the iteration order of Python's + * itertools.product(), which was used to emit the transition + * table. + */ + uint16_t index = 0; + for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { + index *= tbl->num_filtered_states; + index += tbl->filter[*util_dynarray_element(states, uint16_t, + alu->src[i].src.ssa->index)]; } + *util_dynarray_element(states, uint16_t, alu->dest.dest.ssa.index) = + tbl->table[index]; + break; + } - case nir_instr_type_load_const: { - nir_load_const_instr *load_const = nir_instr_as_load_const(instr); - states[load_const->def.index] = CONST_STATE; - break; - } + case nir_instr_type_load_const: { + nir_load_const_instr *load_const = nir_instr_as_load_const(instr); + *util_dynarray_element(states, uint16_t, load_const->def.index) = + CONST_STATE; + break; + } - default: - break; - } + default: + break; } } @@ -739,7 +769,8 @@ nir_algebraic_block(nir_builder *build, nir_block *block, const bool *condition_flags, const struct transform **transforms, const uint16_t *transform_counts, - const uint16_t *states) + struct util_dynarray *states, + const struct per_op_table *pass_op_table) { bool progress = false; const unsigned execution_mode = build->shader->info.float_controls_execution_mode; @@ -757,12 +788,13 @@ nir_algebraic_block(nir_builder *build, nir_block *block, nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) || nir_is_denorm_flush_to_zero(execution_mode, bit_size); - int xform_idx = states[alu->dest.dest.ssa.index]; + int xform_idx = *util_dynarray_element(states, uint16_t, + alu->dest.dest.ssa.index); for (uint16_t i = 0; i < transform_counts[xform_idx]; i++) { const struct transform *xform = &transforms[xform_idx][i]; if (condition_flags[xform->condition_offset] && !(xform->search->inexact && ignore_inexact) && - nir_replace_instr(build, alu, range_ht, + nir_replace_instr(build, alu, range_ht, states, pass_op_table, xform->search, xform->replace)) { _mesa_hash_table_clear(range_ht, NULL); progress = true; @@ -790,22 +822,27 @@ nir_algebraic_impl(nir_function_impl *impl, * state 0 is the default state, which means we don't have to visit * anything other than constants and ALU instructions. */ - uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states)); + struct util_dynarray states = {0}; + if (!util_dynarray_resize(&states, uint16_t, impl->ssa_alloc)) + return false; + memset(states.data, 0, states.size); struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL); nir_foreach_block(block, impl) { - nir_algebraic_automaton(block, states, pass_op_table); + nir_foreach_instr(instr, block) { + nir_algebraic_automaton(instr, &states, pass_op_table); + } } nir_foreach_block_reverse(block, impl) { progress |= nir_algebraic_block(&build, block, range_ht, condition_flags, transforms, transform_counts, - states); + &states, pass_op_table); } ralloc_free(range_ht); - free(states); + util_dynarray_fini(&states); if (progress) { nir_metadata_preserve(impl, nir_metadata_block_index | diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index 80d153916c8..9d567f88165 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -29,6 +29,7 @@ #define _NIR_SEARCH_ #include "nir.h" +#include "util/u_dynarray.h" #define NIR_SEARCH_MAX_VARIABLES 16 @@ -198,6 +199,8 @@ NIR_DEFINE_CAST(nir_search_value_as_expression, nir_search_value, nir_ssa_def * nir_replace_instr(struct nir_builder *b, nir_alu_instr *instr, struct hash_table *range_ht, + struct util_dynarray *states, + const struct per_op_table *pass_op_table, const nir_search_expression *search, const nir_search_value *replace); bool |