summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/compiler/nir/nir_algebraic.py20
-rw-r--r--src/compiler/nir/nir_search.c61
-rw-r--r--src/compiler/nir/nir_search.h12
3 files changed, 64 insertions, 29 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index fe9d1051e67..d4b3bb5957f 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -114,6 +114,7 @@ static const ${val.c_type} ${val.name} = {
${val.cond if val.cond else 'NULL'},
% elif isinstance(val, Expression):
${'true' if val.inexact else 'false'},
+ ${val.comm_expr_idx}, ${val.comm_exprs},
${val.c_opcode()},
{ ${', '.join(src.c_ptr for src in val.sources)} },
${val.cond if val.cond else 'NULL'},
@@ -307,6 +308,25 @@ class Expression(Value):
'Expression cannot use an unsized conversion opcode with ' \
'an explicit size; that\'s silly.'
+ self.__index_comm_exprs(0)
+
+ def __index_comm_exprs(self, base_idx):
+ """Recursively count and index commutative expressions
+ """
+ self.comm_exprs = 0
+ if self.opcode not in conv_opcode_types and \
+ "commutative" in opcodes[self.opcode].algebraic_properties:
+ self.comm_expr_idx = base_idx
+ self.comm_exprs += 1
+ else:
+ self.comm_expr_idx = -1
+
+ for s in self.sources:
+ if isinstance(s, Expression):
+ s.__index_comm_exprs(base_idx + self.comm_exprs)
+ self.comm_exprs += s.comm_exprs
+
+ return self.comm_exprs
def c_opcode(self):
if self.opcode in conv_opcode_types:
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index d257b639189..df27a2473ee 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -30,9 +30,12 @@
#include "nir_builder.h"
#include "util/half_float.h"
+#define NIR_SEARCH_MAX_COMM_OPS 4
+
struct match_state {
bool inexact_match;
bool has_exact_alu;
+ uint8_t comm_op_direction;
unsigned variables_seen;
nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
};
@@ -349,41 +352,25 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
}
}
- /* Stash off the current variables_seen bitmask. This way we can
- * restore it prior to matching in the commutative case below.
+ /* If this is a commutative expression and it's one of the first few, look
+ * up its direction for the current search operation. We'll use that value
+ * to possibly flip the sources for the match.
*/
- unsigned variables_seen_stash = state->variables_seen;
+ unsigned comm_op_flip =
+ (expr->comm_expr_idx >= 0 &&
+ expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
+ ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
bool matched = true;
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
- if (!match_value(expr->srcs[i], instr, i, num_components,
- swizzle, state)) {
+ if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
+ num_components, swizzle, state)) {
matched = false;
break;
}
}
- if (matched)
- return true;
-
- if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
- assert(nir_op_infos[instr->op].num_inputs == 2);
-
- /* Restore the variables_seen bitmask. If we don't do this, then we
- * could end up with an erroneous failure due to variables found in the
- * first match attempt above not matching those in the second.
- */
- state->variables_seen = variables_seen_stash;
-
- if (!match_value(expr->srcs[0], instr, 1, num_components,
- swizzle, state))
- return false;
-
- return match_value(expr->srcs[1], instr, 0, num_components,
- swizzle, state);
- } else {
- return false;
- }
+ return matched;
}
static unsigned
@@ -513,10 +500,26 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
struct match_state state;
state.inexact_match = false;
state.has_exact_alu = false;
- state.variables_seen = 0;
- if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
- swizzle, &state))
+ unsigned comm_expr_combinations =
+ 1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
+
+ bool found = false;
+ for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
+ /* The bitfield of directions is just the current iteration. Hooray for
+ * binary.
+ */
+ state.comm_op_direction = comb;
+ state.variables_seen = 0;
+
+ if (match_expression(search, instr,
+ instr->dest.dest.ssa.num_components,
+ swizzle, &state)) {
+ found = true;
+ break;
+ }
+ }
+ if (!found)
return NULL;
build->cursor = nir_before_instr(&instr->instr);
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index 1c78d0a3201..9dc09d2361c 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -132,6 +132,18 @@ typedef struct {
*/
bool inexact;
+ /* Commutative expression index. This is assigned by opt_algebraic.py when
+ * search structures are constructed and is a unique (to this structure)
+ * index within the commutative operation bitfield used for searching for
+ * all combinations of expressions containing commutative operations.
+ */
+ int8_t comm_expr_idx;
+
+ /* Number of commutative expressions in this expression including this one
+ * (if it is commutative).
+ */
+ uint8_t comm_exprs;
+
/* One of nir_op or nir_search_op */
uint16_t opcode;
const nir_search_value *srcs[4];