diff options
Diffstat (limited to 'src/compiler/nir')
-rw-r--r-- | src/compiler/nir/nir.h | 3 | ||||
-rw-r--r-- | src/compiler/nir/nir_lower_double_ops.c | 29 |
2 files changed, 31 insertions, 1 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 32eb3b2ca07..de573b45c08 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -2425,7 +2425,8 @@ typedef enum { nir_lower_dfloor = (1 << 4), nir_lower_dceil = (1 << 5), nir_lower_dfract = (1 << 6), - nir_lower_dround_even = (1 << 7) + nir_lower_dround_even = (1 << 7), + nir_lower_dmod = (1 << 8) } nir_lower_doubles_options; void nir_lower_doubles(nir_shader *shader, nir_lower_doubles_options options); diff --git a/src/compiler/nir/nir_lower_double_ops.c b/src/compiler/nir/nir_lower_double_ops.c index 3f831dcf304..ae3a596216e 100644 --- a/src/compiler/nir/nir_lower_double_ops.c +++ b/src/compiler/nir/nir_lower_double_ops.c @@ -438,6 +438,24 @@ lower_round_even(nir_builder *b, nir_ssa_def *src) nir_fsub(b, src, nir_imm_double(b, 0.5))))); } +static nir_ssa_def * +lower_mod(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1) +{ + /* mod(x,y) = x - y * floor(x/y) + * + * If the division is lowered, it could add some rounding errors that make + * floor() to return the quotient minus one when x = N * y. If this is the + * case, we return zero because mod(x, y) output value is [0, y). + */ + nir_ssa_def *floor = nir_ffloor(b, nir_fdiv(b, src0, src1)); + nir_ssa_def *mod = nir_fsub(b, src0, nir_fmul(b, src1, floor)); + + return nir_bcsel(b, + nir_fne(b, mod, src1), + mod, + nir_imm_double(b, 0.0)); +} + static void lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options) { @@ -486,6 +504,11 @@ lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options) return; break; + case nir_op_fmod: + if (!(options & nir_lower_dmod)) + return; + break; + default: return; } @@ -525,6 +548,12 @@ lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options) result = lower_round_even(&bld, src); break; + case nir_op_fmod: { + nir_ssa_def *src1 = nir_fmov_alu(&bld, instr->src[1], + instr->dest.dest.ssa.num_components); + result = lower_mod(&bld, src, src1); + } + break; default: unreachable("unhandled opcode"); } |