summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorIago Toral Quiroga <[email protected]>2018-04-18 10:14:11 +0200
committerIago Toral Quiroga <[email protected]>2019-01-02 07:54:05 +0100
commit7d3c34197a5d357473f8b3090dd24d6e0dfea2e4 (patch)
treefb212a4b56d2fe8a40c131af4c1af2bb6f4a985b
parent88663ba67c3438d7bac003fc6060a602f5189c39 (diff)
compiler/spirv: implement 16-bit hyperbolic trigonometric functions
v2: - use nir_fadd_imm and nir_fmul_imm helpers (Jason) v3: - since we need to define one for fsub use it for fdiv too (Jason) Reviewed-by: Jason Ekstrand <[email protected]>
-rw-r--r--src/compiler/spirv/vtn_glsl450.c44
1 files changed, 26 insertions, 18 deletions
diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c
index 7984c7cc776..396ec641562 100644
--- a/src/compiler/spirv/vtn_glsl450.c
+++ b/src/compiler/spirv/vtn_glsl450.c
@@ -654,17 +654,17 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
case GLSLstd450Sinh:
/* 0.5 * (e^x - e^(-x)) */
val->ssa->def =
- nir_fmul(nb, nir_imm_float(nb, 0.5f),
- nir_fsub(nb, build_exp(nb, src[0]),
- build_exp(nb, nir_fneg(nb, src[0]))));
+ nir_fmul_imm(nb, nir_fsub(nb, build_exp(nb, src[0]),
+ build_exp(nb, nir_fneg(nb, src[0]))),
+ 0.5f);
return;
case GLSLstd450Cosh:
/* 0.5 * (e^x + e^(-x)) */
val->ssa->def =
- nir_fmul(nb, nir_imm_float(nb, 0.5f),
- nir_fadd(nb, build_exp(nb, src[0]),
- build_exp(nb, nir_fneg(nb, src[0]))));
+ nir_fmul_imm(nb, nir_fadd(nb, build_exp(nb, src[0]),
+ build_exp(nb, nir_fneg(nb, src[0]))),
+ 0.5f);
return;
case GLSLstd450Tanh: {
@@ -675,30 +675,38 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
* We clamp x to (-inf, +10] to avoid precision problems. When x > 10,
* e^2x is so much larger than 1.0 that 1.0 gets flushed to zero in the
* computation e^2x +/- 1 so it can be ignored.
+ *
+ * For 16-bit precision we clamp x to (-inf, +4.2] since the maximum
+ * representable number is only 65,504 and e^(2*6) exceeds that. Also,
+ * if x > 4.2, tanh(x) will return 1.0 in fp16.
*/
- nir_ssa_def *x = nir_fmin(nb, src[0], nir_imm_float(nb, 10));
- nir_ssa_def *exp2x = build_exp(nb, nir_fmul(nb, x, nir_imm_float(nb, 2)));
- val->ssa->def = nir_fdiv(nb, nir_fsub(nb, exp2x, nir_imm_float(nb, 1)),
- nir_fadd(nb, exp2x, nir_imm_float(nb, 1)));
+ const uint32_t bit_size = src[0]->bit_size;
+ const double clamped_x = bit_size > 16 ? 10.0 : 4.2;
+ nir_ssa_def *x = nir_fmin(nb, src[0],
+ nir_imm_floatN_t(nb, clamped_x, bit_size));
+ nir_ssa_def *exp2x = build_exp(nb, nir_fmul_imm(nb, x, 2.0));
+ val->ssa->def = nir_fdiv(nb, nir_fadd_imm(nb, exp2x, -1.0),
+ nir_fadd_imm(nb, exp2x, 1.0));
return;
}
case GLSLstd450Asinh:
val->ssa->def = nir_fmul(nb, nir_fsign(nb, src[0]),
build_log(nb, nir_fadd(nb, nir_fabs(nb, src[0]),
- nir_fsqrt(nb, nir_fadd(nb, nir_fmul(nb, src[0], src[0]),
- nir_imm_float(nb, 1.0f))))));
+ nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]),
+ 1.0f)))));
return;
case GLSLstd450Acosh:
val->ssa->def = build_log(nb, nir_fadd(nb, src[0],
- nir_fsqrt(nb, nir_fsub(nb, nir_fmul(nb, src[0], src[0]),
- nir_imm_float(nb, 1.0f)))));
+ nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]),
+ -1.0f))));
return;
case GLSLstd450Atanh: {
- nir_ssa_def *one = nir_imm_float(nb, 1.0);
- val->ssa->def = nir_fmul(nb, nir_imm_float(nb, 0.5f),
- build_log(nb, nir_fdiv(nb, nir_fadd(nb, one, src[0]),
- nir_fsub(nb, one, src[0]))));
+ nir_ssa_def *one = nir_imm_floatN_t(nb, 1.0, src[0]->bit_size);
+ val->ssa->def =
+ nir_fmul_imm(nb, build_log(nb, nir_fdiv(nb, nir_fadd(nb, src[0], one),
+ nir_fsub(nb, one, src[0]))),
+ 0.5f);
return;
}