diff options
author | Connor Abbott <[email protected]> | 2018-11-23 17:34:19 +0100 |
---|---|---|
committer | Connor Abbott <[email protected]> | 2018-12-05 17:57:40 +0100 |
commit | 29a1450e288c57727a5cfe22fa4463a53f9cc8bf (patch) | |
tree | b04dfdcd888f0f6ee7d73422049135db818e08d0 /src/compiler/nir/nir_search.c | |
parent | 49ef89073337ad3c3aefd47592148e6bef0b5ae3 (diff) |
nir/algebraic: Rewrite bit-size inference
Before this commit, there were two copies of the algorithm: one in C,
that we would use to figure out what bit-size to give the replacement
expression, and one in Python, that emulated the C one and tried to
prove that the C algorithm would never fail to correctly assign
bit-sizes. That seemed pretty fragile, and likely to fall over if we
make any changes. Furthermore, the C code was really just recomputing
more-or-less the same thing as the Python code every time. Instead, we
can just store the results of the Python algorithm in the C
datastructure, and consult it to compute the bitsize of each value,
moving the "brains" entirely into Python. Since the Python algorithm no
longer has to match C, it's also a lot easier to change it to something
more closely approximating an actual type-inference algorithm. The
algorithm used is based on Hindley-Milner, although deliberately
weakened a little. It's a few more lines than the old one, judging by
the diffstat, but I think it's easier to verify that it's correct while
being as general as possible.
We could split this up into two changes, first making the C code use the
results of the Python code and then rewriting the Python algorithm, but
since the old algorithm never tracked which variable each equivalence
class, it would mean we'd have to add some non-trivial code which would
then get thrown away. I think it's better to see the final state all at
once, although I could also try splitting it up.
v2:
- Replace instances of "== None" and "!= None" with "is None" and
"is not None".
- Rename first_src to first_unsized_src
- Only merge the destination with the first unsized source, since the
sources have already been merged.
- Add a comment explaining what nir_search_value::bit_size now means.
v3:
- Fix one last instance to use "is not" instead of !=
- Don't try to be so clever when choosing which error message to print
based on whether we're in the search or replace expression.
- Fix trailing whitespace.
Reviewed-by: Jason Ekstrand <[email protected]>
Reviewed-by: Dylan Baker <[email protected]>
Diffstat (limited to 'src/compiler/nir/nir_search.c')
-rw-r--r-- | src/compiler/nir/nir_search.c | 146 |
1 files changed, 17 insertions, 129 deletions
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index 0270302fd3d..a41fca876d5 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -118,7 +118,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; /* If the value has a specific bit size and it doesn't match, bail */ - if (value->bit_size && + if (value->bit_size > 0 && nir_src_bit_size(instr->src[src].src) != value->bit_size) return false; @@ -228,7 +228,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, assert(instr->dest.dest.is_ssa); - if (expr->value.bit_size && + if (expr->value.bit_size > 0 && instr->dest.dest.ssa.bit_size != expr->value.bit_size) return false; @@ -290,128 +290,21 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, } } -typedef struct bitsize_tree { - unsigned num_srcs; - struct bitsize_tree *srcs[4]; - - unsigned common_size; - bool is_src_sized[4]; - bool is_dest_sized; - - unsigned dest_size; - unsigned src_size[4]; -} bitsize_tree; - -static bitsize_tree * -build_bitsize_tree(void *mem_ctx, struct match_state *state, - const nir_search_value *value) -{ - bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree); - - switch (value->type) { - case nir_search_value_expression: { - nir_search_expression *expr = nir_search_value_as_expression(value); - nir_op_info info = nir_op_infos[expr->opcode]; - tree->num_srcs = info.num_inputs; - tree->common_size = 0; - for (unsigned i = 0; i < info.num_inputs; i++) { - tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]); - if (tree->is_src_sized[i]) - tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]); - tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]); - } - tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type); - if (tree->is_dest_sized) - tree->dest_size = nir_alu_type_get_type_size(info.output_type); - break; - } - - case nir_search_value_variable: { - nir_search_variable *var = nir_search_value_as_variable(value); - tree->num_srcs = 0; - tree->is_dest_sized = true; - tree->dest_size = nir_src_bit_size(state->variables[var->variable].src); - break; - } - - case nir_search_value_constant: { - tree->num_srcs = 0; - tree->is_dest_sized = false; - tree->common_size = 0; - break; - } - } - - if (value->bit_size) { - assert(!tree->is_dest_sized || tree->dest_size == value->bit_size); - tree->common_size = value->bit_size; - } - - return tree; -} - static unsigned -bitsize_tree_filter_up(bitsize_tree *tree) +replace_bitsize(const nir_search_value *value, unsigned search_bitsize, + struct match_state *state) { - for (unsigned i = 0; i < tree->num_srcs; i++) { - unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]); - if (src_size == 0) - continue; - - if (tree->is_src_sized[i]) { - assert(src_size == tree->src_size[i]); - } else if (tree->common_size != 0) { - assert(src_size == tree->common_size); - tree->src_size[i] = src_size; - } else { - tree->common_size = src_size; - tree->src_size[i] = src_size; - } - } - - if (tree->num_srcs && tree->common_size) { - if (tree->dest_size == 0) - tree->dest_size = tree->common_size; - else if (!tree->is_dest_sized) - assert(tree->dest_size == tree->common_size); - - for (unsigned i = 0; i < tree->num_srcs; i++) { - if (!tree->src_size[i]) - tree->src_size[i] = tree->common_size; - } - } - - return tree->dest_size; -} - -static void -bitsize_tree_filter_down(bitsize_tree *tree, unsigned size) -{ - if (tree->dest_size) - assert(tree->dest_size == size); - else - tree->dest_size = size; - - if (!tree->is_dest_sized) { - if (tree->common_size) - assert(tree->common_size == size); - else - tree->common_size = size; - } - - for (unsigned i = 0; i < tree->num_srcs; i++) { - if (!tree->src_size[i]) { - assert(tree->common_size); - tree->src_size[i] = tree->common_size; - } - bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]); - } + if (value->bit_size > 0) + return value->bit_size; + if (value->bit_size < 0) + return nir_src_bit_size(state->variables[-value->bit_size - 1].src); + return search_bitsize; } static nir_alu_src construct_value(nir_builder *build, const nir_search_value *value, - unsigned num_components, bitsize_tree *bitsize, + unsigned num_components, unsigned search_bitsize, struct match_state *state, nir_instr *instr) { @@ -424,7 +317,7 @@ construct_value(nir_builder *build, nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode); nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, - bitsize->dest_size, NULL); + replace_bitsize(value, search_bitsize, state), NULL); alu->dest.write_mask = (1 << num_components) - 1; alu->dest.saturate = false; @@ -443,7 +336,7 @@ construct_value(nir_builder *build, num_components = nir_op_infos[alu->op].input_sizes[i]; alu->src[i] = construct_value(build, expr->srcs[i], - num_components, bitsize->srcs[i], + num_components, search_bitsize, state, instr); } @@ -472,16 +365,17 @@ construct_value(nir_builder *build, case nir_search_value_constant: { const nir_search_constant *c = nir_search_value_as_constant(value); + unsigned bit_size = replace_bitsize(value, search_bitsize, state); nir_ssa_def *cval; switch (c->type) { case nir_type_float: - cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size); + cval = nir_imm_floatN_t(build, c->data.d, bit_size); break; case nir_type_int: case nir_type_uint: - cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size); + cval = nir_imm_intN_t(build, c->data.i, bit_size); break; case nir_type_bool: @@ -526,16 +420,12 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, swizzle, &state)) return NULL; - void *bitsize_ctx = ralloc_context(NULL); - bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace); - bitsize_tree_filter_up(tree); - bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size); - build->cursor = nir_before_instr(&instr->instr); nir_alu_src val = construct_value(build, replace, instr->dest.dest.ssa.num_components, - tree, &state, &instr->instr); + instr->dest.dest.ssa.bit_size, + &state, &instr->instr); /* Inserting a mov may be unnecessary. However, it's much easier to * simply let copy propagation clean this up than to try to go through @@ -551,7 +441,5 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, */ nir_instr_remove(&instr->instr); - ralloc_free(bitsize_ctx); - return ssa_val; } |