diff options
author | Ian Romanick <[email protected]> | 2019-08-12 17:28:35 -0700 |
---|---|---|
committer | Ian Romanick <[email protected]> | 2019-09-25 15:37:01 -0700 |
commit | 99ddb41e2d46de613b1b15c9da8fd08a12e24658 (patch) | |
tree | b9a62f6b36f653a323f05b3439e867c2740b84b7 | |
parent | 018d2b524ad399ef9ea6d31de98de7078d702980 (diff) |
nir/range-analysis: Use types in the hash key
This allows the reslut of mov and bcsel to be separately interpreted as
float or int depending on the use.
Reviewed-by: Caio Marcelo de Oliveira Filho <[email protected]>
-rw-r--r-- | src/compiler/nir/nir_range_analysis.c | 136 |
1 files changed, 98 insertions, 38 deletions
diff --git a/src/compiler/nir/nir_range_analysis.c b/src/compiler/nir/nir_range_analysis.c index 298d9946b56..e1f3eb14bce 100644 --- a/src/compiler/nir/nir_range_analysis.c +++ b/src/compiler/nir/nir_range_analysis.c @@ -51,8 +51,41 @@ unpack_data(const void *p) return (struct ssa_result_range){v & 0xff, (v & 0x0ff00) != 0}; } +static void * +pack_key(const struct nir_alu_instr *instr, nir_alu_type type) +{ + uintptr_t type_encoding; + uintptr_t ptr = (uintptr_t) instr; + + /* The low 2 bits have to be zero or this whole scheme falls apart. */ + assert((ptr & 0x3) == 0); + + /* NIR is typeless in the sense that sequences of bits have whatever + * meaning is attached to them by the instruction that consumes them. + * However, the number of bits must match between producer and consumer. + * As a result, the number of bits does not need to be encoded here. + */ + switch (nir_alu_type_get_base_type(type)) { + case nir_type_int: type_encoding = 0; break; + case nir_type_uint: type_encoding = 1; break; + case nir_type_bool: type_encoding = 2; break; + case nir_type_float: type_encoding = 3; break; + default: unreachable("Invalid base type."); + } + + return (void *)(ptr | type_encoding); +} + +static nir_alu_type +nir_alu_src_type(const nir_alu_instr *instr, unsigned src) +{ + return nir_alu_type_get_base_type(nir_op_infos[instr->op].input_types[src]) | + nir_src_bit_size(instr->src[src].src); +} + static struct ssa_result_range -analyze_constant(const struct nir_alu_instr *instr, unsigned src) +analyze_constant(const struct nir_alu_instr *instr, unsigned src, + nir_alu_type use_type) { uint8_t swizzle[4] = { 0, 1, 2, 3 }; @@ -69,7 +102,7 @@ analyze_constant(const struct nir_alu_instr *instr, unsigned src) struct ssa_result_range r = { unknown, false }; - switch (nir_op_infos[instr->op].input_types[src]) { + switch (nir_alu_type_get_base_type(use_type)) { case nir_type_float: { double min_value = DBL_MAX; double max_value = -DBL_MAX; @@ -321,13 +354,13 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b) */ static struct ssa_result_range analyze_expression(const nir_alu_instr *instr, unsigned src, - struct hash_table *ht) + struct hash_table *ht, nir_alu_type use_type) { if (!instr->src[src].src.is_ssa) return (struct ssa_result_range){unknown, false}; if (nir_src_is_const(instr->src[src].src)) - return analyze_constant(instr, src); + return analyze_constant(instr, src, use_type); if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) return (struct ssa_result_range){unknown, false}; @@ -335,8 +368,6 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, const struct nir_alu_instr *const alu = nir_instr_as_alu(instr->src[src].src.ssa->parent_instr); - const nir_alu_type use_type = nir_op_infos[instr->op].input_types[src]; - /* Bail if the type of the instruction generating the value does not match * the type the value will be interpreted as. int/uint/bool can be * reinterpreted trivially. The most important cases are between float and @@ -355,7 +386,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } } - struct hash_entry *he = _mesa_hash_table_search(ht, alu); + struct hash_entry *he = _mesa_hash_table_search(ht, pack_key(alu, use_type)); if (he != NULL) return unpack_data(he->data); @@ -466,8 +497,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_bcsel: { - const struct ssa_result_range left = analyze_expression(alu, 1, ht); - const struct ssa_result_range right = analyze_expression(alu, 2, ht); + const struct ssa_result_range left = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range right = + analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2)); /* If either source is a constant load that is not zero, punt. The type * will always be uint regardless of the actual type. We can't even @@ -545,7 +578,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, case nir_op_i2f32: case nir_op_u2f32: - r = analyze_expression(alu, 0, ht); + r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); r.is_integral = true; @@ -555,7 +588,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_fabs: - r = analyze_expression(alu, 0, ht); + r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); switch (r.range) { case unknown: @@ -577,8 +610,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_fadd: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); - const struct ssa_result_range right = analyze_expression(alu, 1, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range right = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); r.is_integral = left.is_integral && right.is_integral; r.range = fadd_table[left.range][right.range]; @@ -595,7 +630,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, ge_zero, ge_zero, ge_zero, gt_zero, gt_zero, ge_zero, gt_zero }; - r = analyze_expression(alu, 0, ht); + r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(table); ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(table); @@ -606,8 +641,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fmax: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); - const struct ssa_result_range right = analyze_expression(alu, 1, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range right = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); r.is_integral = left.is_integral && right.is_integral; @@ -669,8 +706,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fmin: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); - const struct ssa_result_range right = analyze_expression(alu, 1, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range right = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); r.is_integral = left.is_integral && right.is_integral; @@ -732,8 +771,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fmul: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); - const struct ssa_result_range right = analyze_expression(alu, 1, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range right = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); r.is_integral = left.is_integral && right.is_integral; @@ -753,11 +794,15 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_frcp: - r = (struct ssa_result_range){analyze_expression(alu, 0, ht).range, false}; + r = (struct ssa_result_range){ + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range, + false + }; break; case nir_op_mov: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); /* See commentary in nir_op_bcsel for the reasons this is necessary. */ if (nir_src_is_const(alu->src[0].src) && left.range != eq_zero) @@ -768,13 +813,13 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fneg: - r = analyze_expression(alu, 0, ht); + r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); r.range = fneg_table[r.range]; break; case nir_op_fsat: - r = analyze_expression(alu, 0, ht); + r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); switch (r.range) { case le_zero: @@ -799,7 +844,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_fsign: - r = (struct ssa_result_range){analyze_expression(alu, 0, ht).range, true}; + r = (struct ssa_result_range){ + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range, + true + }; break; case nir_op_fsqrt: @@ -808,7 +856,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_ffloor: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); r.is_integral = true; @@ -823,7 +872,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fceil: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); r.is_integral = true; @@ -838,7 +888,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_ftrunc: { - const struct ssa_result_range left = analyze_expression(alu, 0, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); r.is_integral = true; @@ -919,8 +970,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, /* eq_zero */ { ge_zero, gt_zero, gt_zero, eq_zero, ge_zero, ge_zero, gt_zero }, }; - const struct ssa_result_range left = analyze_expression(alu, 0, ht); - const struct ssa_result_range right = analyze_expression(alu, 1, ht); + const struct ssa_result_range left = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range right = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(table); ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(table); @@ -932,9 +985,12 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_ffma: { - const struct ssa_result_range first = analyze_expression(alu, 0, ht); - const struct ssa_result_range second = analyze_expression(alu, 1, ht); - const struct ssa_result_range third = analyze_expression(alu, 2, ht); + const struct ssa_result_range first = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range second = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range third = + analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2)); r.is_integral = first.is_integral && second.is_integral && third.is_integral; @@ -957,9 +1013,12 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_flrp: { - const struct ssa_result_range first = analyze_expression(alu, 0, ht); - const struct ssa_result_range second = analyze_expression(alu, 1, ht); - const struct ssa_result_range third = analyze_expression(alu, 2, ht); + const struct ssa_result_range first = + analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range second = + analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range third = + analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2)); r.is_integral = first.is_integral && second.is_integral && third.is_integral; @@ -983,7 +1042,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, if (r.range == eq_zero) r.is_integral = true; - _mesa_hash_table_insert(ht, alu, pack_data(r)); + _mesa_hash_table_insert(ht, pack_key(alu, use_type), pack_data(r)); return r; } @@ -994,7 +1053,8 @@ nir_analyze_range(const nir_alu_instr *instr, unsigned src) { struct hash_table *ht = _mesa_pointer_hash_table_create(NULL); - const struct ssa_result_range r = analyze_expression(instr, src, ht); + const struct ssa_result_range r = + analyze_expression(instr, src, ht, nir_alu_src_type(instr, src)); _mesa_hash_table_destroy(ht, NULL); |