diff options
author | Ian Romanick <[email protected]> | 2019-05-09 15:33:11 -0700 |
---|---|---|
committer | Ian Romanick <[email protected]> | 2019-05-14 11:25:02 -0700 |
commit | e049a9c92b3048f2d28d5a36f0dd780b19fe4b2a (patch) | |
tree | 1698453d02403c25074c998e815585669ed2c2c0 /src | |
parent | ede45bf9cfe20578712ae874f7a3d18fd86a1297 (diff) |
nir: Add support for 2src_commutative ops that have 3 sources
v2: Instead of handling 3 sources as a special case, generalize with
loops to N sources. Suggested by Jason.
v3: Further generalize by only checking that number of sources is >= 2.
Suggested by Jason.
Reviewed-by: Jason Ekstrand <[email protected]>
Diffstat (limited to 'src')
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 4 | ||||
-rw-r--r-- | src/compiler/nir/nir_instr_set.c | 23 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.c | 6 |
3 files changed, 24 insertions, 9 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index d945c1a8075..aa4e9778a43 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -796,12 +796,12 @@ class TreeAutomaton(object): self.opcodes = self.IndexMap() def get_item(opcode, children, pattern=None): - commutative = len(children) == 2 \ + commutative = len(children) >= 2 \ and "2src_commutative" in opcodes[opcode].algebraic_properties item = self.items.setdefault((opcode, children), self.Item(opcode, children)) if commutative: - self.items[opcode, (children[1], children[0])] = item + self.items[opcode, (children[1], children[0]) + children[2:]] = item if pattern is not None: item.patterns.append(pattern) return item diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c index c6a69d345c9..80c03127810 100644 --- a/src/compiler/nir/nir_instr_set.c +++ b/src/compiler/nir/nir_instr_set.c @@ -57,7 +57,8 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr) /* We explicitly don't hash instr->dest.dest.exact */ if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) { - assert(nir_op_infos[instr->op].num_inputs == 2); + assert(nir_op_infos[instr->op].num_inputs >= 2); + uint32_t hash0 = hash_alu_src(hash, &instr->src[0], nir_ssa_alu_instr_src_components(instr, 0)); uint32_t hash1 = hash_alu_src(hash, &instr->src[1], @@ -69,6 +70,11 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr) * collision. Either addition or multiplication will also work. */ hash = hash0 * hash1; + + for (unsigned i = 2; i < nir_op_infos[instr->op].num_inputs; i++) { + hash = hash_alu_src(hash, &instr->src[i], + nir_ssa_alu_instr_src_components(instr, i)); + } } else { for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { hash = hash_alu_src(hash, &instr->src[i], @@ -529,11 +535,16 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2) /* We explicitly don't hash instr->dest.dest.exact */ if (nir_op_infos[alu1->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) { - assert(nir_op_infos[alu1->op].num_inputs == 2); - return (nir_alu_srcs_equal(alu1, alu2, 0, 0) && - nir_alu_srcs_equal(alu1, alu2, 1, 1)) || - (nir_alu_srcs_equal(alu1, alu2, 0, 1) && - nir_alu_srcs_equal(alu1, alu2, 1, 0)); + if ((!nir_alu_srcs_equal(alu1, alu2, 0, 0) || + !nir_alu_srcs_equal(alu1, alu2, 1, 1)) && + (!nir_alu_srcs_equal(alu1, alu2, 0, 1) || + !nir_alu_srcs_equal(alu1, alu2, 1, 0))) + return false; + + for (unsigned i = 2; i < nir_op_infos[alu1->op].num_inputs; i++) { + if (!nir_alu_srcs_equal(alu1, alu2, i, i)) + return false; + } } else { for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) { if (!nir_alu_srcs_equal(alu1, alu2, i, i)) diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index 6d3fbf7f7ba..3ddda7ca332 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -408,7 +408,11 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, bool matched = true; for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { - if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip, + /* 2src_commutative instructions that have 3 sources are only commutative + * in the first two sources. Source 2 is always source 2. + */ + if (!match_value(expr->srcs[i], instr, + i < 2 ? i ^ comm_op_flip : i, num_components, swizzle, state)) { matched = false; break; |