diff options
Diffstat (limited to 'src/compiler/nir/nir_algebraic.py')
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 47 |
1 files changed, 44 insertions, 3 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index b90264b282e..66ee0ad6402 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -33,7 +33,19 @@ import mako.template import re import traceback -from nir_opcodes import opcodes +from nir_opcodes import opcodes, type_sizes + +# These opcodes are only employed by nir_search. This provides a mapping from +# opcode to destination type. +conv_opcode_types = { + 'i2f' : 'float', + 'u2f' : 'float', + 'f2f' : 'float', + 'f2u' : 'uint', + 'f2i' : 'int', + 'u2u' : 'uint', + 'i2i' : 'int', +} if sys.version_info < (3, 0): integer_types = (int, long) @@ -98,7 +110,7 @@ static const ${val.c_type} ${val.name} = { ${val.cond if val.cond else 'NULL'}, % elif isinstance(val, Expression): ${'true' if val.inexact else 'false'}, - nir_op_${val.opcode}, + ${val.c_opcode()}, { ${', '.join(src.c_ptr for src in val.sources)} }, ${val.cond if val.cond else 'NULL'}, % endif @@ -276,6 +288,18 @@ class Expression(Value): self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) for (i, src) in enumerate(expr[1:]) ] + if self.opcode in conv_opcode_types: + assert self._bit_size is None, \ + 'Expression cannot use an unsized conversion opcode with ' \ + 'an explicit size; that\'s silly.' + + + def c_opcode(self): + if self.opcode in conv_opcode_types: + return 'nir_search_op_' + self.opcode + else: + return 'nir_op_' + self.opcode + def render(self): srcs = "\n".join(src.render() for src in self.sources) return srcs + super(Expression, self).render() @@ -462,6 +486,17 @@ class BitSizeValidator(object): if not isinstance(val, Expression): return + # Generic conversion ops are special in that they have a single unsized + # source and an unsized destination and the two don't have to match. + # This means there's no validation or unioning to do here besides the + # len(val.sources) check. + if val.opcode in conv_opcode_types: + assert len(val.sources) == 1, \ + "Expression {} has {} sources, expected 1".format( + val, len(val.sources)) + self.validate_value(val.sources[0]) + return + nir_op = opcodes[val.opcode] assert len(val.sources) == nir_op.num_inputs, \ "Expression {} has {} sources, expected {}".format( @@ -732,7 +767,13 @@ class AlgebraicPass(object): continue self.xforms.append(xform) - self.opcode_xforms[xform.search.opcode].append(xform) + if xform.search.opcode in conv_opcode_types: + dst_type = conv_opcode_types[xform.search.opcode] + for size in type_sizes(dst_type): + sized_opcode = xform.search.opcode + str(size) + self.opcode_xforms[sized_opcode].append(xform) + else: + self.opcode_xforms[xform.search.opcode].append(xform) if error: sys.exit(1) |