summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/compiler/nir/nir.h4
-rw-r--r--src/compiler/nir/nir_instr_set.c104
-rw-r--r--src/compiler/nir/tests/negative_equal_tests.cpp84
3 files changed, 192 insertions, 0 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index e123a59cca8..3ddf97bb12c 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -997,6 +997,10 @@ bool nir_const_value_negative_equal(const nir_const_value *c1,
bool nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2,
unsigned src1, unsigned src2);
+bool nir_alu_srcs_negative_equal(const nir_alu_instr *alu1,
+ const nir_alu_instr *alu2,
+ unsigned src1, unsigned src2);
+
typedef enum {
nir_deref_type_var,
nir_deref_type_array,
diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c
index 1307fe2f3c9..9aa1f3bbe5e 100644
--- a/src/compiler/nir/nir_instr_set.c
+++ b/src/compiler/nir/nir_instr_set.c
@@ -276,6 +276,20 @@ nir_srcs_equal(nir_src src1, nir_src src2)
}
}
+/**
+ * If the \p s is an SSA value that was generated by a negation instruction,
+ * that instruction is returned as a \c nir_alu_instr. Otherwise \c NULL is
+ * returned.
+ */
+static const struct nir_alu_instr *
+get_neg_instr(const nir_src *s)
+{
+ const struct nir_alu_instr *const alu = nir_src_as_alu_instr_const(s);
+
+ return alu != NULL && (alu->op == nir_op_fneg || alu->op == nir_op_ineg)
+ ? alu : NULL;
+}
+
bool
nir_const_value_negative_equal(const nir_const_value *c1,
const nir_const_value *c2,
@@ -377,6 +391,96 @@ nir_const_value_negative_equal(const nir_const_value *c1,
return false;
}
+/**
+ * Shallow compare of ALU srcs to determine if one is the negation of the other
+ *
+ * This function detects cases where \p alu1 is a constant and \p alu2 is a
+ * constant that is its negation. It will also detect cases where \p alu2 is
+ * an SSA value that is a \c nir_op_fneg applied to \p alu1 (and vice versa).
+ *
+ * This function does not detect the general case when \p alu1 and \p alu2 are
+ * SSA values that are the negations of each other (e.g., \p alu1 represents
+ * (a * b) and \p alu2 represents (-a * b)).
+ */
+bool
+nir_alu_srcs_negative_equal(const nir_alu_instr *alu1,
+ const nir_alu_instr *alu2,
+ unsigned src1, unsigned src2)
+{
+ if (alu1->src[src1].abs != alu2->src[src2].abs)
+ return false;
+
+ bool parity = alu1->src[src1].negate != alu2->src[src2].negate;
+
+ /* Handling load_const instructions is tricky. */
+
+ const nir_const_value *const const1 =
+ nir_src_as_const_value(alu1->src[src1].src);
+
+ if (const1 != NULL) {
+ /* Assume that constant folding will eliminate source mods and unary
+ * ops.
+ */
+ if (parity)
+ return false;
+
+ const nir_const_value *const const2 =
+ nir_src_as_const_value(alu2->src[src2].src);
+
+ if (const2 == NULL)
+ return false;
+
+ /* FINISHME: Apply the swizzle? */
+ return nir_const_value_negative_equal(const1,
+ const2,
+ nir_ssa_alu_instr_src_components(alu1, src1),
+ nir_op_infos[alu1->op].input_types[src1],
+ alu1->dest.dest.ssa.bit_size);
+ }
+
+ uint8_t alu1_swizzle[4] = {};
+ nir_src alu1_actual_src;
+ const struct nir_alu_instr *const neg1 = get_neg_instr(&alu1->src[src1].src);
+
+ if (neg1) {
+ parity = !parity;
+ alu1_actual_src = neg1->src[0].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg1, 0); i++)
+ alu1_swizzle[i] = neg1->src[0].swizzle[i];
+ } else {
+ alu1_actual_src = alu1->src[src1].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++)
+ alu1_swizzle[i] = i;
+ }
+
+ uint8_t alu2_swizzle[4] = {};
+ nir_src alu2_actual_src;
+ const struct nir_alu_instr *const neg2 = get_neg_instr(&alu2->src[src2].src);
+
+ if (neg2) {
+ parity = !parity;
+ alu2_actual_src = neg2->src[0].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg2, 0); i++)
+ alu2_swizzle[i] = neg2->src[0].swizzle[i];
+ } else {
+ alu2_actual_src = alu2->src[src2].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu2, src2); i++)
+ alu2_swizzle[i] = i;
+ }
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++) {
+ if (alu1_swizzle[alu1->src[src1].swizzle[i]] !=
+ alu2_swizzle[alu2->src[src2].swizzle[i]])
+ return false;
+ }
+
+ return parity && nir_srcs_equal(alu1_actual_src, alu2_actual_src);
+}
+
bool
nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2,
unsigned src1, unsigned src2)
diff --git a/src/compiler/nir/tests/negative_equal_tests.cpp b/src/compiler/nir/tests/negative_equal_tests.cpp
index e450a8172db..b38a0c10da5 100644
--- a/src/compiler/nir/tests/negative_equal_tests.cpp
+++ b/src/compiler/nir/tests/negative_equal_tests.cpp
@@ -22,6 +22,7 @@
*/
#include <gtest/gtest.h>
#include "nir.h"
+#include "nir_builder.h"
#include "util/half_float.h"
static nir_const_value count_sequence(nir_alu_type base_type, unsigned bits,
@@ -47,6 +48,21 @@ protected:
nir_const_value c2;
};
+class alu_srcs_negative_equal_test : public ::testing::Test {
+protected:
+ alu_srcs_negative_equal_test()
+ {
+ static const nir_shader_compiler_options options = { };
+ nir_builder_init_simple_shader(&bld, NULL, MESA_SHADER_VERTEX, &options);
+ }
+
+ ~alu_srcs_negative_equal_test()
+ {
+ ralloc_free(bld.shader);
+ }
+
+ struct nir_builder bld;
+};
TEST_F(const_value_negative_equal_test, float32_zero)
{
@@ -130,6 +146,74 @@ compare_fewer_components(nir_type_uint, 32)
compare_fewer_components(nir_type_int, 64)
compare_fewer_components(nir_type_uint, 64)
+TEST_F(alu_srcs_negative_equal_test, trivial_float)
+{
+ nir_ssa_def *two = nir_imm_float(&bld, 2.0f);
+ nir_ssa_def *negative_two = nir_imm_float(&bld, -2.0f);
+
+ nir_ssa_def *result = nir_fadd(&bld, two, negative_two);
+ nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
+
+ ASSERT_NE((void *) 0, instr);
+ EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
+}
+
+TEST_F(alu_srcs_negative_equal_test, trivial_int)
+{
+ nir_ssa_def *two = nir_imm_int(&bld, 2);
+ nir_ssa_def *negative_two = nir_imm_int(&bld, -2);
+
+ nir_ssa_def *result = nir_iadd(&bld, two, negative_two);
+ nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
+
+ ASSERT_NE((void *) 0, instr);
+ EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
+}
+
+TEST_F(alu_srcs_negative_equal_test, trivial_negation_float)
+{
+ /* Cannot just do the negation of a nir_load_const_instr because
+ * nir_alu_srcs_negative_equal expects that constant folding will convert
+ * fneg(2.0) to just -2.0.
+ */
+ nir_ssa_def *two = nir_imm_float(&bld, 2.0f);
+ nir_ssa_def *two_plus_two = nir_fadd(&bld, two, two);
+ nir_ssa_def *negation = nir_fneg(&bld, two_plus_two);
+
+ nir_ssa_def *result = nir_fadd(&bld, two_plus_two, negation);
+
+ nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
+
+ ASSERT_NE((void *) 0, instr);
+ EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
+}
+
+TEST_F(alu_srcs_negative_equal_test, trivial_negation_int)
+{
+ /* Cannot just do the negation of a nir_load_const_instr because
+ * nir_alu_srcs_negative_equal expects that constant folding will convert
+ * ineg(2) to just -2.
+ */
+ nir_ssa_def *two = nir_imm_int(&bld, 2);
+ nir_ssa_def *two_plus_two = nir_iadd(&bld, two, two);
+ nir_ssa_def *negation = nir_ineg(&bld, two_plus_two);
+
+ nir_ssa_def *result = nir_iadd(&bld, two_plus_two, negation);
+
+ nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
+
+ ASSERT_NE((void *) 0, instr);
+ EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
+ EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
+}
+
static nir_const_value
count_sequence(nir_alu_type base_type, unsigned bits, int first)
{