summaryrefslogtreecommitdiffstats
path: root/src/compiler/nir/nir_search.c
diff options
context:
space:
mode:
authorConnor Abbott <[email protected]>2018-11-23 17:34:19 +0100
committerConnor Abbott <[email protected]>2018-12-05 17:57:40 +0100
commit29a1450e288c57727a5cfe22fa4463a53f9cc8bf (patch)
treeb04dfdcd888f0f6ee7d73422049135db818e08d0 /src/compiler/nir/nir_search.c
parent49ef89073337ad3c3aefd47592148e6bef0b5ae3 (diff)
nir/algebraic: Rewrite bit-size inference
Before this commit, there were two copies of the algorithm: one in C, that we would use to figure out what bit-size to give the replacement expression, and one in Python, that emulated the C one and tried to prove that the C algorithm would never fail to correctly assign bit-sizes. That seemed pretty fragile, and likely to fall over if we make any changes. Furthermore, the C code was really just recomputing more-or-less the same thing as the Python code every time. Instead, we can just store the results of the Python algorithm in the C datastructure, and consult it to compute the bitsize of each value, moving the "brains" entirely into Python. Since the Python algorithm no longer has to match C, it's also a lot easier to change it to something more closely approximating an actual type-inference algorithm. The algorithm used is based on Hindley-Milner, although deliberately weakened a little. It's a few more lines than the old one, judging by the diffstat, but I think it's easier to verify that it's correct while being as general as possible. We could split this up into two changes, first making the C code use the results of the Python code and then rewriting the Python algorithm, but since the old algorithm never tracked which variable each equivalence class, it would mean we'd have to add some non-trivial code which would then get thrown away. I think it's better to see the final state all at once, although I could also try splitting it up. v2: - Replace instances of "== None" and "!= None" with "is None" and "is not None". - Rename first_src to first_unsized_src - Only merge the destination with the first unsized source, since the sources have already been merged. - Add a comment explaining what nir_search_value::bit_size now means. v3: - Fix one last instance to use "is not" instead of != - Don't try to be so clever when choosing which error message to print based on whether we're in the search or replace expression. - Fix trailing whitespace. Reviewed-by: Jason Ekstrand <[email protected]> Reviewed-by: Dylan Baker <[email protected]>
Diffstat (limited to 'src/compiler/nir/nir_search.c')
-rw-r--r--src/compiler/nir/nir_search.c146
1 files changed, 17 insertions, 129 deletions
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 0270302fd3d..a41fca876d5 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -118,7 +118,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
/* If the value has a specific bit size and it doesn't match, bail */
- if (value->bit_size &&
+ if (value->bit_size > 0 &&
nir_src_bit_size(instr->src[src].src) != value->bit_size)
return false;
@@ -228,7 +228,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
assert(instr->dest.dest.is_ssa);
- if (expr->value.bit_size &&
+ if (expr->value.bit_size > 0 &&
instr->dest.dest.ssa.bit_size != expr->value.bit_size)
return false;
@@ -290,128 +290,21 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
}
}
-typedef struct bitsize_tree {
- unsigned num_srcs;
- struct bitsize_tree *srcs[4];
-
- unsigned common_size;
- bool is_src_sized[4];
- bool is_dest_sized;
-
- unsigned dest_size;
- unsigned src_size[4];
-} bitsize_tree;
-
-static bitsize_tree *
-build_bitsize_tree(void *mem_ctx, struct match_state *state,
- const nir_search_value *value)
-{
- bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
-
- switch (value->type) {
- case nir_search_value_expression: {
- nir_search_expression *expr = nir_search_value_as_expression(value);
- nir_op_info info = nir_op_infos[expr->opcode];
- tree->num_srcs = info.num_inputs;
- tree->common_size = 0;
- for (unsigned i = 0; i < info.num_inputs; i++) {
- tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
- if (tree->is_src_sized[i])
- tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
- tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
- }
- tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
- if (tree->is_dest_sized)
- tree->dest_size = nir_alu_type_get_type_size(info.output_type);
- break;
- }
-
- case nir_search_value_variable: {
- nir_search_variable *var = nir_search_value_as_variable(value);
- tree->num_srcs = 0;
- tree->is_dest_sized = true;
- tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
- break;
- }
-
- case nir_search_value_constant: {
- tree->num_srcs = 0;
- tree->is_dest_sized = false;
- tree->common_size = 0;
- break;
- }
- }
-
- if (value->bit_size) {
- assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
- tree->common_size = value->bit_size;
- }
-
- return tree;
-}
-
static unsigned
-bitsize_tree_filter_up(bitsize_tree *tree)
+replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
+ struct match_state *state)
{
- for (unsigned i = 0; i < tree->num_srcs; i++) {
- unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
- if (src_size == 0)
- continue;
-
- if (tree->is_src_sized[i]) {
- assert(src_size == tree->src_size[i]);
- } else if (tree->common_size != 0) {
- assert(src_size == tree->common_size);
- tree->src_size[i] = src_size;
- } else {
- tree->common_size = src_size;
- tree->src_size[i] = src_size;
- }
- }
-
- if (tree->num_srcs && tree->common_size) {
- if (tree->dest_size == 0)
- tree->dest_size = tree->common_size;
- else if (!tree->is_dest_sized)
- assert(tree->dest_size == tree->common_size);
-
- for (unsigned i = 0; i < tree->num_srcs; i++) {
- if (!tree->src_size[i])
- tree->src_size[i] = tree->common_size;
- }
- }
-
- return tree->dest_size;
-}
-
-static void
-bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
-{
- if (tree->dest_size)
- assert(tree->dest_size == size);
- else
- tree->dest_size = size;
-
- if (!tree->is_dest_sized) {
- if (tree->common_size)
- assert(tree->common_size == size);
- else
- tree->common_size = size;
- }
-
- for (unsigned i = 0; i < tree->num_srcs; i++) {
- if (!tree->src_size[i]) {
- assert(tree->common_size);
- tree->src_size[i] = tree->common_size;
- }
- bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
- }
+ if (value->bit_size > 0)
+ return value->bit_size;
+ if (value->bit_size < 0)
+ return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
+ return search_bitsize;
}
static nir_alu_src
construct_value(nir_builder *build,
const nir_search_value *value,
- unsigned num_components, bitsize_tree *bitsize,
+ unsigned num_components, unsigned search_bitsize,
struct match_state *state,
nir_instr *instr)
{
@@ -424,7 +317,7 @@ construct_value(nir_builder *build,
nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);
nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
- bitsize->dest_size, NULL);
+ replace_bitsize(value, search_bitsize, state), NULL);
alu->dest.write_mask = (1 << num_components) - 1;
alu->dest.saturate = false;
@@ -443,7 +336,7 @@ construct_value(nir_builder *build,
num_components = nir_op_infos[alu->op].input_sizes[i];
alu->src[i] = construct_value(build, expr->srcs[i],
- num_components, bitsize->srcs[i],
+ num_components, search_bitsize,
state, instr);
}
@@ -472,16 +365,17 @@ construct_value(nir_builder *build,
case nir_search_value_constant: {
const nir_search_constant *c = nir_search_value_as_constant(value);
+ unsigned bit_size = replace_bitsize(value, search_bitsize, state);
nir_ssa_def *cval;
switch (c->type) {
case nir_type_float:
- cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size);
+ cval = nir_imm_floatN_t(build, c->data.d, bit_size);
break;
case nir_type_int:
case nir_type_uint:
- cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size);
+ cval = nir_imm_intN_t(build, c->data.i, bit_size);
break;
case nir_type_bool:
@@ -526,16 +420,12 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
swizzle, &state))
return NULL;
- void *bitsize_ctx = ralloc_context(NULL);
- bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
- bitsize_tree_filter_up(tree);
- bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
-
build->cursor = nir_before_instr(&instr->instr);
nir_alu_src val = construct_value(build, replace,
instr->dest.dest.ssa.num_components,
- tree, &state, &instr->instr);
+ instr->dest.dest.ssa.bit_size,
+ &state, &instr->instr);
/* Inserting a mov may be unnecessary. However, it's much easier to
* simply let copy propagation clean this up than to try to go through
@@ -551,7 +441,5 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
*/
nir_instr_remove(&instr->instr);
- ralloc_free(bitsize_ctx);
-
return ssa_val;
}