summaryrefslogtreecommitdiffstats
path: root/src/compiler/nir/nir_algebraic.py
diff options
context:
space:
mode:
authorIan Romanick <[email protected]>2019-06-24 16:00:29 -0700
committerIan Romanick <[email protected]>2019-06-28 18:56:19 -0700
commit1a43cf9a40e46b27ee5d3a536c2ea2a3073ea7f1 (patch)
tree3e5985d86d7d50a37ce0c78401f05fbcda90fa5f /src/compiler/nir/nir_algebraic.py
parentcae1af4339e4327f4cce106c534c101f09276382 (diff)
nir/algebraic: Don't mark expression with duplicate sources as commutative
There is no reason to mark the fmul in the expression ('fmul', ('fadd', a, b), ('fadd', a, b)) as commutative. If a source of an instruction doesn't match one of the ('fadd', a, b) patterns, it won't match the other either. This change is enough to make this pattern work: ('~fadd@32', ('fmul', ('fadd', 1.0, ('fneg', a)), ('fadd', 1.0, ('fneg', a))), ('fmul', ('flrp', a, 1.0, a), b)) This pattern has 5 commutative expressions (versus a limit of 4), but the first fmul does not need to be commutative. No shader-db change on any Intel platform. No shader-db run-time difference on a certain 36-core / 72-thread system at 95% confidence (n=20). There are more subpatterns that could be marked as non-commutative, but detecting these is more challenging. For example, this fadd: ('fadd', ('fmul', a, b), ('fmul', a, c)) The first fadd: ('fmul', ('fadd', a, b), ('fadd', a, b)) And this fadd: ('flt', ('fadd', a, b), 0.0) This last case may be easier to detect. If all sources are variables and they are the only instances of those variables, then the pattern can be marked as non-commutative. It's probably not worth the effort now, but if we end up with some patterns that bump up on the limit again, it may be worth revisiting. v2: Update the comment about the explicit "len(self.sources)" check to be more clear about why it is necessary. Requested by Connor. Many Python fixes style / idom fixes suggested by Dylan. Add missing (!!!) opcode check in Expression::__eq__ method. This bug is the reason the expected number of commutative expressions in the bitfield_reverse pattern changed from 61 to 45 in the first version of this patch. v3: Use all() in Expression::__eq__ method. Suggested by Connor. Revert away from using __eq__ overloads. The "equality" implementation of Constant and Variable needed for commutativity pruning is weaker than the one needed for propagating and validating bit sizes. Using actual equality caused the pruning to fail for my ('fmul', ('fadd', 1, a), ('fadd', 1, a)) case. I changed the name to "equivalent" rather than the previous "same_as" to further differentiate it from __eq__. Reviewed-by: Connor Abbott <[email protected]> Reviewed-by: Dylan Baker <[email protected]>
Diffstat (limited to 'src/compiler/nir/nir_algebraic.py')
-rw-r--r--src/compiler/nir/nir_algebraic.py57
1 files changed, 56 insertions, 1 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index d15d4ba3d67..d7f8e48dec8 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -269,6 +269,19 @@ class Constant(Value):
elif isinstance(self.value, float):
return "nir_type_float"
+ def equivalent(self, other):
+ """Check that two constants are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.value == other.value
+
_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
r"(?P<cond>\([^\)]+\))?")
@@ -313,6 +326,19 @@ class Variable(Value):
elif self.required_type == 'float':
return "nir_type_float"
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.index == other.index
+
_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
r"(?P<cond>\([^\)]+\))?")
@@ -352,12 +378,41 @@ class Expression(Value):
self.__index_comm_exprs(0)
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ This implementation does not check for equivalence due to commutativity,
+ but it could.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.sources) != len(other.sources):
+ return False
+
+ if self.opcode != other.opcode:
+ return False
+
+ return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
def __index_comm_exprs(self, base_idx):
"""Recursively count and index commutative expressions
"""
self.comm_exprs = 0
+
+ # A note about the explicit "len(self.sources)" check. The list of
+ # sources comes from user input, and that input might be bad. Check
+ # that the expected second source exists before accessing it. Without
+ # this check, a unit test that does "('iadd', 'a')" will crash.
if self.opcode not in conv_opcode_types and \
- "2src_commutative" in opcodes[self.opcode].algebraic_properties:
+ "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+ len(self.sources) >= 2 and \
+ not self.sources[0].equivalent(self.sources[1]):
self.comm_expr_idx = base_idx
self.comm_exprs += 1
else: