summaryrefslogtreecommitdiffstats
path: root/src/compiler/nir/nir_constant_expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/nir/nir_constant_expressions.py')
-rw-r--r--src/compiler/nir/nir_constant_expressions.py51
1 files changed, 27 insertions, 24 deletions
diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py
index 3da20fd503b..c6745f1e934 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -1,16 +1,18 @@
+import re
+
+type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
+
def type_has_size(type_):
return type_[-1:].isdigit()
+def type_size(type_):
+ assert type_has_size(type_)
+ return int(type_split_re.match(type_).group('bits'))
+
def type_sizes(type_):
- if type_.endswith("8"):
- return [8]
- elif type_.endswith("16"):
- return [16]
- elif type_.endswith("32"):
- return [32]
- elif type_.endswith("64"):
- return [64]
+ if type_has_size(type_):
+ return [type_size(type_)]
else:
return [32, 64]
@@ -19,23 +21,23 @@ def type_add_size(type_, size):
return type_
return type_ + str(size)
+def op_bit_sizes(op):
+ sizes = set([8, 16, 32, 64])
+ if not type_has_size(op.output_type):
+ sizes = sizes.intersection(set(type_sizes(op.output_type)))
+ for input_type in op.input_types:
+ if not type_has_size(input_type):
+ sizes = sizes.intersection(set(type_sizes(input_type)))
+ return sorted(list(sizes))
+
def get_const_field(type_):
- if type_ == "int32":
- return "i32"
- if type_ == "uint32":
- return "u32"
- if type_ == "int64":
- return "i64"
- if type_ == "uint64":
- return "u64"
if type_ == "bool32":
return "u32"
- if type_ == "float32":
- return "f32"
- if type_ == "float64":
- return "f64"
- raise Exception(str(type_))
- assert(0)
+ else:
+ m = type_split_re.match(type_)
+ if not m:
+ raise Exception(str(type_))
+ return m.group('type')[0] + m.group('bits')
template = """\
/*
@@ -247,7 +249,7 @@ typedef float float32_t;
typedef double float64_t;
typedef bool bool32_t;
% for type in ["float", "int", "uint"]:
-% for width in [32, 64]:
+% for width in type_sizes(type):
struct ${type}${width}_vec {
${type}${width}_t x;
${type}${width}_t y;
@@ -272,7 +274,7 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
nir_const_value _dst_val = { {0, } };
switch (bit_size) {
- % for bit_size in [32, 64]:
+ % for bit_size in op_bit_sizes(op):
case ${bit_size}: {
<%
output_type = type_add_size(op.output_type, bit_size)
@@ -406,4 +408,5 @@ from mako.template import Template
print Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
type_has_size=type_has_size,
type_add_size=type_add_size,
+ op_bit_sizes=op_bit_sizes,
get_const_field=get_const_field)