diff options
Diffstat (limited to 'src/compiler')
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 47 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.c | 90 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.h | 13 |
3 files changed, 140 insertions, 10 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index b90264b282e..66ee0ad6402 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -33,7 +33,19 @@ import mako.template import re import traceback -from nir_opcodes import opcodes +from nir_opcodes import opcodes, type_sizes + +# These opcodes are only employed by nir_search. This provides a mapping from +# opcode to destination type. +conv_opcode_types = { + 'i2f' : 'float', + 'u2f' : 'float', + 'f2f' : 'float', + 'f2u' : 'uint', + 'f2i' : 'int', + 'u2u' : 'uint', + 'i2i' : 'int', +} if sys.version_info < (3, 0): integer_types = (int, long) @@ -98,7 +110,7 @@ static const ${val.c_type} ${val.name} = { ${val.cond if val.cond else 'NULL'}, % elif isinstance(val, Expression): ${'true' if val.inexact else 'false'}, - nir_op_${val.opcode}, + ${val.c_opcode()}, { ${', '.join(src.c_ptr for src in val.sources)} }, ${val.cond if val.cond else 'NULL'}, % endif @@ -276,6 +288,18 @@ class Expression(Value): self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) for (i, src) in enumerate(expr[1:]) ] + if self.opcode in conv_opcode_types: + assert self._bit_size is None, \ + 'Expression cannot use an unsized conversion opcode with ' \ + 'an explicit size; that\'s silly.' + + + def c_opcode(self): + if self.opcode in conv_opcode_types: + return 'nir_search_op_' + self.opcode + else: + return 'nir_op_' + self.opcode + def render(self): srcs = "\n".join(src.render() for src in self.sources) return srcs + super(Expression, self).render() @@ -462,6 +486,17 @@ class BitSizeValidator(object): if not isinstance(val, Expression): return + # Generic conversion ops are special in that they have a single unsized + # source and an unsized destination and the two don't have to match. + # This means there's no validation or unioning to do here besides the + # len(val.sources) check. + if val.opcode in conv_opcode_types: + assert len(val.sources) == 1, \ + "Expression {} has {} sources, expected 1".format( + val, len(val.sources)) + self.validate_value(val.sources[0]) + return + nir_op = opcodes[val.opcode] assert len(val.sources) == nir_op.num_inputs, \ "Expression {} has {} sources, expected {}".format( @@ -732,7 +767,13 @@ class AlgebraicPass(object): continue self.xforms.append(xform) - self.opcode_xforms[xform.search.opcode].append(xform) + if xform.search.opcode in conv_opcode_types: + dst_type = conv_opcode_types[xform.search.opcode] + for size in type_sizes(dst_type): + sized_opcode = xform.search.opcode + str(size) + self.opcode_xforms[sized_opcode].append(xform) + else: + self.opcode_xforms[xform.search.opcode].append(xform) if error: sys.exit(1) diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index a41fca876d5..f5fc92ec33c 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -90,6 +90,82 @@ src_is_type(nir_src src, nir_alu_type type) } static bool +nir_op_matches_search_op(nir_op nop, uint16_t sop) +{ + if (sop <= nir_last_opcode) + return nop == sop; + +#define MATCH_FCONV_CASE(op) \ + case nir_search_op_##op: \ + return nop == nir_op_##op##16 || \ + nop == nir_op_##op##32 || \ + nop == nir_op_##op##64; + +#define MATCH_ICONV_CASE(op) \ + case nir_search_op_##op: \ + return nop == nir_op_##op##8 || \ + nop == nir_op_##op##16 || \ + nop == nir_op_##op##32 || \ + nop == nir_op_##op##64; + + switch (sop) { + MATCH_FCONV_CASE(i2f) + MATCH_FCONV_CASE(u2f) + MATCH_FCONV_CASE(f2f) + MATCH_ICONV_CASE(f2u) + MATCH_ICONV_CASE(f2i) + MATCH_ICONV_CASE(u2u) + MATCH_ICONV_CASE(i2i) + default: + unreachable("Invalid nir_search_op"); + } + +#undef MATCH_FCONV_CASE +#undef MATCH_ICONV_CASE +} + +static nir_op +nir_op_for_search_op(uint16_t sop, unsigned bit_size) +{ + if (sop <= nir_last_opcode) + return sop; + +#define RET_FCONV_CASE(op) \ + case nir_search_op_##op: \ + switch (bit_size) { \ + case 16: return nir_op_##op##16; \ + case 32: return nir_op_##op##32; \ + case 64: return nir_op_##op##64; \ + default: unreachable("Invalid bit size"); \ + } + +#define RET_ICONV_CASE(op) \ + case nir_search_op_##op: \ + switch (bit_size) { \ + case 8: return nir_op_##op##8; \ + case 16: return nir_op_##op##16; \ + case 32: return nir_op_##op##32; \ + case 64: return nir_op_##op##64; \ + default: unreachable("Invalid bit size"); \ + } + + switch (sop) { + RET_FCONV_CASE(i2f) + RET_FCONV_CASE(u2f) + RET_FCONV_CASE(f2f) + RET_ICONV_CASE(f2u) + RET_ICONV_CASE(f2i) + RET_ICONV_CASE(u2u) + RET_ICONV_CASE(i2i) + default: + unreachable("Invalid nir_search_op"); + } + +#undef RET_FCONV_CASE +#undef RET_ICONV_CASE +} + +static bool match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, unsigned num_components, const uint8_t *swizzle, struct match_state *state) @@ -223,7 +299,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, if (expr->cond && !expr->cond(instr)) return false; - if (instr->op != expr->opcode) + if (!nir_op_matches_search_op(instr->op, expr->opcode)) return false; assert(instr->dest.dest.is_ssa); @@ -311,13 +387,15 @@ construct_value(nir_builder *build, switch (value->type) { case nir_search_value_expression: { const nir_search_expression *expr = nir_search_value_as_expression(value); + unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state); + nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size); - if (nir_op_infos[expr->opcode].output_size != 0) - num_components = nir_op_infos[expr->opcode].output_size; + if (nir_op_infos[op].output_size != 0) + num_components = nir_op_infos[op].output_size; - nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode); + nir_alu_instr *alu = nir_alu_instr_create(build->shader, op); nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, - replace_bitsize(value, search_bitsize, state), NULL); + dst_bit_size, NULL); alu->dest.write_mask = (1 << num_components) - 1; alu->dest.saturate = false; @@ -328,7 +406,7 @@ construct_value(nir_builder *build, */ alu->exact = state->has_exact_alu; - for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) { + for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { /* If the source is an explicitly sized source, then we need to reset * the number of components to match. */ diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index a76f39e0f40..cd55bbd0173 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -109,6 +109,16 @@ typedef struct { } data; } nir_search_constant; +enum nir_search_op { + nir_search_op_i2f = nir_last_opcode + 1, + nir_search_op_u2f, + nir_search_op_f2f, + nir_search_op_f2u, + nir_search_op_f2i, + nir_search_op_u2u, + nir_search_op_i2i, +}; + typedef struct { nir_search_value value; @@ -118,7 +128,8 @@ typedef struct { */ bool inexact; - nir_op opcode; + /* One of nir_op or nir_search_op */ + uint16_t opcode; const nir_search_value *srcs[4]; /** Optional condition fxn ptr |