diff options
author | Ian Romanick <[email protected]> | 2019-06-24 16:00:29 -0700 |
---|---|---|
committer | Ian Romanick <[email protected]> | 2019-06-28 18:56:19 -0700 |
commit | 1a43cf9a40e46b27ee5d3a536c2ea2a3073ea7f1 (patch) | |
tree | 3e5985d86d7d50a37ce0c78401f05fbcda90fa5f /src/compiler/nir/nir_algebraic.py | |
parent | cae1af4339e4327f4cce106c534c101f09276382 (diff) |
nir/algebraic: Don't mark expression with duplicate sources as commutative
There is no reason to mark the fmul in the expression
('fmul', ('fadd', a, b), ('fadd', a, b))
as commutative. If a source of an instruction doesn't match one of the
('fadd', a, b) patterns, it won't match the other either.
This change is enough to make this pattern work:
('~fadd@32', ('fmul', ('fadd', 1.0, ('fneg', a)),
('fadd', 1.0, ('fneg', a))),
('fmul', ('flrp', a, 1.0, a), b))
This pattern has 5 commutative expressions (versus a limit of 4), but
the first fmul does not need to be commutative.
No shader-db change on any Intel platform. No shader-db run-time
difference on a certain 36-core / 72-thread system at 95% confidence
(n=20).
There are more subpatterns that could be marked as non-commutative, but
detecting these is more challenging. For example, this fadd:
('fadd', ('fmul', a, b), ('fmul', a, c))
The first fadd:
('fmul', ('fadd', a, b), ('fadd', a, b))
And this fadd:
('flt', ('fadd', a, b), 0.0)
This last case may be easier to detect. If all sources are variables
and they are the only instances of those variables, then the pattern can
be marked as non-commutative. It's probably not worth the effort now,
but if we end up with some patterns that bump up on the limit again, it
may be worth revisiting.
v2: Update the comment about the explicit "len(self.sources)" check to
be more clear about why it is necessary. Requested by Connor. Many
Python fixes style / idom fixes suggested by Dylan. Add missing (!!!)
opcode check in Expression::__eq__ method. This bug is the reason the
expected number of commutative expressions in the bitfield_reverse
pattern changed from 61 to 45 in the first version of this patch.
v3: Use all() in Expression::__eq__ method. Suggested by Connor.
Revert away from using __eq__ overloads. The "equality" implementation
of Constant and Variable needed for commutativity pruning is weaker than
the one needed for propagating and validating bit sizes. Using actual
equality caused the pruning to fail for my ('fmul', ('fadd', 1, a),
('fadd', 1, a)) case. I changed the name to "equivalent" rather than
the previous "same_as" to further differentiate it from __eq__.
Reviewed-by: Connor Abbott <[email protected]>
Reviewed-by: Dylan Baker <[email protected]>
Diffstat (limited to 'src/compiler/nir/nir_algebraic.py')
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index d15d4ba3d67..d7f8e48dec8 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -269,6 +269,19 @@ class Constant(Value): elif isinstance(self.value, float): return "nir_type_float" + def equivalent(self, other): + """Check that two constants are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + """ + if not isinstance(other, type(self)): + return False + + return self.value == other.value + _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" r"(?P<cond>\([^\)]+\))?") @@ -313,6 +326,19 @@ class Variable(Value): elif self.required_type == 'float': return "nir_type_float" + def equivalent(self, other): + """Check that two variables are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + """ + if not isinstance(other, type(self)): + return False + + return self.index == other.index + _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" r"(?P<cond>\([^\)]+\))?") @@ -352,12 +378,41 @@ class Expression(Value): self.__index_comm_exprs(0) + def equivalent(self, other): + """Check that two variables are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + This implementation does not check for equivalence due to commutativity, + but it could. + + """ + if not isinstance(other, type(self)): + return False + + if len(self.sources) != len(other.sources): + return False + + if self.opcode != other.opcode: + return False + + return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) + def __index_comm_exprs(self, base_idx): """Recursively count and index commutative expressions """ self.comm_exprs = 0 + + # A note about the explicit "len(self.sources)" check. The list of + # sources comes from user input, and that input might be bad. Check + # that the expected second source exists before accessing it. Without + # this check, a unit test that does "('iadd', 'a')" will crash. if self.opcode not in conv_opcode_types and \ - "2src_commutative" in opcodes[self.opcode].algebraic_properties: + "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ + len(self.sources) >= 2 and \ + not self.sources[0].equivalent(self.sources[1]): self.comm_expr_idx = base_idx self.comm_exprs += 1 else: |