diff options
Diffstat (limited to 'src/compiler')
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index d15d4ba3d67..d7f8e48dec8 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -269,6 +269,19 @@ class Constant(Value): elif isinstance(self.value, float): return "nir_type_float" + def equivalent(self, other): + """Check that two constants are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + """ + if not isinstance(other, type(self)): + return False + + return self.value == other.value + _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" r"(?P<cond>\([^\)]+\))?") @@ -313,6 +326,19 @@ class Variable(Value): elif self.required_type == 'float': return "nir_type_float" + def equivalent(self, other): + """Check that two variables are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + """ + if not isinstance(other, type(self)): + return False + + return self.index == other.index + _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" r"(?P<cond>\([^\)]+\))?") @@ -352,12 +378,41 @@ class Expression(Value): self.__index_comm_exprs(0) + def equivalent(self, other): + """Check that two variables are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + This implementation does not check for equivalence due to commutativity, + but it could. + + """ + if not isinstance(other, type(self)): + return False + + if len(self.sources) != len(other.sources): + return False + + if self.opcode != other.opcode: + return False + + return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) + def __index_comm_exprs(self, base_idx): """Recursively count and index commutative expressions """ self.comm_exprs = 0 + + # A note about the explicit "len(self.sources)" check. The list of + # sources comes from user input, and that input might be bad. Check + # that the expected second source exists before accessing it. Without + # this check, a unit test that does "('iadd', 'a')" will crash. if self.opcode not in conv_opcode_types and \ - "2src_commutative" in opcodes[self.opcode].algebraic_properties: + "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ + len(self.sources) >= 2 and \ + not self.sources[0].equivalent(self.sources[1]): self.comm_expr_idx = base_idx self.comm_exprs += 1 else: |