aboutsummaryrefslogtreecommitdiffstats
path: root/src/compiler/nir/nir_algebraic.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/nir/nir_algebraic.py')
-rw-r--r--src/compiler/nir/nir_algebraic.py47
1 files changed, 44 insertions, 3 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index b90264b282e..66ee0ad6402 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -33,7 +33,19 @@ import mako.template
import re
import traceback
-from nir_opcodes import opcodes
+from nir_opcodes import opcodes, type_sizes
+
+# These opcodes are only employed by nir_search. This provides a mapping from
+# opcode to destination type.
+conv_opcode_types = {
+ 'i2f' : 'float',
+ 'u2f' : 'float',
+ 'f2f' : 'float',
+ 'f2u' : 'uint',
+ 'f2i' : 'int',
+ 'u2u' : 'uint',
+ 'i2i' : 'int',
+}
if sys.version_info < (3, 0):
integer_types = (int, long)
@@ -98,7 +110,7 @@ static const ${val.c_type} ${val.name} = {
${val.cond if val.cond else 'NULL'},
% elif isinstance(val, Expression):
${'true' if val.inexact else 'false'},
- nir_op_${val.opcode},
+ ${val.c_opcode()},
{ ${', '.join(src.c_ptr for src in val.sources)} },
${val.cond if val.cond else 'NULL'},
% endif
@@ -276,6 +288,18 @@ class Expression(Value):
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
for (i, src) in enumerate(expr[1:]) ]
+ if self.opcode in conv_opcode_types:
+ assert self._bit_size is None, \
+ 'Expression cannot use an unsized conversion opcode with ' \
+ 'an explicit size; that\'s silly.'
+
+
+ def c_opcode(self):
+ if self.opcode in conv_opcode_types:
+ return 'nir_search_op_' + self.opcode
+ else:
+ return 'nir_op_' + self.opcode
+
def render(self):
srcs = "\n".join(src.render() for src in self.sources)
return srcs + super(Expression, self).render()
@@ -462,6 +486,17 @@ class BitSizeValidator(object):
if not isinstance(val, Expression):
return
+ # Generic conversion ops are special in that they have a single unsized
+ # source and an unsized destination and the two don't have to match.
+ # This means there's no validation or unioning to do here besides the
+ # len(val.sources) check.
+ if val.opcode in conv_opcode_types:
+ assert len(val.sources) == 1, \
+ "Expression {} has {} sources, expected 1".format(
+ val, len(val.sources))
+ self.validate_value(val.sources[0])
+ return
+
nir_op = opcodes[val.opcode]
assert len(val.sources) == nir_op.num_inputs, \
"Expression {} has {} sources, expected {}".format(
@@ -732,7 +767,13 @@ class AlgebraicPass(object):
continue
self.xforms.append(xform)
- self.opcode_xforms[xform.search.opcode].append(xform)
+ if xform.search.opcode in conv_opcode_types:
+ dst_type = conv_opcode_types[xform.search.opcode]
+ for size in type_sizes(dst_type):
+ sized_opcode = xform.search.opcode + str(size)
+ self.opcode_xforms[sized_opcode].append(xform)
+ else:
+ self.opcode_xforms[xform.search.opcode].append(xform)
if error:
sys.exit(1)