diff options
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 20 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.c | 61 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.h | 12 |
3 files changed, 64 insertions, 29 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index fe9d1051e67..d4b3bb5957f 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -114,6 +114,7 @@ static const ${val.c_type} ${val.name} = { ${val.cond if val.cond else 'NULL'}, % elif isinstance(val, Expression): ${'true' if val.inexact else 'false'}, + ${val.comm_expr_idx}, ${val.comm_exprs}, ${val.c_opcode()}, { ${', '.join(src.c_ptr for src in val.sources)} }, ${val.cond if val.cond else 'NULL'}, @@ -307,6 +308,25 @@ class Expression(Value): 'Expression cannot use an unsized conversion opcode with ' \ 'an explicit size; that\'s silly.' + self.__index_comm_exprs(0) + + def __index_comm_exprs(self, base_idx): + """Recursively count and index commutative expressions + """ + self.comm_exprs = 0 + if self.opcode not in conv_opcode_types and \ + "commutative" in opcodes[self.opcode].algebraic_properties: + self.comm_expr_idx = base_idx + self.comm_exprs += 1 + else: + self.comm_expr_idx = -1 + + for s in self.sources: + if isinstance(s, Expression): + s.__index_comm_exprs(base_idx + self.comm_exprs) + self.comm_exprs += s.comm_exprs + + return self.comm_exprs def c_opcode(self): if self.opcode in conv_opcode_types: diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index d257b639189..df27a2473ee 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -30,9 +30,12 @@ #include "nir_builder.h" #include "util/half_float.h" +#define NIR_SEARCH_MAX_COMM_OPS 4 + struct match_state { bool inexact_match; bool has_exact_alu; + uint8_t comm_op_direction; unsigned variables_seen; nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES]; }; @@ -349,41 +352,25 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, } } - /* Stash off the current variables_seen bitmask. This way we can - * restore it prior to matching in the commutative case below. + /* If this is a commutative expression and it's one of the first few, look + * up its direction for the current search operation. We'll use that value + * to possibly flip the sources for the match. */ - unsigned variables_seen_stash = state->variables_seen; + unsigned comm_op_flip = + (expr->comm_expr_idx >= 0 && + expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ? + ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0; bool matched = true; for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { - if (!match_value(expr->srcs[i], instr, i, num_components, - swizzle, state)) { + if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip, + num_components, swizzle, state)) { matched = false; break; } } - if (matched) - return true; - - if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) { - assert(nir_op_infos[instr->op].num_inputs == 2); - - /* Restore the variables_seen bitmask. If we don't do this, then we - * could end up with an erroneous failure due to variables found in the - * first match attempt above not matching those in the second. - */ - state->variables_seen = variables_seen_stash; - - if (!match_value(expr->srcs[0], instr, 1, num_components, - swizzle, state)) - return false; - - return match_value(expr->srcs[1], instr, 0, num_components, - swizzle, state); - } else { - return false; - } + return matched; } static unsigned @@ -513,10 +500,26 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, struct match_state state; state.inexact_match = false; state.has_exact_alu = false; - state.variables_seen = 0; - if (!match_expression(search, instr, instr->dest.dest.ssa.num_components, - swizzle, &state)) + unsigned comm_expr_combinations = + 1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS); + + bool found = false; + for (unsigned comb = 0; comb < comm_expr_combinations; comb++) { + /* The bitfield of directions is just the current iteration. Hooray for + * binary. + */ + state.comm_op_direction = comb; + state.variables_seen = 0; + + if (match_expression(search, instr, + instr->dest.dest.ssa.num_components, + swizzle, &state)) { + found = true; + break; + } + } + if (!found) return NULL; build->cursor = nir_before_instr(&instr->instr); diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index 1c78d0a3201..9dc09d2361c 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -132,6 +132,18 @@ typedef struct { */ bool inexact; + /* Commutative expression index. This is assigned by opt_algebraic.py when + * search structures are constructed and is a unique (to this structure) + * index within the commutative operation bitfield used for searching for + * all combinations of expressions containing commutative operations. + */ + int8_t comm_expr_idx; + + /* Number of commutative expressions in this expression including this one + * (if it is commutative). + */ + uint8_t comm_exprs; + /* One of nir_op or nir_search_op */ uint16_t opcode; const nir_search_value *srcs[4]; |