aboutsummaryrefslogtreecommitdiffstats
path: root/src/amd/llvm
diff options
context:
space:
mode:
Diffstat (limited to 'src/amd/llvm')
-rw-r--r--src/amd/llvm/ac_llvm_helper.cpp31
-rw-r--r--src/amd/llvm/ac_llvm_util.h2
-rw-r--r--src/amd/llvm/ac_nir_to_llvm.c7
3 files changed, 40 insertions, 0 deletions
diff --git a/src/amd/llvm/ac_llvm_helper.cpp b/src/amd/llvm/ac_llvm_helper.cpp
index 578521a6f2d..f5383344dd4 100644
--- a/src/amd/llvm/ac_llvm_helper.cpp
+++ b/src/amd/llvm/ac_llvm_helper.cpp
@@ -96,6 +96,11 @@ LLVMBuilderRef ac_create_builder(LLVMContextRef ctx,
*/
flags.setAllowReciprocal(); /* arcp */
+ /* Allow floating-point contraction (e.g. fusing a multiply
+ * followed by an addition into a fused multiply-and-add).
+ */
+ flags.setAllowContract(); /* contract */
+
llvm::unwrap(builder)->setFastMathFlags(flags);
break;
}
@@ -103,6 +108,32 @@ LLVMBuilderRef ac_create_builder(LLVMContextRef ctx,
return builder;
}
+/* Return the original state of inexact math. */
+bool ac_disable_inexact_math(LLVMBuilderRef builder)
+{
+ auto *b = llvm::unwrap(builder);
+ llvm::FastMathFlags flags = b->getFastMathFlags();
+
+ if (!flags.allowContract())
+ return false;
+
+ flags.setAllowContract(false);
+ b->setFastMathFlags(flags);
+ return true;
+}
+
+void ac_restore_inexact_math(LLVMBuilderRef builder, bool value)
+{
+ auto *b = llvm::unwrap(builder);
+ llvm::FastMathFlags flags = b->getFastMathFlags();
+
+ if (flags.allowContract() == value)
+ return;
+
+ flags.setAllowContract(value);
+ b->setFastMathFlags(flags);
+}
+
LLVMTargetLibraryInfoRef
ac_create_target_library_info(const char *triple)
{
diff --git a/src/amd/llvm/ac_llvm_util.h b/src/amd/llvm/ac_llvm_util.h
index 4cfb3b55388..f9650bdf4f1 100644
--- a/src/amd/llvm/ac_llvm_util.h
+++ b/src/amd/llvm/ac_llvm_util.h
@@ -109,6 +109,8 @@ LLVMModuleRef ac_create_module(LLVMTargetMachineRef tm, LLVMContextRef ctx);
LLVMBuilderRef ac_create_builder(LLVMContextRef ctx,
enum ac_float_mode float_mode);
+bool ac_disable_inexact_math(LLVMBuilderRef builder);
+void ac_restore_inexact_math(LLVMBuilderRef builder, bool value);
void
ac_llvm_add_target_dep_function_attr(LLVMValueRef F,
diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c
index 627f5d2d931..03717191e24 100644
--- a/src/amd/llvm/ac_nir_to_llvm.c
+++ b/src/amd/llvm/ac_nir_to_llvm.c
@@ -589,6 +589,10 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
unsigned num_components = instr->dest.dest.ssa.num_components;
unsigned src_components;
LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.dest.ssa);
+ bool saved_inexact = false;
+
+ if (instr->exact)
+ saved_inexact = ac_disable_inexact_math(ctx->ac.builder);
assert(nir_op_infos[instr->op].num_inputs <= ARRAY_SIZE(src));
switch (instr->op) {
@@ -1182,6 +1186,9 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
result = ac_to_integer_or_pointer(&ctx->ac, result);
ctx->ssa_defs[instr->dest.dest.ssa.index] = result;
}
+
+ if (instr->exact)
+ ac_restore_inexact_math(ctx->ac.builder, saved_inexact);
}
static void visit_load_const(struct ac_nir_context *ctx,