aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/compiler/nir/nir.h3
-rw-r--r--src/compiler/nir/nir_lower_double_ops.c58
2 files changed, 60 insertions, 1 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index ac96727cad6..317d71636f4 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -2420,7 +2420,8 @@ typedef enum {
nir_lower_dtrunc = (1 << 3),
nir_lower_dfloor = (1 << 4),
nir_lower_dceil = (1 << 5),
- nir_lower_dfract = (1 << 6)
+ nir_lower_dfract = (1 << 6),
+ nir_lower_dround_even = (1 << 7)
} nir_lower_doubles_options;
void nir_lower_doubles(nir_shader *shader, nir_lower_doubles_options options);
diff --git a/src/compiler/nir/nir_lower_double_ops.c b/src/compiler/nir/nir_lower_double_ops.c
index f1fa2c3fd2d..7505fa394bf 100644
--- a/src/compiler/nir/nir_lower_double_ops.c
+++ b/src/compiler/nir/nir_lower_double_ops.c
@@ -389,6 +389,55 @@ lower_fract(nir_builder *b, nir_ssa_def *src)
return nir_fsub(b, src, nir_ffloor(b, src));
}
+static nir_ssa_def *
+lower_round_even(nir_builder *b, nir_ssa_def *src)
+{
+ /* If fract(src) == 0.5, then we will have to decide the rounding direction.
+ * We will do this by computing the mod(abs(src), 2) and testing if it
+ * is < 1 or not.
+ *
+ * We compute mod(abs(src), 2) as:
+ * abs(src) - 2.0 * floor(abs(src) / 2.0)
+ */
+ nir_ssa_def *two = nir_imm_double(b, 2.0);
+ nir_ssa_def *abs_src = nir_fabs(b, src);
+ nir_ssa_def *mod =
+ nir_fsub(b,
+ abs_src,
+ nir_fmul(b,
+ two,
+ nir_ffloor(b,
+ nir_fmul(b,
+ abs_src,
+ nir_imm_double(b, 0.5)))));
+
+ /*
+ * If fract(src) != 0.5, then we round as floor(src + 0.5)
+ *
+ * If fract(src) == 0.5, then we have to check the modulo:
+ *
+ * if it is < 1 we need a trunc operation so we get:
+ * 0.5 -> 0, -0.5 -> -0
+ * 2.5 -> 2, -2.5 -> -2
+ *
+ * otherwise we need to check if src >= 0, in which case we need to round
+ * upwards, or not, in which case we need to round downwards so we get:
+ * 1.5 -> 2, -1.5 -> -2
+ * 3.5 -> 4, -3.5 -> -4
+ */
+ nir_ssa_def *fract = nir_ffract(b, src);
+ return nir_bcsel(b,
+ nir_fne(b, fract, nir_imm_double(b, 0.5)),
+ nir_ffloor(b, nir_fadd(b, src, nir_imm_double(b, 0.5))),
+ nir_bcsel(b,
+ nir_flt(b, mod, nir_imm_double(b, 1.0)),
+ nir_ftrunc(b, src),
+ nir_bcsel(b,
+ nir_fge(b, src, nir_imm_double(b, 0.0)),
+ nir_fadd(b, src, nir_imm_double(b, 0.5)),
+ nir_fsub(b, src, nir_imm_double(b, 0.5)))));
+}
+
static void
lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
{
@@ -432,6 +481,11 @@ lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
return;
break;
+ case nir_op_fround_even:
+ if (!(options & nir_lower_dround_even))
+ return;
+ break;
+
default:
return;
}
@@ -467,6 +521,10 @@ lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
case nir_op_ffract:
result = lower_fract(&bld, src);
break;
+ case nir_op_fround_even:
+ result = lower_round_even(&bld, src);
+ break;
+
default:
unreachable("unhandled opcode");
}