summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/compiler/nir/nir_algebraic.py4
-rw-r--r--src/compiler/nir/nir_instr_set.c23
-rw-r--r--src/compiler/nir/nir_search.c6
3 files changed, 24 insertions, 9 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index d945c1a8075..aa4e9778a43 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -796,12 +796,12 @@ class TreeAutomaton(object):
self.opcodes = self.IndexMap()
def get_item(opcode, children, pattern=None):
- commutative = len(children) == 2 \
+ commutative = len(children) >= 2 \
and "2src_commutative" in opcodes[opcode].algebraic_properties
item = self.items.setdefault((opcode, children),
self.Item(opcode, children))
if commutative:
- self.items[opcode, (children[1], children[0])] = item
+ self.items[opcode, (children[1], children[0]) + children[2:]] = item
if pattern is not None:
item.patterns.append(pattern)
return item
diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c
index c6a69d345c9..80c03127810 100644
--- a/src/compiler/nir/nir_instr_set.c
+++ b/src/compiler/nir/nir_instr_set.c
@@ -57,7 +57,8 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr)
/* We explicitly don't hash instr->dest.dest.exact */
if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) {
- assert(nir_op_infos[instr->op].num_inputs == 2);
+ assert(nir_op_infos[instr->op].num_inputs >= 2);
+
uint32_t hash0 = hash_alu_src(hash, &instr->src[0],
nir_ssa_alu_instr_src_components(instr, 0));
uint32_t hash1 = hash_alu_src(hash, &instr->src[1],
@@ -69,6 +70,11 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr)
* collision. Either addition or multiplication will also work.
*/
hash = hash0 * hash1;
+
+ for (unsigned i = 2; i < nir_op_infos[instr->op].num_inputs; i++) {
+ hash = hash_alu_src(hash, &instr->src[i],
+ nir_ssa_alu_instr_src_components(instr, i));
+ }
} else {
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
hash = hash_alu_src(hash, &instr->src[i],
@@ -529,11 +535,16 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2)
/* We explicitly don't hash instr->dest.dest.exact */
if (nir_op_infos[alu1->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) {
- assert(nir_op_infos[alu1->op].num_inputs == 2);
- return (nir_alu_srcs_equal(alu1, alu2, 0, 0) &&
- nir_alu_srcs_equal(alu1, alu2, 1, 1)) ||
- (nir_alu_srcs_equal(alu1, alu2, 0, 1) &&
- nir_alu_srcs_equal(alu1, alu2, 1, 0));
+ if ((!nir_alu_srcs_equal(alu1, alu2, 0, 0) ||
+ !nir_alu_srcs_equal(alu1, alu2, 1, 1)) &&
+ (!nir_alu_srcs_equal(alu1, alu2, 0, 1) ||
+ !nir_alu_srcs_equal(alu1, alu2, 1, 0)))
+ return false;
+
+ for (unsigned i = 2; i < nir_op_infos[alu1->op].num_inputs; i++) {
+ if (!nir_alu_srcs_equal(alu1, alu2, i, i))
+ return false;
+ }
} else {
for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
if (!nir_alu_srcs_equal(alu1, alu2, i, i))
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 6d3fbf7f7ba..3ddda7ca332 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -408,7 +408,11 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
bool matched = true;
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
- if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
+ /* 2src_commutative instructions that have 3 sources are only commutative
+ * in the first two sources. Source 2 is always source 2.
+ */
+ if (!match_value(expr->srcs[i], instr,
+ i < 2 ? i ^ comm_op_flip : i,
num_components, swizzle, state)) {
matched = false;
break;