diff options
author | Ian Romanick <[email protected]> | 2016-07-11 11:05:13 -0700 |
---|---|---|
committer | Ian Romanick <[email protected]> | 2016-08-30 16:28:02 -0700 |
commit | 74e335c7623165e8c4b15609ec34c6c4e952995e (patch) | |
tree | be79374e15f3b748da22ea4335284640a4e4211f /src/compiler/glsl/ir_expression_operation.py | |
parent | f81b1c7fa7f5392ea0b950b888db6f023769af52 (diff) |
glsl: Generate code for constant binary expressions that combine vector and scalar operands
v2: 'for (a, b) in d' => 'for a, b in d'. Suggested by Dylan.
Signed-off-by: Ian Romanick <[email protected]>
Reviewed-by: Matt Turner <[email protected]>
Acked-by: Dylan Baker <[email protected]>
Diffstat (limited to 'src/compiler/glsl/ir_expression_operation.py')
-rw-r--r-- | src/compiler/glsl/ir_expression_operation.py | 66 |
1 files changed, 51 insertions, 15 deletions
diff --git a/src/compiler/glsl/ir_expression_operation.py b/src/compiler/glsl/ir_expression_operation.py index d086366cbe3..21d3803a332 100644 --- a/src/compiler/glsl/ir_expression_operation.py +++ b/src/compiler/glsl/ir_expression_operation.py @@ -141,9 +141,32 @@ constant_template2 = mako.template.Template("""\ data.${op.dest_type.union_field}[c] = ${op.get_c_expression(op.source_types)}; break;""") +# This template is for binary operations that can operate on some combination +# of scalar and vector operands. +constant_template_vector_scalar = mako.template.Template("""\ + case ${op.get_enum_name()}: + assert(op[0]->type == op[1]->type || op0_scalar || op1_scalar); + for (unsigned c = 0, c0 = 0, c1 = 0; + c < components; + c0 += c0_inc, c1 += c1_inc, c++) { + + switch (op[0]->type->base_type) { + % for dst_type, src_types in op.signatures(): + case ${src_types[0].glsl_type}: + data.${dst_type.union_field}[c] = ${op.get_c_expression(src_types, ("c0", "c1"))}; + break; + % endfor + default: + assert(0); + } + } + break;""") + + +vector_scalar_operation = "vector-scalar" class operation(object): - def __init__(self, name, num_operands, printable_name = None, source_types = None, dest_type = None, c_expression = None): + def __init__(self, name, num_operands, printable_name = None, source_types = None, dest_type = None, c_expression = None, flags = None): self.name = name self.num_operands = num_operands @@ -162,6 +185,13 @@ class operation(object): else: self.c_expression = c_expression + if flags is None: + self.flags = frozenset() + elif isinstance(flags, str): + self.flags = frozenset([flags]) + else: + self.flags = frozenset(flags) + def get_enum_name(self): return "ir_{}op_{}".format(("un", "bin", "tri", "quad")[self.num_operands-1], self.name) @@ -181,19 +211,22 @@ class operation(object): else: return constant_template3.render(op=self) elif self.num_operands == 2: - if len(self.source_types) == 1: + if vector_scalar_operation in self.flags: + return constant_template_vector_scalar.render(op=self) + elif len(self.source_types) == 1: return constant_template0.render(op=self) return None - def get_c_expression(self, types): - src0 = "op[0]->value.{}[c]".format(types[0].union_field) - src1 = "op[1]->value.{}[c]".format(types[1].union_field) if len(types) >= 2 else "ERROR" + def get_c_expression(self, types, indices=("c", "c")): + src0 = "op[0]->value.{}[{}]".format(types[0].union_field, indices[0]) + src1 = "op[1]->value.{}[{}]".format(types[1].union_field, indices[1]) if len(types) >= 2 else "ERROR" expr = self.c_expression[types[0].union_field] if types[0].union_field in self.c_expression else self.c_expression['default'] - return expr.format(src0=src0) + return expr.format(src0=src0, + src1=src1) def signatures(self): @@ -329,12 +362,12 @@ ir_expression_operation = [ operation("vote_all", 1), operation("vote_eq", 1), - operation("add", 2, printable_name="+"), - operation("sub", 2, printable_name="-"), + operation("add", 2, printable_name="+", source_types=numeric_types, c_expression="{src0} + {src1}", flags=vector_scalar_operation), + operation("sub", 2, printable_name="-", source_types=numeric_types, c_expression="{src0} - {src1}", flags=vector_scalar_operation), # "Floating-point or low 32-bit integer multiply." operation("mul", 2, printable_name="*"), operation("imul_high", 2), # Calculates the high 32-bits of a 64-bit multiply. - operation("div", 2, printable_name="/"), + operation("div", 2, printable_name="/", source_types=numeric_types, c_expression={'u': "{src1} == 0 ? 0 : {src0} / {src1}", 'i': "{src1} == 0 ? 0 : {src0} / {src1}", 'default': "{src0} / {src1}"}, flags=vector_scalar_operation), # Returns the carry resulting from the addition of the two arguments. operation("carry", 2), @@ -344,7 +377,10 @@ ir_expression_operation = [ operation("borrow", 2), # Either (vector % vector) or (vector % scalar) - operation("mod", 2, printable_name="%"), + # + # We don't use fmod because it rounds toward zero; GLSL specifies the use + # of floor. + operation("mod", 2, printable_name="%", source_types=numeric_types, c_expression={'u': "{src1} == 0 ? 0 : {src0} % {src1}", 'i': "{src1} == 0 ? 0 : {src0} % {src1}", 'f': "{src0} - {src1} * floorf({src0} / {src1})", 'd': "{src0} - {src1} * floor({src0} / {src1})"}, flags=vector_scalar_operation), # Binary comparison operators which return a boolean vector. # The type of both operands must be equal. @@ -366,17 +402,17 @@ ir_expression_operation = [ # Bit-wise binary operations. operation("lshift", 2, printable_name="<<"), operation("rshift", 2, printable_name=">>"), - operation("bit_and", 2, printable_name="&"), - operation("bit_xor", 2, printable_name="^"), - operation("bit_or", 2, printable_name="|"), + operation("bit_and", 2, printable_name="&", source_types=integer_types, c_expression="{src0} & {src1}", flags=vector_scalar_operation), + operation("bit_xor", 2, printable_name="^", source_types=integer_types, c_expression="{src0} ^ {src1}", flags=vector_scalar_operation), + operation("bit_or", 2, printable_name="|", source_types=integer_types, c_expression="{src0} | {src1}", flags=vector_scalar_operation), operation("logic_and", 2, printable_name="&&", source_types=(bool_type,), c_expression="{src0} && {src1}"), operation("logic_xor", 2, printable_name="^^", source_types=(bool_type,), c_expression="{src0} != {src1}"), operation("logic_or", 2, printable_name="||", source_types=(bool_type,), c_expression="{src0} || {src1}"), operation("dot", 2), - operation("min", 2), - operation("max", 2), + operation("min", 2, source_types=numeric_types, c_expression="MIN2({src0}, {src1})", flags=vector_scalar_operation), + operation("max", 2, source_types=numeric_types, c_expression="MAX2({src0}, {src1})", flags=vector_scalar_operation), operation("pow", 2, source_types=(float_type,), c_expression="powf({src0}, {src1})"), |