summaryrefslogtreecommitdiffstats
path: root/src/compiler/nir
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/nir')
-rw-r--r--src/compiler/nir/nir_algebraic.py520
-rw-r--r--src/compiler/nir/nir_search.c146
-rw-r--r--src/compiler/nir/nir_search.h17
3 files changed, 317 insertions, 366 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index 728196136ab..efd6e52cdb9 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -88,7 +88,7 @@ class Value(object):
__template = mako.template.Template("""
static const ${val.c_type} ${val.name} = {
- { ${val.type_enum}, ${val.bit_size} },
+ { ${val.type_enum}, ${val.c_bit_size} },
% if isinstance(val, Constant):
${val.type()}, { ${val.hex()} /* ${val.value} */ },
% elif isinstance(val, Variable):
@@ -112,6 +112,40 @@ static const ${val.c_type} ${val.name} = {
def __str__(self):
return self.in_val
+ def get_bit_size(self):
+ """Get the physical bit-size that has been chosen for this value, or if
+ there is none, the canonical value which currently represents this
+ bit-size class. Variables will be preferred, i.e. if there are any
+ variables in the equivalence class, the canonical value will be a
+ variable. We do this since we'll need to know which variable each value
+ is equivalent to when constructing the replacement expression. This is
+ the "find" part of the union-find algorithm.
+ """
+ bit_size = self
+
+ while isinstance(bit_size, Value):
+ if bit_size._bit_size is None:
+ break
+ bit_size = bit_size._bit_size
+
+ if bit_size is not self:
+ self._bit_size = bit_size
+ return bit_size
+
+ def set_bit_size(self, other):
+ """Make self.get_bit_size() return what other.get_bit_size() return
+ before calling this, or just "other" if it's a concrete bit-size. This is
+ the "union" part of the union-find algorithm.
+ """
+
+ self_bit_size = self.get_bit_size()
+ other_bit_size = other if isinstance(other, int) else other.get_bit_size()
+
+ if self_bit_size == other_bit_size:
+ return
+
+ self_bit_size._bit_size = other_bit_size
+
@property
def type_enum(self):
return "nir_search_value_" + self.type_str
@@ -124,6 +158,21 @@ static const ${val.c_type} ${val.name} = {
def c_ptr(self):
return "&{0}.value".format(self.name)
+ @property
+ def c_bit_size(self):
+ bit_size = self.get_bit_size()
+ if isinstance(bit_size, int):
+ return bit_size
+ elif isinstance(bit_size, Variable):
+ return -bit_size.index - 1
+ else:
+ # If the bit-size class is neither a variable, nor an actual bit-size, then
+ # - If it's in the search expression, we don't need to check anything
+ # - If it's in the replace expression, either it's ambiguous (in which
+ # case we'd reject it), or it equals the bit-size of the search value
+ # We represent these cases with a 0 bit-size.
+ return 0
+
def render(self):
return self.__template.render(val=self,
Constant=Constant,
@@ -140,14 +189,14 @@ class Constant(Value):
if isinstance(val, (str)):
m = _constant_re.match(val)
self.value = ast.literal_eval(m.group('value'))
- self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+ self._bit_size = int(m.group('bits')) if m.group('bits') else None
else:
self.value = val
- self.bit_size = 0
+ self._bit_size = None
if isinstance(self.value, bool):
- assert self.bit_size == 0 or self.bit_size == 32
- self.bit_size = 32
+ assert self._bit_size is None or self._bit_size == 32
+ self._bit_size = 32
def hex(self):
if isinstance(self.value, (bool)):
@@ -191,11 +240,11 @@ class Variable(Value):
self.is_constant = m.group('const') is not None
self.cond = m.group('cond')
self.required_type = m.group('type')
- self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+ self._bit_size = int(m.group('bits')) if m.group('bits') else None
if self.required_type == 'bool':
- assert self.bit_size == 0 or self.bit_size == 32
- self.bit_size = 32
+ assert self._bit_size is None or self._bit_size == 32
+ self._bit_size = 32
if self.required_type is not None:
assert self.required_type in ('float', 'bool', 'int', 'uint')
@@ -225,7 +274,7 @@ class Expression(Value):
assert m and m.group('opcode') is not None
self.opcode = m.group('opcode')
- self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+ self._bit_size = int(m.group('bits')) if m.group('bits') else None
self.inexact = m.group('inexact') is not None
self.cond = m.group('cond')
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
@@ -235,40 +284,6 @@ class Expression(Value):
srcs = "\n".join(src.render() for src in self.sources)
return srcs + super(Expression, self).render()
-class IntEquivalenceRelation(object):
- """A class representing an equivalence relation on integers.
-
- Each integer has a canonical form which is the maximum integer to which it
- is equivalent. Two integers are equivalent precisely when they have the
- same canonical form.
-
- The convention of maximum is explicitly chosen to make using it in
- BitSizeValidator easier because it means that an actual bit_size (if any)
- will always be the canonical form.
- """
- def __init__(self):
- self._remap = {}
-
- def get_canonical(self, x):
- """Get the canonical integer corresponding to x."""
- if x in self._remap:
- return self.get_canonical(self._remap[x])
- else:
- return x
-
- def add_equiv(self, a, b):
- """Add an equivalence and return the canonical form."""
- c = max(self.get_canonical(a), self.get_canonical(b))
- if a != c:
- assert a < c
- self._remap[a] = c
-
- if b != c:
- assert b < c
- self._remap[b] = c
-
- return c
-
class BitSizeValidator(object):
"""A class for validating bit sizes of expressions.
@@ -296,7 +311,7 @@ class BitSizeValidator(object):
inference can be ambiguous or contradictory. Consider, for instance, the
following transformation:
- (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
+ (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
This transformation can potentially cause a problem because usub_borrow is
well-defined for any bit-size of integer. However, b2i always generates a
@@ -315,217 +330,250 @@ class BitSizeValidator(object):
generate any code. This ensures that bugs are caught at compile time
rather than at run time.
- The basic operation of the validator is very similar to the bitsize_tree in
- nir_search only a little more subtle. Instead of simply tracking bit
- sizes, it tracks "bit classes" where each class is represented by an
- integer. A value of 0 means we don't know anything yet, positive values
- are actual bit-sizes, and negative values are used to track equivalence
- classes of sizes that must be the same but have yet to receive an actual
- size. The first stage uses the bitsize_tree algorithm to assign bit
- classes to each variable. If it ever comes across an inconsistency, it
- assert-fails. Then the second stage uses that information to prove that
- the resulting expression can always validly be constructed.
- """
-
- def __init__(self, varset):
- self._num_classes = 0
- self._var_classes = [0] * len(varset.names)
- self._class_relation = IntEquivalenceRelation()
-
- def validate(self, search, replace):
- search_dst_class = self._propagate_bit_size_up(search)
- if search_dst_class == 0:
- search_dst_class = self._new_class()
- self._propagate_bit_class_down(search, search_dst_class)
-
- replace_dst_class = self._validate_bit_class_up(replace)
- if replace_dst_class != 0:
- assert search_dst_class != 0, \
- 'Search expression matches any bit size but replace ' \
- 'expression can only generate {0}-bit values' \
- .format(replace_dst_class)
-
- assert search_dst_class == replace_dst_class, \
- 'Search expression matches any {0}-bit values but replace ' \
- 'expression can only generates {1}-bit values' \
- .format(search_dst_class, replace_dst_class)
-
- self._validate_bit_class_down(replace, search_dst_class)
-
- def _new_class(self):
- self._num_classes += 1
- return -self._num_classes
-
- def _set_var_bit_class(self, var, bit_class):
- assert bit_class != 0
- var_class = self._var_classes[var.index]
- if var_class == 0:
- self._var_classes[var.index] = bit_class
- else:
- canon_var_class = self._class_relation.get_canonical(var_class)
- canon_bit_class = self._class_relation.get_canonical(bit_class)
- assert canon_var_class < 0 or canon_bit_class < 0 or \
- canon_var_class == canon_bit_class, \
- 'Variable {0} cannot be both {1}-bit and {2}-bit' \
- .format(str(var), bit_class, var_class)
- var_class = self._class_relation.add_equiv(var_class, bit_class)
- self._var_classes[var.index] = var_class
-
- def _get_var_bit_class(self, var):
- return self._class_relation.get_canonical(self._var_classes[var.index])
-
- def _propagate_bit_size_up(self, val):
- if isinstance(val, (Constant, Variable)):
- return val.bit_size
-
- elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- val.common_size = 0
- for i in range(nir_op.num_inputs):
- src_bits = self._propagate_bit_size_up(val.sources[i])
- if src_bits == 0:
- continue
-
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- assert src_bits == src_type_bits, \
- 'Source {0} of nir_op_{1} must be a {2}-bit value but ' \
- 'the only possible matched values are {3}-bit: {4}' \
- .format(i, val.opcode, src_type_bits, src_bits, str(val))
- else:
- assert val.common_size == 0 or src_bits == val.common_size, \
- 'Expression cannot have both {0}-bit and {1}-bit ' \
- 'variable-width sources: {2}' \
- .format(src_bits, val.common_size, str(val))
- val.common_size = src_bits
-
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert val.bit_size == 0 or val.bit_size == dst_type_bits, \
- 'nir_op_{0} produces a {1}-bit result but a {2}-bit ' \
- 'result was requested' \
- .format(val.opcode, dst_type_bits, val.bit_size)
- return dst_type_bits
- else:
- if val.common_size != 0:
- assert val.bit_size == 0 or val.bit_size == val.common_size, \
- 'Variable width expression musr be {0}-bit based on ' \
- 'the sources but a {1}-bit result was requested: {2}' \
- .format(val.common_size, val.bit_size, str(val))
- else:
- val.common_size = val.bit_size
- return val.common_size
+ Each value maintains a "bit-size class", which is either an actual bit size
+ or an equivalence class with other values that must have the same bit size.
+ The validator works by combining bit-size classes with each other according
+ to the NIR rules outlined above, checking that there are no inconsistencies.
+ When doing this for the replacement expression, we make sure to never change
+ the equivalence class of any of the search values. We could make the example
+ transforms above work by doing some extra run-time checking of the search
+ expression, but we make the user specify those constraints themselves, to
+ avoid any surprises. Since the replacement bitsizes can only be connected to
+ the source bitsize via variables (variables must have the same bitsize in
+ the source and replacment expressions) or the roots of the expression (the
+ replacement expression must produce the same bit size as the search
+ expression), we prevent merging a variable with anything when processing the
+ replacement expression, or specializing the search bitsize
+ with anything. The former prevents
- def _propagate_bit_class_down(self, val, bit_class):
- if isinstance(val, Constant):
- assert val.bit_size == 0 or val.bit_size == bit_class, \
- 'Constant is {0}-bit but a {1}-bit value is required: {2}' \
- .format(val.bit_size, bit_class, str(val))
+ (('bcsel', a, b, 0), ('iand', a, b))
- elif isinstance(val, Variable):
- assert val.bit_size == 0 or val.bit_size == bit_class, \
- 'Variable is {0}-bit but a {1}-bit value is required: {2}' \
- .format(val.bit_size, bit_class, str(val))
- self._set_var_bit_class(val, bit_class)
+ from being allowed, since we'd have to merge the bitsizes for a and b due to
+ the 'iand', while the latter prevents
- elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert bit_class == 0 or bit_class == dst_type_bits, \
- 'nir_op_{0} produces a {1}-bit result but the parent ' \
- 'expression wants a {2}-bit value' \
- .format(val.opcode, dst_type_bits, bit_class)
- else:
- assert val.common_size == 0 or val.common_size == bit_class, \
- 'Variable-width expression produces a {0}-bit result ' \
- 'based on the source widths but the parent expression ' \
- 'wants a {1}-bit value: {2}' \
- .format(val.common_size, bit_class, str(val))
- val.common_size = bit_class
-
- if val.common_size:
- common_class = val.common_size
- elif nir_op.num_inputs:
- # If we got here then we have no idea what the actual size is.
- # Instead, we use a generic class
- common_class = self._new_class()
-
- for i in range(nir_op.num_inputs):
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- self._propagate_bit_class_down(val.sources[i], src_type_bits)
- else:
- self._propagate_bit_class_down(val.sources[i], common_class)
+ (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
- def _validate_bit_class_up(self, val):
- if isinstance(val, Constant):
- return val.bit_size
+ from being allowed, since the search expression has the bit size of a and b,
+ which can't be specialized to 32 which is the bitsize of the replace
+ expression. It also prevents something like:
- elif isinstance(val, Variable):
- var_class = self._get_var_bit_class(val)
- # By the time we get to validation, every variable should have a class
- assert var_class != 0
+ (('b2i', ('i2b', a)), ('ineq', a, 0))
- # If we have an explicit size provided by the user, the variable
- # *must* exactly match the search. It cannot be implicitly sized
- # because otherwise we could end up with a conflict at runtime.
- assert val.bit_size == 0 or val.bit_size == var_class
+ since the bitsize of 'b2i', which can be anything, can't be specialized to
+ the bitsize of a.
- return var_class
+ After doing all this, we check that every subexpression of the replacement
+ was assigned a constant bitsize, the bitsize of a variable, or the bitsize
+ of the search expresssion, since those are the things that are known when
+ constructing the replacement expresssion. Finally, we record the bitsize
+ needed in nir_search_value so that we know what to do when building the
+ replacement expression.
+ """
+ def __init__(self, varset):
+ self._var_classes = [None] * len(varset.names)
+
+ def compare_bitsizes(self, a, b):
+ """Determines which bitsize class is a specialization of the other, or
+ whether neither is. When we merge two different bitsizes, the
+ less-specialized bitsize always points to the more-specialized one, so
+ that calling get_bit_size() always gets you the most specialized bitsize.
+ The specialization partial order is given by:
+ - Physical bitsizes are always the most specialized, and a different
+ bitsize can never specialize another.
+ - In the search expression, variables can always be specialized to each
+ other and to physical bitsizes. In the replace expression, we disallow
+ this to avoid adding extra constraints to the search expression that
+ the user didn't specify.
+ - Expressions and constants without a bitsize can always be specialized to
+ each other and variables, but not the other way around.
+
+ We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
+ and None if they are not comparable (neither a <= b nor b <= a).
+ """
+ if isinstance(a, int):
+ if isinstance(b, int):
+ return 0 if a == b else None
+ elif isinstance(b, Variable):
+ return -1 if self.is_search else None
+ else:
+ return -1
+ elif isinstance(a, Variable):
+ if isinstance(b, int):
+ return 1 if self.is_search else None
+ elif isinstance(b, Variable):
+ return 0 if self.is_search or a.index == b.index else None
+ else:
+ return -1
+ else:
+ if isinstance(b, int):
+ return 1
+ elif isinstance(b, Variable):
+ return 1
+ else:
+ return 0
+
+ def unify_bit_size(self, a, b, error_msg):
+ """Record that a must have the same bit-size as b. If both
+ have been assigned conflicting physical bit-sizes, call "error_msg" with
+ the bit-sizes of self and other to get a message and raise an error.
+ In the replace expression, disallow merging variables with other
+ variables and physical bit-sizes as well.
+ """
+ a_bit_size = a.get_bit_size()
+ b_bit_size = b if isinstance(b, int) else b.get_bit_size()
+
+ cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
+
+ assert cmp_result is not None, \
+ error_msg(a_bit_size, b_bit_size)
+
+ if cmp_result < 0:
+ b_bit_size.set_bit_size(a)
+ elif not isinstance(a_bit_size, int):
+ a_bit_size.set_bit_size(b)
+
+ def merge_variables(self, val):
+ """Perform the first part of type inference by merging all the different
+ uses of the same variable. We always do this as if we're in the search
+ expression, even if we're actually not, since otherwise we'd get errors
+ if the search expression specified some constraint but the replace
+ expression didn't, because we'd be merging a variable and a constant.
+ """
+ if isinstance(val, Variable):
+ if self._var_classes[val.index] is None:
+ self._var_classes[val.index] = val
+ else:
+ other = self._var_classes[val.index]
+ self.unify_bit_size(other, val,
+ lambda other_bit_size, bit_size:
+ 'Variable {} has conflicting bit size requirements: ' \
+ 'it must have bit size {} and {}'.format(
+ val.var_name, other_bit_size, bit_size))
elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- val.common_class = 0
- for i in range(nir_op.num_inputs):
- src_class = self._validate_bit_class_up(val.sources[i])
- if src_class == 0:
+ for src in val.sources:
+ self.merge_variables(src)
+
+ def validate_value(self, val):
+ """Validate the an expression by performing classic Hindley-Milner
+ type inference on bitsizes. This will detect if there are any conflicting
+ requirements, and unify variables so that we know which variables must
+ have the same bitsize. If we're operating on the replace expression, we
+ will refuse to merge different variables together or merge a variable
+ with a constant, in order to prevent surprises due to rules unexpectedly
+ not matching at runtime.
+ """
+ if not isinstance(val, Expression):
+ return
+
+ nir_op = opcodes[val.opcode]
+ assert len(val.sources) == nir_op.num_inputs, \
+ "Expression {} has {} sources, expected {}".format(
+ val, len(val.sources), nir_op.num_inputs)
+
+ for src in val.sources:
+ self.validate_value(src)
+
+ dst_type_bits = type_bits(nir_op.output_type)
+
+ # First, unify all the sources. That way, an error coming up because two
+ # sources have an incompatible bit-size won't produce an error message
+ # involving the destination.
+ first_unsized_src = None
+ for src_type, src in zip(nir_op.input_types, val.sources):
+ src_type_bits = type_bits(src_type)
+ if src_type_bits == 0:
+ if first_unsized_src is None:
+ first_unsized_src = src
continue
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- assert src_class == src_type_bits
+ if self.is_search:
+ self.unify_bit_size(first_unsized_src, src,
+ lambda first_unsized_src_bit_size, src_bit_size:
+ 'Source {} of {} must have bit size {}, while source {} ' \
+ 'must have incompatible bit size {}'.format(
+ first_unsized_src, val, first_unsized_src_bit_size,
+ src, src_bit_size))
else:
- assert val.common_class == 0 or src_class == val.common_class
- val.common_class = src_class
-
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert val.bit_size == 0 or val.bit_size == dst_type_bits
- return dst_type_bits
+ self.unify_bit_size(first_unsized_src, src,
+ lambda first_unsized_src_bit_size, src_bit_size:
+ 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
+ 'of {} may not have the same bit size when building the ' \
+ 'replacement expression.'.format(
+ first_unsized_src, first_unsized_src_bit_size, src,
+ src_bit_size, val))
else:
- if val.common_class != 0:
- assert val.bit_size == 0 or val.bit_size == val.common_class
+ if self.is_search:
+ self.unify_bit_size(src, src_type_bits,
+ lambda src_bit_size, unused:
+ '{} must have {} bits, but as a source of nir_op_{} '\
+ 'it must have {} bits'.format(
+ src, src_bit_size, nir_op.name, src_type_bits))
+ else:
+ self.unify_bit_size(src, src_type_bits,
+ lambda src_bit_size, unused:
+ '{} has the bit size of {}, but as a source of ' \
+ 'nir_op_{} it must have {} bits, which may not be the ' \
+ 'same'.format(
+ src, src_bit_size, nir_op.name, src_type_bits))
+
+ if dst_type_bits == 0:
+ if first_unsized_src is not None:
+ if self.is_search:
+ self.unify_bit_size(val, first_unsized_src,
+ lambda val_bit_size, src_bit_size:
+ '{} must have the bit size of {}, while its source {} ' \
+ 'must have incompatible bit size {}'.format(
+ val, val_bit_size, first_unsized_src, src_bit_size))
else:
- val.common_class = val.bit_size
- return val.common_class
+ self.unify_bit_size(val, first_unsized_src,
+ lambda val_bit_size, src_bit_size:
+ '{} must have {} bits, but its source {} ' \
+ '(bit size of {}) may not have that bit size ' \
+ 'when building the replacement.'.format(
+ val, val_bit_size, first_unsized_src, src_bit_size))
+ else:
+ self.unify_bit_size(val, dst_type_bits,
+ lambda dst_bit_size, unused:
+ '{} must have {} bits, but as a destination of nir_op_{} ' \
+ 'it must have {} bits'.format(
+ val, dst_bit_size, nir_op.name, dst_type_bits))
+
+ def validate_replace(self, val, search):
+ bit_size = val.get_bit_size()
+ assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
+ bit_size == search.get_bit_size(), \
+ 'Ambiguous bit size for replacement value {}: ' \
+ 'it cannot be deduced from a variable, a fixed bit size ' \
+ 'somewhere, or the search expression.'.format(val)
+
+ if isinstance(val, Expression):
+ for src in val.sources:
+ self.validate_replace(src, search)
- def _validate_bit_class_down(self, val, bit_class):
- # At this point, everything *must* have a bit class. Otherwise, we have
- # a value we don't know how to define.
- assert bit_class != 0
+ def validate(self, search, replace):
+ self.is_search = True
+ self.merge_variables(search)
+ self.merge_variables(replace)
+ self.validate_value(search)
- if isinstance(val, Constant):
- assert val.bit_size == 0 or val.bit_size == bit_class
+ self.is_search = False
+ self.validate_value(replace)
- elif isinstance(val, Variable):
- assert val.bit_size == 0 or val.bit_size == bit_class
+ # Check that search is always more specialized than replace. Note that
+ # we're doing this in replace mode, disallowing merging variables.
+ search_bit_size = search.get_bit_size()
+ replace_bit_size = replace.get_bit_size()
+ cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
- elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert bit_class == dst_type_bits
- else:
- assert val.common_class == 0 or val.common_class == bit_class
- val.common_class = bit_class
+ assert cmp_result is not None and cmp_result <= 0, \
+ 'The search expression bit size {} and replace expression ' \
+ 'bit size {} may not be the same'.format(
+ search_bit_size, replace_bit_size)
- for i in range(nir_op.num_inputs):
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- self._validate_bit_class_down(val.sources[i], src_type_bits)
- else:
- self._validate_bit_class_down(val.sources[i], val.common_class)
+ replace.set_bit_size(search)
+
+ self.validate_replace(replace, search)
_optimization_ids = itertools.count()
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 0270302fd3d..a41fca876d5 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -118,7 +118,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
/* If the value has a specific bit size and it doesn't match, bail */
- if (value->bit_size &&
+ if (value->bit_size > 0 &&
nir_src_bit_size(instr->src[src].src) != value->bit_size)
return false;
@@ -228,7 +228,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
assert(instr->dest.dest.is_ssa);
- if (expr->value.bit_size &&
+ if (expr->value.bit_size > 0 &&
instr->dest.dest.ssa.bit_size != expr->value.bit_size)
return false;
@@ -290,128 +290,21 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
}
}
-typedef struct bitsize_tree {
- unsigned num_srcs;
- struct bitsize_tree *srcs[4];
-
- unsigned common_size;
- bool is_src_sized[4];
- bool is_dest_sized;
-
- unsigned dest_size;
- unsigned src_size[4];
-} bitsize_tree;
-
-static bitsize_tree *
-build_bitsize_tree(void *mem_ctx, struct match_state *state,
- const nir_search_value *value)
-{
- bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
-
- switch (value->type) {
- case nir_search_value_expression: {
- nir_search_expression *expr = nir_search_value_as_expression(value);
- nir_op_info info = nir_op_infos[expr->opcode];
- tree->num_srcs = info.num_inputs;
- tree->common_size = 0;
- for (unsigned i = 0; i < info.num_inputs; i++) {
- tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
- if (tree->is_src_sized[i])
- tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
- tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
- }
- tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
- if (tree->is_dest_sized)
- tree->dest_size = nir_alu_type_get_type_size(info.output_type);
- break;
- }
-
- case nir_search_value_variable: {
- nir_search_variable *var = nir_search_value_as_variable(value);
- tree->num_srcs = 0;
- tree->is_dest_sized = true;
- tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
- break;
- }
-
- case nir_search_value_constant: {
- tree->num_srcs = 0;
- tree->is_dest_sized = false;
- tree->common_size = 0;
- break;
- }
- }
-
- if (value->bit_size) {
- assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
- tree->common_size = value->bit_size;
- }
-
- return tree;
-}
-
static unsigned
-bitsize_tree_filter_up(bitsize_tree *tree)
+replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
+ struct match_state *state)
{
- for (unsigned i = 0; i < tree->num_srcs; i++) {
- unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
- if (src_size == 0)
- continue;
-
- if (tree->is_src_sized[i]) {
- assert(src_size == tree->src_size[i]);
- } else if (tree->common_size != 0) {
- assert(src_size == tree->common_size);
- tree->src_size[i] = src_size;
- } else {
- tree->common_size = src_size;
- tree->src_size[i] = src_size;
- }
- }
-
- if (tree->num_srcs && tree->common_size) {
- if (tree->dest_size == 0)
- tree->dest_size = tree->common_size;
- else if (!tree->is_dest_sized)
- assert(tree->dest_size == tree->common_size);
-
- for (unsigned i = 0; i < tree->num_srcs; i++) {
- if (!tree->src_size[i])
- tree->src_size[i] = tree->common_size;
- }
- }
-
- return tree->dest_size;
-}
-
-static void
-bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
-{
- if (tree->dest_size)
- assert(tree->dest_size == size);
- else
- tree->dest_size = size;
-
- if (!tree->is_dest_sized) {
- if (tree->common_size)
- assert(tree->common_size == size);
- else
- tree->common_size = size;
- }
-
- for (unsigned i = 0; i < tree->num_srcs; i++) {
- if (!tree->src_size[i]) {
- assert(tree->common_size);
- tree->src_size[i] = tree->common_size;
- }
- bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
- }
+ if (value->bit_size > 0)
+ return value->bit_size;
+ if (value->bit_size < 0)
+ return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
+ return search_bitsize;
}
static nir_alu_src
construct_value(nir_builder *build,
const nir_search_value *value,
- unsigned num_components, bitsize_tree *bitsize,
+ unsigned num_components, unsigned search_bitsize,
struct match_state *state,
nir_instr *instr)
{
@@ -424,7 +317,7 @@ construct_value(nir_builder *build,
nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);
nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
- bitsize->dest_size, NULL);
+ replace_bitsize(value, search_bitsize, state), NULL);
alu->dest.write_mask = (1 << num_components) - 1;
alu->dest.saturate = false;
@@ -443,7 +336,7 @@ construct_value(nir_builder *build,
num_components = nir_op_infos[alu->op].input_sizes[i];
alu->src[i] = construct_value(build, expr->srcs[i],
- num_components, bitsize->srcs[i],
+ num_components, search_bitsize,
state, instr);
}
@@ -472,16 +365,17 @@ construct_value(nir_builder *build,
case nir_search_value_constant: {
const nir_search_constant *c = nir_search_value_as_constant(value);
+ unsigned bit_size = replace_bitsize(value, search_bitsize, state);
nir_ssa_def *cval;
switch (c->type) {
case nir_type_float:
- cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size);
+ cval = nir_imm_floatN_t(build, c->data.d, bit_size);
break;
case nir_type_int:
case nir_type_uint:
- cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size);
+ cval = nir_imm_intN_t(build, c->data.i, bit_size);
break;
case nir_type_bool:
@@ -526,16 +420,12 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
swizzle, &state))
return NULL;
- void *bitsize_ctx = ralloc_context(NULL);
- bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
- bitsize_tree_filter_up(tree);
- bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
-
build->cursor = nir_before_instr(&instr->instr);
nir_alu_src val = construct_value(build, replace,
instr->dest.dest.ssa.num_components,
- tree, &state, &instr->instr);
+ instr->dest.dest.ssa.bit_size,
+ &state, &instr->instr);
/* Inserting a mov may be unnecessary. However, it's much easier to
* simply let copy propagation clean this up than to try to go through
@@ -551,7 +441,5 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
*/
nir_instr_remove(&instr->instr);
- ralloc_free(bitsize_ctx);
-
return ssa_val;
}
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index df4189ede74..a76f39e0f40 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -43,7 +43,22 @@ typedef enum {
typedef struct {
nir_search_value_type type;
- unsigned bit_size;
+ /**
+ * Bit size of the value. It is interpreted as follows:
+ *
+ * For a search expression:
+ * - If bit_size > 0, then the value only matches an SSA value with the
+ * given bit size.
+ * - If bit_size <= 0, then the value matches any size SSA value.
+ *
+ * For a replace expression:
+ * - If bit_size > 0, then the value is constructed with the given bit size.
+ * - If bit_size == 0, then the value is constructed with the same bit size
+ * as the search value.
+ * - If bit_size < 0, then the value is constructed with the same bit size
+ * as variable (-bit_size - 1).
+ */
+ int bit_size;
} nir_search_value;
typedef struct {