summaryrefslogtreecommitdiffstats
path: root/src/compiler/nir/nir_search.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/nir/nir_search.c')
-rw-r--r--src/compiler/nir/nir_search.c115
1 files changed, 94 insertions, 21 deletions
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 6bb2f35aae8..c1b179525ab 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -28,6 +28,7 @@
#include <inttypes.h>
#include "nir_search.h"
#include "nir_builder.h"
+#include "nir_worklist.h"
#include "util/half_float.h"
/* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
@@ -51,7 +52,7 @@ static bool
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
unsigned num_components, const uint8_t *swizzle,
struct match_state *state);
-static void
+static bool
nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
const struct per_op_table *pass_op_table);
@@ -640,13 +641,50 @@ UNUSED static void dump_value(const nir_search_value *val)
fprintf(stderr, "@%d", val->bit_size);
}
+static void
+add_uses_to_worklist(nir_instr *instr, nir_instr_worklist *worklist)
+{
+ nir_ssa_def *def = nir_instr_ssa_def(instr);
+
+ nir_foreach_use_safe(use_src, def) {
+ nir_instr_worklist_push_tail(worklist, use_src->parent_instr);
+ }
+}
+
+static void
+nir_algebraic_update_automaton(nir_instr *new_instr,
+ nir_instr_worklist *algebraic_worklist,
+ struct util_dynarray *states,
+ const struct per_op_table *pass_op_table)
+{
+
+ nir_instr_worklist *automaton_worklist = nir_instr_worklist_create();
+
+ /* Walk through the tree of uses of our new instruction's SSA value,
+ * recursively updating the automaton state until it stabilizes.
+ */
+ add_uses_to_worklist(new_instr, automaton_worklist);
+
+ nir_instr *instr;
+ while ((instr = nir_instr_worklist_pop_head(automaton_worklist))) {
+ if (nir_algebraic_automaton(instr, states, pass_op_table)) {
+ nir_instr_worklist_push_tail(algebraic_worklist, instr);
+
+ add_uses_to_worklist(instr, automaton_worklist);
+ }
+ }
+
+ nir_instr_worklist_destroy(automaton_worklist);
+}
+
nir_ssa_def *
nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
struct hash_table *range_ht,
struct util_dynarray *states,
const struct per_op_table *pass_op_table,
const nir_search_expression *search,
- const nir_search_value *replace)
+ const nir_search_value *replace,
+ nir_instr_worklist *algebraic_worklist)
{
uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
@@ -711,18 +749,23 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
}
+ /* Rewrite the uses of the old SSA value to the new one, and recurse
+ * through the uses updating the automaton's state.
+ */
nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
+ nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist,
+ states, pass_op_table);
- /* We know this one has no more uses because we just rewrote them all,
- * so we can remove it. The rest of the matched expression, however, we
- * don't know so much about. We'll just let dead code clean them up.
+ /* Nothing uses the instr any more, so drop it out of the program. Note
+ * that the instr may be in the worklist still, so we can't free it
+ * directly.
*/
nir_instr_remove(&instr->instr);
return ssa_val;
}
-static void
+static bool
nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
const struct per_op_table *pass_op_table)
{
@@ -733,7 +776,7 @@ nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
uint16_t search_op = nir_search_op_for_nir_op(op);
const struct per_op_table *tbl = &pass_op_table[search_op];
if (tbl->num_filtered_states == 0)
- return;
+ return false;
/* Calculate the index into the transition table. Note the index
* calculated must match the iteration order of Python's
@@ -746,20 +789,29 @@ nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
index += tbl->filter[*util_dynarray_element(states, uint16_t,
alu->src[i].src.ssa->index)];
}
- *util_dynarray_element(states, uint16_t, alu->dest.dest.ssa.index) =
- tbl->table[index];
- break;
+
+ uint16_t *state = util_dynarray_element(states, uint16_t,
+ alu->dest.dest.ssa.index);
+ if (*state != tbl->table[index]) {
+ *state = tbl->table[index];
+ return true;
+ }
+ return false;
}
case nir_instr_type_load_const: {
nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
- *util_dynarray_element(states, uint16_t, load_const->def.index) =
- CONST_STATE;
- break;
+ uint16_t *state = util_dynarray_element(states, uint16_t,
+ load_const->def.index);
+ if (*state != CONST_STATE) {
+ *state = CONST_STATE;
+ return true;
+ }
+ return false;
}
default:
- break;
+ return false;
}
}
@@ -770,7 +822,8 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr,
const struct transform **transforms,
const uint16_t *transform_counts,
struct util_dynarray *states,
- const struct per_op_table *pass_op_table)
+ const struct per_op_table *pass_op_table,
+ nir_instr_worklist *worklist)
{
if (instr->type != nir_instr_type_alu)
@@ -794,7 +847,7 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr,
if (condition_flags[xform->condition_offset] &&
!(xform->search->inexact && ignore_inexact) &&
nir_replace_instr(build, alu, range_ht, states, pass_op_table,
- xform->search, xform->replace)) {
+ xform->search, xform->replace, worklist)) {
_mesa_hash_table_clear(range_ht, NULL);
return true;
}
@@ -826,21 +879,41 @@ nir_algebraic_impl(nir_function_impl *impl,
struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
+ nir_instr_worklist *worklist = nir_instr_worklist_create();
+
+ /* Walk top-to-bottom setting up the automaton state. */
nir_foreach_block(block, impl) {
nir_foreach_instr(instr, block) {
nir_algebraic_automaton(instr, &states, pass_op_table);
}
}
+ /* Put our instrs in the worklist such that we're popping the last instr
+ * first. This will encourage us to match the biggest source patterns when
+ * possible.
+ */
nir_foreach_block_reverse(block, impl) {
- nir_foreach_instr_reverse_safe(instr, block) {
- progress |= nir_algebraic_instr(&build, instr,
- range_ht, condition_flags,
- transforms, transform_counts, &states,
- pass_op_table);
+ nir_foreach_instr_reverse(instr, block) {
+ nir_instr_worklist_push_tail(worklist, instr);
}
}
+ nir_instr *instr;
+ while ((instr = nir_instr_worklist_pop_head(worklist))) {
+ /* The worklist can have an instr pushed to it multiple times if it was
+ * the src of multiple instrs that also got optimized, so make sure that
+ * we don't try to re-optimize an instr we already handled.
+ */
+ if (exec_node_is_tail_sentinel(&instr->node))
+ continue;
+
+ progress |= nir_algebraic_instr(&build, instr,
+ range_ht, condition_flags,
+ transforms, transform_counts, &states,
+ pass_op_table, worklist);
+ }
+
+ nir_instr_worklist_destroy(worklist);
ralloc_free(range_ht);
util_dynarray_fini(&states);