summaryrefslogtreecommitdiffstats
path: root/src/compiler/nir
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/nir')
-rw-r--r--src/compiler/nir/nir_opt_if.c128
1 files changed, 128 insertions, 0 deletions
diff --git a/src/compiler/nir/nir_opt_if.c b/src/compiler/nir/nir_opt_if.c
index 512fd92575b..5780ae3794b 100644
--- a/src/compiler/nir/nir_opt_if.c
+++ b/src/compiler/nir/nir_opt_if.c
@@ -23,6 +23,7 @@
#include "nir.h"
#include "nir/nir_builder.h"
+#include "nir_constant_expressions.h"
#include "nir_control_flow.h"
#include "nir_loop_analyze.h"
@@ -400,6 +401,119 @@ evaluate_if_condition(nir_if *nif, nir_cursor cursor, uint32_t *value)
}
}
+/*
+ * This propagates if condition evaluation down the chain of some alu
+ * instructions. For example by checking the use of some of the following alu
+ * instruction we can eventually replace ssa_107 with NIR_TRUE.
+ *
+ * loop {
+ * block block_1:
+ * vec1 32 ssa_85 = load_const (0x00000002)
+ * vec1 32 ssa_86 = ieq ssa_48, ssa_85
+ * vec1 32 ssa_87 = load_const (0x00000001)
+ * vec1 32 ssa_88 = ieq ssa_48, ssa_87
+ * vec1 32 ssa_89 = ior ssa_86, ssa_88
+ * vec1 32 ssa_90 = ieq ssa_48, ssa_0
+ * vec1 32 ssa_91 = ior ssa_89, ssa_90
+ * if ssa_86 {
+ * block block_2:
+ * ...
+ * break
+ * } else {
+ * block block_3:
+ * }
+ * block block_4:
+ * if ssa_88 {
+ * block block_5:
+ * ...
+ * break
+ * } else {
+ * block block_6:
+ * }
+ * block block_7:
+ * if ssa_90 {
+ * block block_8:
+ * ...
+ * break
+ * } else {
+ * block block_9:
+ * }
+ * block block_10:
+ * vec1 32 ssa_107 = inot ssa_91
+ * if ssa_107 {
+ * block block_11:
+ * break
+ * } else {
+ * block block_12:
+ * }
+ * }
+ */
+static bool
+propagate_condition_eval(nir_builder *b, nir_if *nif, nir_src *use_src,
+ nir_src *alu_use, nir_alu_instr *alu,
+ bool is_if_condition)
+{
+ bool progress = false;
+
+ nir_const_value bool_value;
+ b->cursor = nir_before_src(alu_use, is_if_condition);
+ if (nir_op_infos[alu->op].num_inputs == 1) {
+ assert(alu->op == nir_op_inot || alu->op == nir_op_b2i);
+
+ if (evaluate_if_condition(nif, b->cursor, &bool_value.u32[0])) {
+ assert(nir_src_bit_size(alu->src[0].src) == 32);
+
+ nir_const_value result =
+ nir_eval_const_opcode(alu->op, 1, 32, &bool_value);
+
+ replace_if_condition_use_with_const(b, alu_use, result,
+ is_if_condition);
+ progress = true;
+ }
+ } else {
+ assert(alu->op == nir_op_ior || alu->op == nir_op_iand);
+
+ if (evaluate_if_condition(nif, b->cursor, &bool_value.u32[0])) {
+ nir_ssa_def *def[2];
+ for (unsigned i = 0; i < 2; i++) {
+ if (alu->src[i].src.ssa == use_src->ssa) {
+ def[i] = nir_build_imm(b, 1, 32, bool_value);
+ } else {
+ def[i] = alu->src[i].src.ssa;
+ }
+ }
+
+ nir_ssa_def *nalu =
+ nir_build_alu(b, alu->op, def[0], def[1], NULL, NULL);
+
+ /* Rewrite use to use new alu instruction */
+ nir_src new_src = nir_src_for_ssa(nalu);
+
+ if (is_if_condition)
+ nir_if_rewrite_condition(alu_use->parent_if, new_src);
+ else
+ nir_instr_rewrite_src(alu_use->parent_instr, alu_use, new_src);
+
+ progress = true;
+ }
+ }
+
+ return progress;
+}
+
+static bool
+can_propagate_through_alu(nir_src *src)
+{
+ if (src->parent_instr->type == nir_instr_type_alu &&
+ (nir_instr_as_alu(src->parent_instr)->op == nir_op_ior ||
+ nir_instr_as_alu(src->parent_instr)->op == nir_op_iand ||
+ nir_instr_as_alu(src->parent_instr)->op == nir_op_inot ||
+ nir_instr_as_alu(src->parent_instr)->op == nir_op_b2i))
+ return true;
+
+ return false;
+}
+
static bool
evaluate_condition_use(nir_builder *b, nir_if *nif, nir_src *use_src,
bool is_if_condition)
@@ -414,6 +528,20 @@ evaluate_condition_use(nir_builder *b, nir_if *nif, nir_src *use_src,
progress = true;
}
+ if (!is_if_condition && can_propagate_through_alu(use_src)) {
+ nir_alu_instr *alu = nir_instr_as_alu(use_src->parent_instr);
+
+ nir_foreach_use_safe(alu_use, &alu->dest.dest.ssa) {
+ progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu,
+ false);
+ }
+
+ nir_foreach_if_use_safe(alu_use, &alu->dest.dest.ssa) {
+ progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu,
+ true);
+ }
+ }
+
return progress;
}