diff options
Diffstat (limited to 'src/compiler/nir')
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 520 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.c | 146 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.h | 17 |
3 files changed, 317 insertions, 366 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index 728196136ab..efd6e52cdb9 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -88,7 +88,7 @@ class Value(object): __template = mako.template.Template(""" static const ${val.c_type} ${val.name} = { - { ${val.type_enum}, ${val.bit_size} }, + { ${val.type_enum}, ${val.c_bit_size} }, % if isinstance(val, Constant): ${val.type()}, { ${val.hex()} /* ${val.value} */ }, % elif isinstance(val, Variable): @@ -112,6 +112,40 @@ static const ${val.c_type} ${val.name} = { def __str__(self): return self.in_val + def get_bit_size(self): + """Get the physical bit-size that has been chosen for this value, or if + there is none, the canonical value which currently represents this + bit-size class. Variables will be preferred, i.e. if there are any + variables in the equivalence class, the canonical value will be a + variable. We do this since we'll need to know which variable each value + is equivalent to when constructing the replacement expression. This is + the "find" part of the union-find algorithm. + """ + bit_size = self + + while isinstance(bit_size, Value): + if bit_size._bit_size is None: + break + bit_size = bit_size._bit_size + + if bit_size is not self: + self._bit_size = bit_size + return bit_size + + def set_bit_size(self, other): + """Make self.get_bit_size() return what other.get_bit_size() return + before calling this, or just "other" if it's a concrete bit-size. This is + the "union" part of the union-find algorithm. + """ + + self_bit_size = self.get_bit_size() + other_bit_size = other if isinstance(other, int) else other.get_bit_size() + + if self_bit_size == other_bit_size: + return + + self_bit_size._bit_size = other_bit_size + @property def type_enum(self): return "nir_search_value_" + self.type_str @@ -124,6 +158,21 @@ static const ${val.c_type} ${val.name} = { def c_ptr(self): return "&{0}.value".format(self.name) + @property + def c_bit_size(self): + bit_size = self.get_bit_size() + if isinstance(bit_size, int): + return bit_size + elif isinstance(bit_size, Variable): + return -bit_size.index - 1 + else: + # If the bit-size class is neither a variable, nor an actual bit-size, then + # - If it's in the search expression, we don't need to check anything + # - If it's in the replace expression, either it's ambiguous (in which + # case we'd reject it), or it equals the bit-size of the search value + # We represent these cases with a 0 bit-size. + return 0 + def render(self): return self.__template.render(val=self, Constant=Constant, @@ -140,14 +189,14 @@ class Constant(Value): if isinstance(val, (str)): m = _constant_re.match(val) self.value = ast.literal_eval(m.group('value')) - self.bit_size = int(m.group('bits')) if m.group('bits') else 0 + self._bit_size = int(m.group('bits')) if m.group('bits') else None else: self.value = val - self.bit_size = 0 + self._bit_size = None if isinstance(self.value, bool): - assert self.bit_size == 0 or self.bit_size == 32 - self.bit_size = 32 + assert self._bit_size is None or self._bit_size == 32 + self._bit_size = 32 def hex(self): if isinstance(self.value, (bool)): @@ -191,11 +240,11 @@ class Variable(Value): self.is_constant = m.group('const') is not None self.cond = m.group('cond') self.required_type = m.group('type') - self.bit_size = int(m.group('bits')) if m.group('bits') else 0 + self._bit_size = int(m.group('bits')) if m.group('bits') else None if self.required_type == 'bool': - assert self.bit_size == 0 or self.bit_size == 32 - self.bit_size = 32 + assert self._bit_size is None or self._bit_size == 32 + self._bit_size = 32 if self.required_type is not None: assert self.required_type in ('float', 'bool', 'int', 'uint') @@ -225,7 +274,7 @@ class Expression(Value): assert m and m.group('opcode') is not None self.opcode = m.group('opcode') - self.bit_size = int(m.group('bits')) if m.group('bits') else 0 + self._bit_size = int(m.group('bits')) if m.group('bits') else None self.inexact = m.group('inexact') is not None self.cond = m.group('cond') self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) @@ -235,40 +284,6 @@ class Expression(Value): srcs = "\n".join(src.render() for src in self.sources) return srcs + super(Expression, self).render() -class IntEquivalenceRelation(object): - """A class representing an equivalence relation on integers. - - Each integer has a canonical form which is the maximum integer to which it - is equivalent. Two integers are equivalent precisely when they have the - same canonical form. - - The convention of maximum is explicitly chosen to make using it in - BitSizeValidator easier because it means that an actual bit_size (if any) - will always be the canonical form. - """ - def __init__(self): - self._remap = {} - - def get_canonical(self, x): - """Get the canonical integer corresponding to x.""" - if x in self._remap: - return self.get_canonical(self._remap[x]) - else: - return x - - def add_equiv(self, a, b): - """Add an equivalence and return the canonical form.""" - c = max(self.get_canonical(a), self.get_canonical(b)) - if a != c: - assert a < c - self._remap[a] = c - - if b != c: - assert b < c - self._remap[b] = c - - return c - class BitSizeValidator(object): """A class for validating bit sizes of expressions. @@ -296,7 +311,7 @@ class BitSizeValidator(object): inference can be ambiguous or contradictory. Consider, for instance, the following transformation: - (('usub_borrow', a, b), ('b2i', ('ult', a, b))) + (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) This transformation can potentially cause a problem because usub_borrow is well-defined for any bit-size of integer. However, b2i always generates a @@ -315,217 +330,250 @@ class BitSizeValidator(object): generate any code. This ensures that bugs are caught at compile time rather than at run time. - The basic operation of the validator is very similar to the bitsize_tree in - nir_search only a little more subtle. Instead of simply tracking bit - sizes, it tracks "bit classes" where each class is represented by an - integer. A value of 0 means we don't know anything yet, positive values - are actual bit-sizes, and negative values are used to track equivalence - classes of sizes that must be the same but have yet to receive an actual - size. The first stage uses the bitsize_tree algorithm to assign bit - classes to each variable. If it ever comes across an inconsistency, it - assert-fails. Then the second stage uses that information to prove that - the resulting expression can always validly be constructed. - """ - - def __init__(self, varset): - self._num_classes = 0 - self._var_classes = [0] * len(varset.names) - self._class_relation = IntEquivalenceRelation() - - def validate(self, search, replace): - search_dst_class = self._propagate_bit_size_up(search) - if search_dst_class == 0: - search_dst_class = self._new_class() - self._propagate_bit_class_down(search, search_dst_class) - - replace_dst_class = self._validate_bit_class_up(replace) - if replace_dst_class != 0: - assert search_dst_class != 0, \ - 'Search expression matches any bit size but replace ' \ - 'expression can only generate {0}-bit values' \ - .format(replace_dst_class) - - assert search_dst_class == replace_dst_class, \ - 'Search expression matches any {0}-bit values but replace ' \ - 'expression can only generates {1}-bit values' \ - .format(search_dst_class, replace_dst_class) - - self._validate_bit_class_down(replace, search_dst_class) - - def _new_class(self): - self._num_classes += 1 - return -self._num_classes - - def _set_var_bit_class(self, var, bit_class): - assert bit_class != 0 - var_class = self._var_classes[var.index] - if var_class == 0: - self._var_classes[var.index] = bit_class - else: - canon_var_class = self._class_relation.get_canonical(var_class) - canon_bit_class = self._class_relation.get_canonical(bit_class) - assert canon_var_class < 0 or canon_bit_class < 0 or \ - canon_var_class == canon_bit_class, \ - 'Variable {0} cannot be both {1}-bit and {2}-bit' \ - .format(str(var), bit_class, var_class) - var_class = self._class_relation.add_equiv(var_class, bit_class) - self._var_classes[var.index] = var_class - - def _get_var_bit_class(self, var): - return self._class_relation.get_canonical(self._var_classes[var.index]) - - def _propagate_bit_size_up(self, val): - if isinstance(val, (Constant, Variable)): - return val.bit_size - - elif isinstance(val, Expression): - nir_op = opcodes[val.opcode] - val.common_size = 0 - for i in range(nir_op.num_inputs): - src_bits = self._propagate_bit_size_up(val.sources[i]) - if src_bits == 0: - continue - - src_type_bits = type_bits(nir_op.input_types[i]) - if src_type_bits != 0: - assert src_bits == src_type_bits, \ - 'Source {0} of nir_op_{1} must be a {2}-bit value but ' \ - 'the only possible matched values are {3}-bit: {4}' \ - .format(i, val.opcode, src_type_bits, src_bits, str(val)) - else: - assert val.common_size == 0 or src_bits == val.common_size, \ - 'Expression cannot have both {0}-bit and {1}-bit ' \ - 'variable-width sources: {2}' \ - .format(src_bits, val.common_size, str(val)) - val.common_size = src_bits - - dst_type_bits = type_bits(nir_op.output_type) - if dst_type_bits != 0: - assert val.bit_size == 0 or val.bit_size == dst_type_bits, \ - 'nir_op_{0} produces a {1}-bit result but a {2}-bit ' \ - 'result was requested' \ - .format(val.opcode, dst_type_bits, val.bit_size) - return dst_type_bits - else: - if val.common_size != 0: - assert val.bit_size == 0 or val.bit_size == val.common_size, \ - 'Variable width expression musr be {0}-bit based on ' \ - 'the sources but a {1}-bit result was requested: {2}' \ - .format(val.common_size, val.bit_size, str(val)) - else: - val.common_size = val.bit_size - return val.common_size + Each value maintains a "bit-size class", which is either an actual bit size + or an equivalence class with other values that must have the same bit size. + The validator works by combining bit-size classes with each other according + to the NIR rules outlined above, checking that there are no inconsistencies. + When doing this for the replacement expression, we make sure to never change + the equivalence class of any of the search values. We could make the example + transforms above work by doing some extra run-time checking of the search + expression, but we make the user specify those constraints themselves, to + avoid any surprises. Since the replacement bitsizes can only be connected to + the source bitsize via variables (variables must have the same bitsize in + the source and replacment expressions) or the roots of the expression (the + replacement expression must produce the same bit size as the search + expression), we prevent merging a variable with anything when processing the + replacement expression, or specializing the search bitsize + with anything. The former prevents - def _propagate_bit_class_down(self, val, bit_class): - if isinstance(val, Constant): - assert val.bit_size == 0 or val.bit_size == bit_class, \ - 'Constant is {0}-bit but a {1}-bit value is required: {2}' \ - .format(val.bit_size, bit_class, str(val)) + (('bcsel', a, b, 0), ('iand', a, b)) - elif isinstance(val, Variable): - assert val.bit_size == 0 or val.bit_size == bit_class, \ - 'Variable is {0}-bit but a {1}-bit value is required: {2}' \ - .format(val.bit_size, bit_class, str(val)) - self._set_var_bit_class(val, bit_class) + from being allowed, since we'd have to merge the bitsizes for a and b due to + the 'iand', while the latter prevents - elif isinstance(val, Expression): - nir_op = opcodes[val.opcode] - dst_type_bits = type_bits(nir_op.output_type) - if dst_type_bits != 0: - assert bit_class == 0 or bit_class == dst_type_bits, \ - 'nir_op_{0} produces a {1}-bit result but the parent ' \ - 'expression wants a {2}-bit value' \ - .format(val.opcode, dst_type_bits, bit_class) - else: - assert val.common_size == 0 or val.common_size == bit_class, \ - 'Variable-width expression produces a {0}-bit result ' \ - 'based on the source widths but the parent expression ' \ - 'wants a {1}-bit value: {2}' \ - .format(val.common_size, bit_class, str(val)) - val.common_size = bit_class - - if val.common_size: - common_class = val.common_size - elif nir_op.num_inputs: - # If we got here then we have no idea what the actual size is. - # Instead, we use a generic class - common_class = self._new_class() - - for i in range(nir_op.num_inputs): - src_type_bits = type_bits(nir_op.input_types[i]) - if src_type_bits != 0: - self._propagate_bit_class_down(val.sources[i], src_type_bits) - else: - self._propagate_bit_class_down(val.sources[i], common_class) + (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) - def _validate_bit_class_up(self, val): - if isinstance(val, Constant): - return val.bit_size + from being allowed, since the search expression has the bit size of a and b, + which can't be specialized to 32 which is the bitsize of the replace + expression. It also prevents something like: - elif isinstance(val, Variable): - var_class = self._get_var_bit_class(val) - # By the time we get to validation, every variable should have a class - assert var_class != 0 + (('b2i', ('i2b', a)), ('ineq', a, 0)) - # If we have an explicit size provided by the user, the variable - # *must* exactly match the search. It cannot be implicitly sized - # because otherwise we could end up with a conflict at runtime. - assert val.bit_size == 0 or val.bit_size == var_class + since the bitsize of 'b2i', which can be anything, can't be specialized to + the bitsize of a. - return var_class + After doing all this, we check that every subexpression of the replacement + was assigned a constant bitsize, the bitsize of a variable, or the bitsize + of the search expresssion, since those are the things that are known when + constructing the replacement expresssion. Finally, we record the bitsize + needed in nir_search_value so that we know what to do when building the + replacement expression. + """ + def __init__(self, varset): + self._var_classes = [None] * len(varset.names) + + def compare_bitsizes(self, a, b): + """Determines which bitsize class is a specialization of the other, or + whether neither is. When we merge two different bitsizes, the + less-specialized bitsize always points to the more-specialized one, so + that calling get_bit_size() always gets you the most specialized bitsize. + The specialization partial order is given by: + - Physical bitsizes are always the most specialized, and a different + bitsize can never specialize another. + - In the search expression, variables can always be specialized to each + other and to physical bitsizes. In the replace expression, we disallow + this to avoid adding extra constraints to the search expression that + the user didn't specify. + - Expressions and constants without a bitsize can always be specialized to + each other and variables, but not the other way around. + + We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, + and None if they are not comparable (neither a <= b nor b <= a). + """ + if isinstance(a, int): + if isinstance(b, int): + return 0 if a == b else None + elif isinstance(b, Variable): + return -1 if self.is_search else None + else: + return -1 + elif isinstance(a, Variable): + if isinstance(b, int): + return 1 if self.is_search else None + elif isinstance(b, Variable): + return 0 if self.is_search or a.index == b.index else None + else: + return -1 + else: + if isinstance(b, int): + return 1 + elif isinstance(b, Variable): + return 1 + else: + return 0 + + def unify_bit_size(self, a, b, error_msg): + """Record that a must have the same bit-size as b. If both + have been assigned conflicting physical bit-sizes, call "error_msg" with + the bit-sizes of self and other to get a message and raise an error. + In the replace expression, disallow merging variables with other + variables and physical bit-sizes as well. + """ + a_bit_size = a.get_bit_size() + b_bit_size = b if isinstance(b, int) else b.get_bit_size() + + cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) + + assert cmp_result is not None, \ + error_msg(a_bit_size, b_bit_size) + + if cmp_result < 0: + b_bit_size.set_bit_size(a) + elif not isinstance(a_bit_size, int): + a_bit_size.set_bit_size(b) + + def merge_variables(self, val): + """Perform the first part of type inference by merging all the different + uses of the same variable. We always do this as if we're in the search + expression, even if we're actually not, since otherwise we'd get errors + if the search expression specified some constraint but the replace + expression didn't, because we'd be merging a variable and a constant. + """ + if isinstance(val, Variable): + if self._var_classes[val.index] is None: + self._var_classes[val.index] = val + else: + other = self._var_classes[val.index] + self.unify_bit_size(other, val, + lambda other_bit_size, bit_size: + 'Variable {} has conflicting bit size requirements: ' \ + 'it must have bit size {} and {}'.format( + val.var_name, other_bit_size, bit_size)) elif isinstance(val, Expression): - nir_op = opcodes[val.opcode] - val.common_class = 0 - for i in range(nir_op.num_inputs): - src_class = self._validate_bit_class_up(val.sources[i]) - if src_class == 0: + for src in val.sources: + self.merge_variables(src) + + def validate_value(self, val): + """Validate the an expression by performing classic Hindley-Milner + type inference on bitsizes. This will detect if there are any conflicting + requirements, and unify variables so that we know which variables must + have the same bitsize. If we're operating on the replace expression, we + will refuse to merge different variables together or merge a variable + with a constant, in order to prevent surprises due to rules unexpectedly + not matching at runtime. + """ + if not isinstance(val, Expression): + return + + nir_op = opcodes[val.opcode] + assert len(val.sources) == nir_op.num_inputs, \ + "Expression {} has {} sources, expected {}".format( + val, len(val.sources), nir_op.num_inputs) + + for src in val.sources: + self.validate_value(src) + + dst_type_bits = type_bits(nir_op.output_type) + + # First, unify all the sources. That way, an error coming up because two + # sources have an incompatible bit-size won't produce an error message + # involving the destination. + first_unsized_src = None + for src_type, src in zip(nir_op.input_types, val.sources): + src_type_bits = type_bits(src_type) + if src_type_bits == 0: + if first_unsized_src is None: + first_unsized_src = src continue - src_type_bits = type_bits(nir_op.input_types[i]) - if src_type_bits != 0: - assert src_class == src_type_bits + if self.is_search: + self.unify_bit_size(first_unsized_src, src, + lambda first_unsized_src_bit_size, src_bit_size: + 'Source {} of {} must have bit size {}, while source {} ' \ + 'must have incompatible bit size {}'.format( + first_unsized_src, val, first_unsized_src_bit_size, + src, src_bit_size)) else: - assert val.common_class == 0 or src_class == val.common_class - val.common_class = src_class - - dst_type_bits = type_bits(nir_op.output_type) - if dst_type_bits != 0: - assert val.bit_size == 0 or val.bit_size == dst_type_bits - return dst_type_bits + self.unify_bit_size(first_unsized_src, src, + lambda first_unsized_src_bit_size, src_bit_size: + 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ + 'of {} may not have the same bit size when building the ' \ + 'replacement expression.'.format( + first_unsized_src, first_unsized_src_bit_size, src, + src_bit_size, val)) else: - if val.common_class != 0: - assert val.bit_size == 0 or val.bit_size == val.common_class + if self.is_search: + self.unify_bit_size(src, src_type_bits, + lambda src_bit_size, unused: + '{} must have {} bits, but as a source of nir_op_{} '\ + 'it must have {} bits'.format( + src, src_bit_size, nir_op.name, src_type_bits)) + else: + self.unify_bit_size(src, src_type_bits, + lambda src_bit_size, unused: + '{} has the bit size of {}, but as a source of ' \ + 'nir_op_{} it must have {} bits, which may not be the ' \ + 'same'.format( + src, src_bit_size, nir_op.name, src_type_bits)) + + if dst_type_bits == 0: + if first_unsized_src is not None: + if self.is_search: + self.unify_bit_size(val, first_unsized_src, + lambda val_bit_size, src_bit_size: + '{} must have the bit size of {}, while its source {} ' \ + 'must have incompatible bit size {}'.format( + val, val_bit_size, first_unsized_src, src_bit_size)) else: - val.common_class = val.bit_size - return val.common_class + self.unify_bit_size(val, first_unsized_src, + lambda val_bit_size, src_bit_size: + '{} must have {} bits, but its source {} ' \ + '(bit size of {}) may not have that bit size ' \ + 'when building the replacement.'.format( + val, val_bit_size, first_unsized_src, src_bit_size)) + else: + self.unify_bit_size(val, dst_type_bits, + lambda dst_bit_size, unused: + '{} must have {} bits, but as a destination of nir_op_{} ' \ + 'it must have {} bits'.format( + val, dst_bit_size, nir_op.name, dst_type_bits)) + + def validate_replace(self, val, search): + bit_size = val.get_bit_size() + assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ + bit_size == search.get_bit_size(), \ + 'Ambiguous bit size for replacement value {}: ' \ + 'it cannot be deduced from a variable, a fixed bit size ' \ + 'somewhere, or the search expression.'.format(val) + + if isinstance(val, Expression): + for src in val.sources: + self.validate_replace(src, search) - def _validate_bit_class_down(self, val, bit_class): - # At this point, everything *must* have a bit class. Otherwise, we have - # a value we don't know how to define. - assert bit_class != 0 + def validate(self, search, replace): + self.is_search = True + self.merge_variables(search) + self.merge_variables(replace) + self.validate_value(search) - if isinstance(val, Constant): - assert val.bit_size == 0 or val.bit_size == bit_class + self.is_search = False + self.validate_value(replace) - elif isinstance(val, Variable): - assert val.bit_size == 0 or val.bit_size == bit_class + # Check that search is always more specialized than replace. Note that + # we're doing this in replace mode, disallowing merging variables. + search_bit_size = search.get_bit_size() + replace_bit_size = replace.get_bit_size() + cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) - elif isinstance(val, Expression): - nir_op = opcodes[val.opcode] - dst_type_bits = type_bits(nir_op.output_type) - if dst_type_bits != 0: - assert bit_class == dst_type_bits - else: - assert val.common_class == 0 or val.common_class == bit_class - val.common_class = bit_class + assert cmp_result is not None and cmp_result <= 0, \ + 'The search expression bit size {} and replace expression ' \ + 'bit size {} may not be the same'.format( + search_bit_size, replace_bit_size) - for i in range(nir_op.num_inputs): - src_type_bits = type_bits(nir_op.input_types[i]) - if src_type_bits != 0: - self._validate_bit_class_down(val.sources[i], src_type_bits) - else: - self._validate_bit_class_down(val.sources[i], val.common_class) + replace.set_bit_size(search) + + self.validate_replace(replace, search) _optimization_ids = itertools.count() diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index 0270302fd3d..a41fca876d5 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -118,7 +118,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; /* If the value has a specific bit size and it doesn't match, bail */ - if (value->bit_size && + if (value->bit_size > 0 && nir_src_bit_size(instr->src[src].src) != value->bit_size) return false; @@ -228,7 +228,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, assert(instr->dest.dest.is_ssa); - if (expr->value.bit_size && + if (expr->value.bit_size > 0 && instr->dest.dest.ssa.bit_size != expr->value.bit_size) return false; @@ -290,128 +290,21 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, } } -typedef struct bitsize_tree { - unsigned num_srcs; - struct bitsize_tree *srcs[4]; - - unsigned common_size; - bool is_src_sized[4]; - bool is_dest_sized; - - unsigned dest_size; - unsigned src_size[4]; -} bitsize_tree; - -static bitsize_tree * -build_bitsize_tree(void *mem_ctx, struct match_state *state, - const nir_search_value *value) -{ - bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree); - - switch (value->type) { - case nir_search_value_expression: { - nir_search_expression *expr = nir_search_value_as_expression(value); - nir_op_info info = nir_op_infos[expr->opcode]; - tree->num_srcs = info.num_inputs; - tree->common_size = 0; - for (unsigned i = 0; i < info.num_inputs; i++) { - tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]); - if (tree->is_src_sized[i]) - tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]); - tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]); - } - tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type); - if (tree->is_dest_sized) - tree->dest_size = nir_alu_type_get_type_size(info.output_type); - break; - } - - case nir_search_value_variable: { - nir_search_variable *var = nir_search_value_as_variable(value); - tree->num_srcs = 0; - tree->is_dest_sized = true; - tree->dest_size = nir_src_bit_size(state->variables[var->variable].src); - break; - } - - case nir_search_value_constant: { - tree->num_srcs = 0; - tree->is_dest_sized = false; - tree->common_size = 0; - break; - } - } - - if (value->bit_size) { - assert(!tree->is_dest_sized || tree->dest_size == value->bit_size); - tree->common_size = value->bit_size; - } - - return tree; -} - static unsigned -bitsize_tree_filter_up(bitsize_tree *tree) +replace_bitsize(const nir_search_value *value, unsigned search_bitsize, + struct match_state *state) { - for (unsigned i = 0; i < tree->num_srcs; i++) { - unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]); - if (src_size == 0) - continue; - - if (tree->is_src_sized[i]) { - assert(src_size == tree->src_size[i]); - } else if (tree->common_size != 0) { - assert(src_size == tree->common_size); - tree->src_size[i] = src_size; - } else { - tree->common_size = src_size; - tree->src_size[i] = src_size; - } - } - - if (tree->num_srcs && tree->common_size) { - if (tree->dest_size == 0) - tree->dest_size = tree->common_size; - else if (!tree->is_dest_sized) - assert(tree->dest_size == tree->common_size); - - for (unsigned i = 0; i < tree->num_srcs; i++) { - if (!tree->src_size[i]) - tree->src_size[i] = tree->common_size; - } - } - - return tree->dest_size; -} - -static void -bitsize_tree_filter_down(bitsize_tree *tree, unsigned size) -{ - if (tree->dest_size) - assert(tree->dest_size == size); - else - tree->dest_size = size; - - if (!tree->is_dest_sized) { - if (tree->common_size) - assert(tree->common_size == size); - else - tree->common_size = size; - } - - for (unsigned i = 0; i < tree->num_srcs; i++) { - if (!tree->src_size[i]) { - assert(tree->common_size); - tree->src_size[i] = tree->common_size; - } - bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]); - } + if (value->bit_size > 0) + return value->bit_size; + if (value->bit_size < 0) + return nir_src_bit_size(state->variables[-value->bit_size - 1].src); + return search_bitsize; } static nir_alu_src construct_value(nir_builder *build, const nir_search_value *value, - unsigned num_components, bitsize_tree *bitsize, + unsigned num_components, unsigned search_bitsize, struct match_state *state, nir_instr *instr) { @@ -424,7 +317,7 @@ construct_value(nir_builder *build, nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode); nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, - bitsize->dest_size, NULL); + replace_bitsize(value, search_bitsize, state), NULL); alu->dest.write_mask = (1 << num_components) - 1; alu->dest.saturate = false; @@ -443,7 +336,7 @@ construct_value(nir_builder *build, num_components = nir_op_infos[alu->op].input_sizes[i]; alu->src[i] = construct_value(build, expr->srcs[i], - num_components, bitsize->srcs[i], + num_components, search_bitsize, state, instr); } @@ -472,16 +365,17 @@ construct_value(nir_builder *build, case nir_search_value_constant: { const nir_search_constant *c = nir_search_value_as_constant(value); + unsigned bit_size = replace_bitsize(value, search_bitsize, state); nir_ssa_def *cval; switch (c->type) { case nir_type_float: - cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size); + cval = nir_imm_floatN_t(build, c->data.d, bit_size); break; case nir_type_int: case nir_type_uint: - cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size); + cval = nir_imm_intN_t(build, c->data.i, bit_size); break; case nir_type_bool: @@ -526,16 +420,12 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, swizzle, &state)) return NULL; - void *bitsize_ctx = ralloc_context(NULL); - bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace); - bitsize_tree_filter_up(tree); - bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size); - build->cursor = nir_before_instr(&instr->instr); nir_alu_src val = construct_value(build, replace, instr->dest.dest.ssa.num_components, - tree, &state, &instr->instr); + instr->dest.dest.ssa.bit_size, + &state, &instr->instr); /* Inserting a mov may be unnecessary. However, it's much easier to * simply let copy propagation clean this up than to try to go through @@ -551,7 +441,5 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, */ nir_instr_remove(&instr->instr); - ralloc_free(bitsize_ctx); - return ssa_val; } diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index df4189ede74..a76f39e0f40 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -43,7 +43,22 @@ typedef enum { typedef struct { nir_search_value_type type; - unsigned bit_size; + /** + * Bit size of the value. It is interpreted as follows: + * + * For a search expression: + * - If bit_size > 0, then the value only matches an SSA value with the + * given bit size. + * - If bit_size <= 0, then the value matches any size SSA value. + * + * For a replace expression: + * - If bit_size > 0, then the value is constructed with the given bit size. + * - If bit_size == 0, then the value is constructed with the same bit size + * as the search value. + * - If bit_size < 0, then the value is constructed with the same bit size + * as variable (-bit_size - 1). + */ + int bit_size; } nir_search_value; typedef struct { |