summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJason Ekstrand <[email protected]>2017-03-08 20:23:05 -0800
committerJason Ekstrand <[email protected]>2017-03-14 07:36:40 -0700
commit9d559ba39dca49c30cdfc81e8fdfbefb06a05f2a (patch)
tree44110ce766a7f477a19125e7045dd02f9624af62 /src
parent762a6333f21fd8606f69db6060027c4522d46678 (diff)
nir/constant_expressions: Refactor helper functions
Apart from avoiding some unneeded size cases, this shouldn't have any actual functional impact. Reviewed-by: Dylan Baker <[email protected]> Reviewed-by: Lionel Landwerlin <[email protected]>
Diffstat (limited to 'src')
-rw-r--r--src/compiler/nir/nir_constant_expressions.py51
1 files changed, 27 insertions, 24 deletions
diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py
index 3da20fd503b..c6745f1e934 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -1,16 +1,18 @@
+import re
+
+type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
+
def type_has_size(type_):
return type_[-1:].isdigit()
+def type_size(type_):
+ assert type_has_size(type_)
+ return int(type_split_re.match(type_).group('bits'))
+
def type_sizes(type_):
- if type_.endswith("8"):
- return [8]
- elif type_.endswith("16"):
- return [16]
- elif type_.endswith("32"):
- return [32]
- elif type_.endswith("64"):
- return [64]
+ if type_has_size(type_):
+ return [type_size(type_)]
else:
return [32, 64]
@@ -19,23 +21,23 @@ def type_add_size(type_, size):
return type_
return type_ + str(size)
+def op_bit_sizes(op):
+ sizes = set([8, 16, 32, 64])
+ if not type_has_size(op.output_type):
+ sizes = sizes.intersection(set(type_sizes(op.output_type)))
+ for input_type in op.input_types:
+ if not type_has_size(input_type):
+ sizes = sizes.intersection(set(type_sizes(input_type)))
+ return sorted(list(sizes))
+
def get_const_field(type_):
- if type_ == "int32":
- return "i32"
- if type_ == "uint32":
- return "u32"
- if type_ == "int64":
- return "i64"
- if type_ == "uint64":
- return "u64"
if type_ == "bool32":
return "u32"
- if type_ == "float32":
- return "f32"
- if type_ == "float64":
- return "f64"
- raise Exception(str(type_))
- assert(0)
+ else:
+ m = type_split_re.match(type_)
+ if not m:
+ raise Exception(str(type_))
+ return m.group('type')[0] + m.group('bits')
template = """\
/*
@@ -247,7 +249,7 @@ typedef float float32_t;
typedef double float64_t;
typedef bool bool32_t;
% for type in ["float", "int", "uint"]:
-% for width in [32, 64]:
+% for width in type_sizes(type):
struct ${type}${width}_vec {
${type}${width}_t x;
${type}${width}_t y;
@@ -272,7 +274,7 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
nir_const_value _dst_val = { {0, } };
switch (bit_size) {
- % for bit_size in [32, 64]:
+ % for bit_size in op_bit_sizes(op):
case ${bit_size}: {
<%
output_type = type_add_size(op.output_type, bit_size)
@@ -406,4 +408,5 @@ from mako.template import Template
print Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
type_has_size=type_has_size,
type_add_size=type_add_size,
+ op_bit_sizes=op_bit_sizes,
get_const_field=get_const_field)