summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/compiler/nir/nir.h3
-rw-r--r--src/compiler/nir/nir_lower_double_ops.c29
2 files changed, 31 insertions, 1 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 32eb3b2ca07..de573b45c08 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -2425,7 +2425,8 @@ typedef enum {
nir_lower_dfloor = (1 << 4),
nir_lower_dceil = (1 << 5),
nir_lower_dfract = (1 << 6),
- nir_lower_dround_even = (1 << 7)
+ nir_lower_dround_even = (1 << 7),
+ nir_lower_dmod = (1 << 8)
} nir_lower_doubles_options;
void nir_lower_doubles(nir_shader *shader, nir_lower_doubles_options options);
diff --git a/src/compiler/nir/nir_lower_double_ops.c b/src/compiler/nir/nir_lower_double_ops.c
index 3f831dcf304..ae3a596216e 100644
--- a/src/compiler/nir/nir_lower_double_ops.c
+++ b/src/compiler/nir/nir_lower_double_ops.c
@@ -438,6 +438,24 @@ lower_round_even(nir_builder *b, nir_ssa_def *src)
nir_fsub(b, src, nir_imm_double(b, 0.5)))));
}
+static nir_ssa_def *
+lower_mod(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
+{
+ /* mod(x,y) = x - y * floor(x/y)
+ *
+ * If the division is lowered, it could add some rounding errors that make
+ * floor() to return the quotient minus one when x = N * y. If this is the
+ * case, we return zero because mod(x, y) output value is [0, y).
+ */
+ nir_ssa_def *floor = nir_ffloor(b, nir_fdiv(b, src0, src1));
+ nir_ssa_def *mod = nir_fsub(b, src0, nir_fmul(b, src1, floor));
+
+ return nir_bcsel(b,
+ nir_fne(b, mod, src1),
+ mod,
+ nir_imm_double(b, 0.0));
+}
+
static void
lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
{
@@ -486,6 +504,11 @@ lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
return;
break;
+ case nir_op_fmod:
+ if (!(options & nir_lower_dmod))
+ return;
+ break;
+
default:
return;
}
@@ -525,6 +548,12 @@ lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
result = lower_round_even(&bld, src);
break;
+ case nir_op_fmod: {
+ nir_ssa_def *src1 = nir_fmov_alu(&bld, instr->src[1],
+ instr->dest.dest.ssa.num_components);
+ result = lower_mod(&bld, src, src1);
+ }
+ break;
default:
unreachable("unhandled opcode");
}