summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/amd/compiler/aco_instruction_selection.cpp400
-rw-r--r--src/amd/compiler/aco_instruction_selection_setup.cpp63
-rw-r--r--src/amd/compiler/aco_lower_bool_phis.cpp8
3 files changed, 184 insertions, 287 deletions
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index ab34a068671..a7c3c703403 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -130,6 +130,8 @@ Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_ne
if (!dst.id())
dst = bld.tmp(src.regClass());
+ assert(src.size() == dst.size());
+
if (ctx->stage != fragment_fs) {
if (!dst.id())
return src;
@@ -331,33 +333,31 @@ void expand_vector(isel_context* ctx, Temp vec_src, Temp dst, unsigned num_compo
ctx->allocated_vec.emplace(dst.id(), elems);
}
-Temp as_divergent_bool(isel_context *ctx, Temp val, bool vcc_hint)
+Temp bool_to_vector_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s2))
{
- if (val.regClass() == s2) {
- return val;
- } else {
- assert(val.regClass() == s1);
- Builder bld(ctx->program, ctx->block);
- Definition& def = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2),
- Operand((uint32_t) -1), Operand(0u), bld.scc(val)).def(0);
- if (vcc_hint)
- def.setHint(vcc);
- return def.getTemp();
- }
+ Builder bld(ctx->program, ctx->block);
+ if (!dst.id())
+ dst = bld.tmp(s2);
+
+ assert(val.regClass() == s1);
+ assert(dst.regClass() == s2);
+
+ return bld.sop2(aco_opcode::s_cselect_b64, bld.hint_vcc(Definition(dst)), Operand((uint32_t) -1), Operand(0u), bld.scc(val));
}
-Temp as_uniform_bool(isel_context *ctx, Temp val)
+Temp bool_to_scalar_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s1))
{
- if (val.regClass() == s1) {
- return val;
- } else {
- assert(val.regClass() == s2);
- Builder bld(ctx->program, ctx->block);
- /* if we're currently in WQM mode, ensure that the source is also computed in WQM */
- Temp tmp = bld.tmp(s1);
- bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2)).def(1).getTemp();
- return emit_wqm(ctx, tmp);
- }
+ Builder bld(ctx->program, ctx->block);
+ if (!dst.id())
+ dst = bld.tmp(s1);
+
+ assert(val.regClass() == s2);
+ assert(dst.regClass() == s1);
+
+ /* if we're currently in WQM mode, ensure that the source is also computed in WQM */
+ Temp tmp = bld.tmp(s1);
+ bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2));
+ return emit_wqm(ctx, tmp, dst);
}
Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1)
@@ -526,27 +526,44 @@ void emit_vopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o
src1 = as_vgpr(ctx, src1);
}
}
+
Builder bld(ctx->program, ctx->block);
- bld.vopc(op, Definition(dst), src0, src1).def(0).setHint(vcc);
+ bld.vopc(op, bld.hint_vcc(Definition(dst)), src0, src1);
}
-void emit_comparison(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst)
+void emit_sopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst)
{
- if (dst.regClass() == s2) {
- emit_vopc_instruction(ctx, instr, op, dst);
- if (!ctx->divergent_vals[instr->dest.dest.ssa.index])
- emit_split_vector(ctx, dst, 2);
- } else if (dst.regClass() == s1) {
- Temp src0 = get_alu_src(ctx, instr->src[0]);
- Temp src1 = get_alu_src(ctx, instr->src[1]);
- assert(src0.type() == RegType::sgpr && src1.type() == RegType::sgpr);
+ Temp src0 = get_alu_src(ctx, instr->src[0]);
+ Temp src1 = get_alu_src(ctx, instr->src[1]);
- Builder bld(ctx->program, ctx->block);
- bld.sopc(op, bld.scc(Definition(dst)), src0, src1);
+ assert(dst.regClass() == s2);
+ assert(src0.type() == RegType::sgpr);
+ assert(src1.type() == RegType::sgpr);
- } else {
- assert(false);
- }
+ Builder bld(ctx->program, ctx->block);
+ /* Emit the SALU comparison instruction */
+ Temp cmp = bld.sopc(op, bld.scc(bld.def(s1)), src0, src1);
+ /* Turn the result into a per-lane bool */
+ bool_to_vector_condition(ctx, cmp, dst);
+}
+
+void emit_comparison(isel_context *ctx, nir_alu_instr *instr, Temp dst,
+ aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::last_opcode, aco_opcode s64_op = aco_opcode::last_opcode)
+{
+ aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : s32_op;
+ aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : v32_op;
+ bool divergent_vals = ctx->divergent_vals[instr->dest.dest.ssa.index];
+ bool use_valu = s_op == aco_opcode::last_opcode ||
+ divergent_vals ||
+ ctx->allocated[instr->src[0].src.ssa->index].type() == RegType::vgpr ||
+ ctx->allocated[instr->src[1].src.ssa->index].type() == RegType::vgpr;
+ aco_opcode op = use_valu ? v_op : s_op;
+ assert(op != aco_opcode::last_opcode);
+
+ if (use_valu)
+ emit_vopc_instruction(ctx, instr, op, dst);
+ else
+ emit_sopc_instruction(ctx, instr, op, dst);
}
void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32, aco_opcode op64, Temp dst)
@@ -554,16 +571,13 @@ void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32
Builder bld(ctx->program, ctx->block);
Temp src0 = get_alu_src(ctx, instr->src[0]);
Temp src1 = get_alu_src(ctx, instr->src[1]);
- if (dst.regClass() == s2) {
- bld.sop2(op64, Definition(dst), bld.def(s1, scc),
- as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false));
- } else {
- assert(dst.regClass() == s1);
- bld.sop2(op32, bld.def(s1), bld.scc(Definition(dst)),
- as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1));
- }
-}
+ assert(dst.regClass() == s2);
+ assert(src0.regClass() == s2);
+ assert(src1.regClass() == s2);
+
+ bld.sop2(op64, Definition(dst), bld.def(s1, scc), src0, src1);
+}
void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
{
@@ -572,9 +586,9 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
Temp then = get_alu_src(ctx, instr->src[1]);
Temp els = get_alu_src(ctx, instr->src[2]);
- if (dst.type() == RegType::vgpr) {
- cond = as_divergent_bool(ctx, cond, true);
+ assert(cond.regClass() == s2);
+ if (dst.type() == RegType::vgpr) {
aco_ptr<Instruction> bcsel;
if (dst.size() == 1) {
then = as_vgpr(ctx, then);
@@ -599,11 +613,17 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
return;
}
- if (instr->dest.dest.ssa.bit_size != 1) { /* uniform condition and values in sgpr */
+ if (instr->dest.dest.ssa.bit_size == 1) {
+ assert(dst.regClass() == s2);
+ assert(then.regClass() == s2);
+ assert(els.regClass() == s2);
+ }
+
+ if (!ctx->divergent_vals[instr->src[0].src.ssa->index]) { /* uniform condition and values in sgpr */
if (dst.regClass() == s1 || dst.regClass() == s2) {
assert((then.regClass() == s1 || then.regClass() == s2) && els.regClass() == then.regClass());
aco_opcode op = dst.regClass() == s1 ? aco_opcode::s_cselect_b32 : aco_opcode::s_cselect_b64;
- bld.sop2(op, Definition(dst), then, els, bld.scc(as_uniform_bool(ctx, cond)));
+ bld.sop2(op, Definition(dst), then, els, bld.scc(bool_to_scalar_condition(ctx, cond)));
} else {
fprintf(stderr, "Unimplemented uniform bcsel bit size: ");
nir_print_instr(&instr->instr, stderr);
@@ -612,34 +632,10 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
return;
}
- /* boolean bcsel */
- assert(instr->dest.dest.ssa.bit_size == 1);
-
- if (dst.regClass() == s1)
- cond = as_uniform_bool(ctx, cond);
-
- if (cond.regClass() == s1) { /* uniform selection */
- aco_opcode op;
- if (dst.regClass() == s2) {
- op = aco_opcode::s_cselect_b64;
- then = as_divergent_bool(ctx, then, false);
- els = as_divergent_bool(ctx, els, false);
- } else {
- assert(dst.regClass() == s1);
- op = aco_opcode::s_cselect_b32;
- then = as_uniform_bool(ctx, then);
- els = as_uniform_bool(ctx, els);
- }
- bld.sop2(op, Definition(dst), then, els, bld.scc(cond));
- return;
- }
-
/* divergent boolean bcsel
* this implements bcsel on bools: dst = s0 ? s1 : s2
* are going to be: dst = (s0 & s1) | (~s0 & s2) */
- assert (dst.regClass() == s2);
- then = as_divergent_bool(ctx, then, false);
- els = as_divergent_bool(ctx, els, false);
+ assert(instr->dest.dest.ssa.bit_size == 1);
if (cond.id() != then.id())
then = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), cond, then);
@@ -700,16 +696,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
}
case nir_op_inot: {
Temp src = get_alu_src(ctx, instr->src[0]);
- /* uniform booleans */
- if (instr->dest.dest.ssa.bit_size == 1 && dst.regClass() == s1) {
- if (src.regClass() == s1) {
- /* in this case, src is either 1 or 0 */
- bld.sop2(aco_opcode::s_xor_b32, bld.def(s1), bld.scc(Definition(dst)), Operand(1u), src);
- } else {
- /* src is either exec_mask or 0 */
- assert(src.regClass() == s2);
- bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(Definition(dst)), Operand(0u), src);
- }
+ if (instr->dest.dest.ssa.bit_size == 1) {
+ assert(src.regClass() == s2);
+ assert(dst.regClass() == s2);
+ bld.sop2(aco_opcode::s_andn2_b64, Definition(dst), bld.def(s1, scc), Operand(exec, s2), src);
} else if (dst.regClass() == v1) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_not_b32, dst);
} else if (dst.type() == RegType::sgpr) {
@@ -1919,12 +1909,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
}
case nir_op_b2f32: {
Temp src = get_alu_src(ctx, instr->src[0]);
+ assert(src.regClass() == s2);
+
if (dst.regClass() == s1) {
- src = as_uniform_bool(ctx, src);
+ src = bool_to_scalar_condition(ctx, src);
bld.sop2(aco_opcode::s_mul_i32, Definition(dst), Operand(0x3f800000u), src);
} else if (dst.regClass() == v1) {
- bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u),
- as_divergent_bool(ctx, src, true));
+ bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u), src);
} else {
unreachable("Wrong destination register class for nir_op_b2f32.");
}
@@ -1932,13 +1923,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
}
case nir_op_b2f64: {
Temp src = get_alu_src(ctx, instr->src[0]);
+ assert(src.regClass() == s2);
+
if (dst.regClass() == s2) {
- src = as_uniform_bool(ctx, src);
+ src = bool_to_scalar_condition(ctx, src);
bld.sop2(aco_opcode::s_cselect_b64, Definition(dst), Operand(0x3f800000u), Operand(0u), bld.scc(src));
} else if (dst.regClass() == v2) {
Temp one = bld.vop1(aco_opcode::v_mov_b32, bld.def(v2), Operand(0x3FF00000u));
- Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one,
- as_divergent_bool(ctx, src, true));
+ Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, src);
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), Operand(0u), upper);
} else {
unreachable("Wrong destination register class for nir_op_b2f64.");
@@ -2000,29 +1992,31 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
}
case nir_op_b2i32: {
Temp src = get_alu_src(ctx, instr->src[0]);
+ assert(src.regClass() == s2);
+
if (dst.regClass() == s1) {
- if (src.regClass() == s1) {
- bld.copy(Definition(dst), src);
- } else {
- // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ
- assert(src.regClass() == s2);
- bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(Definition(dst)), Operand(0u), src);
- }
- } else {
- assert(dst.regClass() == v1 && src.regClass() == s2);
+ // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ
+ bool_to_scalar_condition(ctx, src, dst);
+ } else if (dst.regClass() == v1) {
bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(1u), src);
+ } else {
+ unreachable("Invalid register class for b2i32");
}
break;
}
case nir_op_i2b1: {
Temp src = get_alu_src(ctx, instr->src[0]);
- if (dst.regClass() == s2) {
+ assert(dst.regClass() == s2);
+
+ if (src.type() == RegType::vgpr) {
assert(src.regClass() == v1 || src.regClass() == v2);
bld.vopc(src.size() == 2 ? aco_opcode::v_cmp_lg_u64 : aco_opcode::v_cmp_lg_u32,
Definition(dst), Operand(0u), src).def(0).setHint(vcc);
} else {
- assert(src.regClass() == s1 && dst.regClass() == s1);
- bld.sopc(aco_opcode::s_cmp_lg_u32, bld.scc(Definition(dst)), Operand(0u), src);
+ assert(src.regClass() == s1 || src.regClass() == s2);
+ Temp tmp = bld.sopc(src.size() == 2 ? aco_opcode::s_cmp_lg_u64 : aco_opcode::s_cmp_lg_u32,
+ bld.scc(bld.def(s1)), Operand(0u), src);
+ bool_to_vector_condition(ctx, tmp, dst);
}
break;
}
@@ -2228,119 +2222,49 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_flt: {
- if (instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f32, dst);
- else if (instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64);
break;
}
case nir_op_fge: {
- if (instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f32, dst);
- else if (instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64);
break;
}
case nir_op_feq: {
- if (instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f32, dst);
- else if (instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64);
break;
}
case nir_op_fne: {
- if (instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f32, dst);
- else if (instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64);
break;
}
case nir_op_ilt: {
- if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i32, dst);
- else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_i32, dst);
- else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32);
break;
}
case nir_op_ige: {
- if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i32, dst);
- else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_i32, dst);
- else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32);
break;
}
case nir_op_ieq: {
- if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) {
- emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i32, dst);
- } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) {
- emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_i32, dst);
- } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) {
- emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i64, dst);
- } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) {
- emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_u64, dst);
- } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) {
- Temp src0 = get_alu_src(ctx, instr->src[0]);
- Temp src1 = get_alu_src(ctx, instr->src[1]);
- bld.sopc(aco_opcode::s_cmp_eq_i32, bld.scc(Definition(dst)),
- as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1));
- } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) {
- Temp src0 = get_alu_src(ctx, instr->src[0]);
- Temp src1 = get_alu_src(ctx, instr->src[1]);
- bld.sop2(aco_opcode::s_xnor_b64, Definition(dst), bld.def(s1, scc),
- as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false));
- } else {
- fprintf(stderr, "Unimplemented NIR instr bit size: ");
- nir_print_instr(&instr->instr, stderr);
- fprintf(stderr, "\n");
- }
+ if (instr->src[0].src.ssa->bit_size == 1)
+ emit_boolean_logic(ctx, instr, aco_opcode::s_xnor_b32, aco_opcode::s_xnor_b64, dst);
+ else
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, aco_opcode::s_cmp_eq_u64);
break;
}
case nir_op_ine: {
- if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) {
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i32, dst);
- } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) {
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i64, dst);
- } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) {
- emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_i32, dst);
- } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) {
- emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_u64, dst);
- } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) {
- Temp src0 = get_alu_src(ctx, instr->src[0]);
- Temp src1 = get_alu_src(ctx, instr->src[1]);
- bld.sopc(aco_opcode::s_cmp_lg_i32, bld.scc(Definition(dst)),
- as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1));
- } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) {
- Temp src0 = get_alu_src(ctx, instr->src[0]);
- Temp src1 = get_alu_src(ctx, instr->src[1]);
- bld.sop2(aco_opcode::s_xor_b64, Definition(dst), bld.def(s1, scc),
- as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false));
- } else {
- fprintf(stderr, "Unimplemented NIR instr bit size: ");
- nir_print_instr(&instr->instr, stderr);
- fprintf(stderr, "\n");
- }
+ if (instr->src[0].src.ssa->bit_size == 1)
+ emit_boolean_logic(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::s_xor_b64, dst);
+ else
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, aco_opcode::s_cmp_lg_u64);
break;
}
case nir_op_ult: {
- if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u32, dst);
- else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_u32, dst);
- else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32);
break;
}
case nir_op_uge: {
- if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u32, dst);
- else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
- emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_u32, dst);
- else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
- emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u64, dst);
+ emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32);
break;
}
case nir_op_fddx:
@@ -2387,9 +2311,13 @@ void visit_load_const(isel_context *ctx, nir_load_const_instr *instr)
assert(instr->def.num_components == 1 && "Vector load_const should be lowered to scalar.");
assert(dst.type() == RegType::sgpr);
- if (dst.size() == 1)
- {
- Builder(ctx->program, ctx->block).copy(Definition(dst), Operand(instr->value[0].u32));
+ Builder bld(ctx->program, ctx->block);
+
+ if (instr->def.bit_size == 1) {
+ assert(dst.regClass() == s2);
+ bld.sop1(aco_opcode::s_mov_b64, Definition(dst), Operand((uint64_t)(instr->value[0].b ? -1 : 0)));
+ } else if (dst.size() == 1) {
+ bld.copy(Definition(dst), Operand(instr->value[0].u32));
} else {
assert(dst.size() != 1);
aco_ptr<Pseudo_instruction> vec{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, dst.size(), 1)};
@@ -3577,7 +3505,8 @@ void visit_discard_if(isel_context *ctx, nir_intrinsic_instr *instr)
// TODO: optimize uniform conditions
Builder bld(ctx->program, ctx->block);
- Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false);
+ Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
+ assert(src.regClass() == s2);
src = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2));
bld.pseudo(aco_opcode::p_discard_if, src);
ctx->block->kind |= block_kind_uses_discard_if;
@@ -5114,15 +5043,17 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te
} else if (op == nir_op_iand && cluster_size == 64) {
//subgroupAnd(val) -> (exec & ~val) == 0
Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp();
- return bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), tmp, Operand(0u));
+ return bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp));
} else if (op == nir_op_ior && cluster_size == 64) {
//subgroupOr(val) -> (val & exec) != 0
- return bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp();
+ Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp();
+ return bool_to_vector_condition(ctx, tmp);
} else if (op == nir_op_ixor && cluster_size == 64) {
//subgroupXor(val) -> s_bcnt1_i32_b64(val & exec) & 1
Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2));
tmp = bld.sop1(aco_opcode::s_bcnt1_i32_b64, bld.def(s2), bld.def(s1, scc), tmp);
- return bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp();
+ tmp = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp();
+ return bool_to_vector_condition(ctx, tmp);
} else {
//subgroupClustered{And,Or,Xor}(val, n) ->
//lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0))
@@ -5221,8 +5152,6 @@ void emit_uniform_subgroup(isel_context *ctx, nir_intrinsic_instr *instr, Temp s
Definition dst(get_ssa_temp(ctx, &instr->dest.ssa));
if (src.regClass().type() == RegType::vgpr) {
bld.pseudo(aco_opcode::p_as_uniform, dst, src);
- } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
- bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(dst), Operand(0u), Operand(src));
} else if (src.regClass() == s1) {
bld.sop1(aco_opcode::s_mov_b32, dst, src);
} else if (src.regClass() == s2) {
@@ -5541,10 +5470,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
case nir_intrinsic_ballot: {
Definition tmp = bld.def(s2);
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
- if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s2) {
+ if (instr->src[0].ssa->bit_size == 1) {
+ assert(src.regClass() == s2);
bld.sop2(aco_opcode::s_and_b64, tmp, bld.def(s1, scc), Operand(exec, s2), src);
- } else if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s1) {
- bld.sop2(aco_opcode::s_cselect_b64, tmp, Operand(exec, s2), Operand(0u), bld.scc(src));
} else if (instr->src[0].ssa->bit_size == 32 && src.regClass() == v1) {
bld.vopc(aco_opcode::v_cmp_lg_u32, tmp, Operand(0u), src);
} else if (instr->src[0].ssa->bit_size == 64 && src.regClass() == v2) {
@@ -5576,9 +5504,12 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
hi = emit_wqm(ctx, emit_bpermute(ctx, bld, tid, hi));
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
emit_split_vector(ctx, dst, 2);
- } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2 && tid.regClass() == s1) {
- emit_wqm(ctx, bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid), dst);
- } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
+ } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == s1) {
+ assert(src.regClass() == s2);
+ Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid);
+ bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst);
+ } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == v1) {
+ assert(src.regClass() == s2);
Temp tmp = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), tid, src);
tmp = emit_extract_vector(ctx, tmp, 0, v1);
tmp = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(1u), tmp);
@@ -5614,11 +5545,11 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
hi = emit_wqm(ctx, bld.vop1(aco_opcode::v_readfirstlane_b32, bld.def(s1), hi));
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
emit_split_vector(ctx, dst, 2);
- } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
- emit_wqm(ctx,
- bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src,
- bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2))),
- dst);
+ } else if (instr->dest.ssa.bit_size == 1) {
+ assert(src.regClass() == s2);
+ Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src,
+ bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2)));
+ bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst);
} else if (src.regClass() == s1) {
bld.sop1(aco_opcode::s_mov_b32, Definition(dst), src);
} else if (src.regClass() == s2) {
@@ -5631,27 +5562,25 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
break;
}
case nir_intrinsic_vote_all: {
- Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false);
+ Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
assert(src.regClass() == s2);
- assert(dst.regClass() == s1);
+ assert(dst.regClass() == s2);
- Definition tmp = bld.def(s1);
- bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(tmp),
- bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)),
- Operand(exec, s2));
- emit_wqm(ctx, tmp.getTemp(), dst);
+ Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp();
+ Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp));
+ emit_wqm(ctx, val, dst);
break;
}
case nir_intrinsic_vote_any: {
- Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false);
+ Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
assert(src.regClass() == s2);
- assert(dst.regClass() == s1);
+ assert(dst.regClass() == s2);
- Definition tmp = bld.def(s1);
- bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(tmp), src, Operand(exec, s2));
- emit_wqm(ctx, tmp.getTemp(), dst);
+ Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp();
+ Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(-1u), Operand(0u), bld.scc(tmp));
+ emit_wqm(ctx, val, dst);
break;
}
case nir_intrinsic_reduce:
@@ -5752,7 +5681,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
} else {
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
unsigned lane = nir_src_as_const_value(instr->src[1])->u32;
- if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
+ if (instr->dest.ssa.bit_size == 1) {
+ assert(src.regClass() == s2);
uint32_t half_mask = 0x11111111u << lane;
Temp mask_tmp = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), Operand(half_mask), Operand(half_mask));
Temp tmp = bld.tmp(s2);
@@ -5809,7 +5739,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
}
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
- if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
+ if (instr->dest.ssa.bit_size == 1) {
+ assert(src.regClass() == s2);
src = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), Operand((uint32_t)-1), src);
src = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl);
Temp tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(s2), Operand(0u), src);
@@ -5912,9 +5843,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
ctx->program->needs_exact = true;
break;
case nir_intrinsic_demote_if: {
- Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc),
- as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false),
- Operand(exec, s2));
+ Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
+ assert(src.regClass() == s2);
+ Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2));
bld.pseudo(aco_opcode::p_demote_to_helper, cond);
ctx->block->kind |= block_kind_uses_demote;
ctx->program->needs_exact = true;
@@ -6520,7 +6451,9 @@ void visit_tex(isel_context *ctx, nir_tex_instr *instr)
Operand((uint32_t)V_008F14_IMG_NUM_FORMAT_SINT),
bld.scc(compare_cube_wa));
}
- tg4_compare_cube_wa64 = as_divergent_bool(ctx, compare_cube_wa, true);
+ tg4_compare_cube_wa64 = bld.tmp(s2);
+ bool_to_vector_condition(ctx, compare_cube_wa, tg4_compare_cube_wa64);
+
nfmt = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), nfmt, Operand(26u));
desc[1] = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), desc[1],
@@ -6770,6 +6703,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
aco_ptr<Pseudo_instruction> phi;
unsigned num_src = exec_list_length(&instr->srcs);
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
+ assert(instr->dest.ssa.bit_size != 1 || dst.regClass() == s2);
aco_opcode opcode = !dst.is_linear() || ctx->divergent_vals[instr->dest.ssa.index] ? aco_opcode::p_phi : aco_opcode::p_linear_phi;
@@ -6797,7 +6731,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
}
/* try to scalarize vector phis */
- if (dst.size() > 1) {
+ if (instr->dest.ssa.bit_size != 1 && dst.size() > 1) {
// TODO: scalarize linear phis on divergent ifs
bool can_scalarize = (opcode == aco_opcode::p_phi || !(ctx->block->kind & block_kind_merge));
std::array<Temp, 4> new_vec;
@@ -7265,10 +7199,10 @@ static void visit_if(isel_context *ctx, nir_if *if_stmt)
ctx->block->kind |= block_kind_uniform;
/* emit branch */
- if (cond.regClass() == s2) {
- // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction
- cond = as_uniform_bool(ctx, cond);
- }
+ assert(cond.regClass() == s2);
+ // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction
+ cond = bool_to_scalar_condition(ctx, cond);
+
branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_cbranch_z, Format::PSEUDO_BRANCH, 1, 0));
branch->operands[0] = Operand(cond);
branch->operands[0].setFixed(scc);
diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp
index 2c349635799..cdc8103497b 100644
--- a/src/amd/compiler/aco_instruction_selection_setup.cpp
+++ b/src/amd/compiler/aco_instruction_selection_setup.cpp
@@ -244,25 +244,14 @@ void init_context(isel_context *ctx, nir_shader *shader)
case nir_op_fge:
case nir_op_feq:
case nir_op_fne:
- size = 2;
- break;
case nir_op_ilt:
case nir_op_ige:
case nir_op_ult:
case nir_op_uge:
- size = alu_instr->src[0].src.ssa->bit_size == 64 ? 2 : 1;
- /* fallthrough */
case nir_op_ieq:
case nir_op_ine:
case nir_op_i2b1:
- if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) {
- size = 2;
- } else {
- for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
- if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr)
- size = 2;
- }
- }
+ size = 2;
break;
case nir_op_f2i64:
case nir_op_f2u64:
@@ -274,13 +263,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
break;
case nir_op_bcsel:
if (alu_instr->dest.dest.ssa.bit_size == 1) {
- if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index])
- size = 2;
- else if (allocated[alu_instr->src[1].src.ssa->index].regClass() == s2 &&
- allocated[alu_instr->src[2].src.ssa->index].regClass() == s2)
- size = 2;
- else
- size = 1;
+ size = 2;
} else {
if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) {
type = RegType::vgpr;
@@ -298,32 +281,14 @@ void init_context(isel_context *ctx, nir_shader *shader)
break;
case nir_op_mov:
if (alu_instr->dest.dest.ssa.bit_size == 1) {
- size = allocated[alu_instr->src[0].src.ssa->index].size();
+ size = 2;
} else {
type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr;
}
break;
- case nir_op_inot:
- case nir_op_ixor:
- if (alu_instr->dest.dest.ssa.bit_size == 1) {
- size = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? 2 : 1;
- break;
- } else {
- /* fallthrough */
- }
default:
if (alu_instr->dest.dest.ssa.bit_size == 1) {
- if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) {
- size = 2;
- } else {
- size = 2;
- for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
- if (allocated[alu_instr->src[i].src.ssa->index].regClass() == s1) {
- size = 1;
- break;
- }
- }
- }
+ size = 2;
} else {
for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr)
@@ -339,6 +304,8 @@ void init_context(isel_context *ctx, nir_shader *shader)
unsigned size = nir_instr_as_load_const(instr)->def.num_components;
if (nir_instr_as_load_const(instr)->def.bit_size == 64)
size *= 2;
+ else if (nir_instr_as_load_const(instr)->def.bit_size == 1)
+ size *= 2;
allocated[nir_instr_as_load_const(instr)->def.index] = Temp(0, RegClass(RegType::sgpr, size));
break;
}
@@ -365,6 +332,8 @@ void init_context(isel_context *ctx, nir_shader *shader)
case nir_intrinsic_read_invocation:
case nir_intrinsic_first_invocation:
type = RegType::sgpr;
+ if (intrinsic->dest.ssa.bit_size == 1)
+ size = 2;
break;
case nir_intrinsic_ballot:
type = RegType::sgpr;
@@ -433,11 +402,11 @@ void init_context(isel_context *ctx, nir_shader *shader)
case nir_intrinsic_masked_swizzle_amd:
case nir_intrinsic_inclusive_scan:
case nir_intrinsic_exclusive_scan:
- if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) {
+ if (intrinsic->dest.ssa.bit_size == 1) {
+ size = 2;
type = RegType::sgpr;
- } else if (intrinsic->src[0].ssa->bit_size == 1) {
+ } else if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) {
type = RegType::sgpr;
- size = 2;
} else {
type = RegType::vgpr;
}
@@ -452,12 +421,12 @@ void init_context(isel_context *ctx, nir_shader *shader)
size = 2;
break;
case nir_intrinsic_reduce:
- if (nir_intrinsic_cluster_size(intrinsic) == 0 ||
- !ctx->divergent_vals[intrinsic->dest.ssa.index]) {
+ if (intrinsic->dest.ssa.bit_size == 1) {
+ size = 2;
type = RegType::sgpr;
- } else if (intrinsic->src[0].ssa->bit_size == 1) {
+ } else if (nir_intrinsic_cluster_size(intrinsic) == 0 ||
+ !ctx->divergent_vals[intrinsic->dest.ssa.index]) {
type = RegType::sgpr;
- size = 2;
} else {
type = RegType::vgpr;
}
@@ -554,7 +523,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
if (phi->dest.ssa.bit_size == 1) {
assert(size == 1 && "multiple components not yet supported on boolean phis.");
type = RegType::sgpr;
- size *= ctx->divergent_vals[phi->dest.ssa.index] ? 2 : 1;
+ size *= 2;
allocated[phi->dest.ssa.index] = Temp(0, RegClass(type, size));
break;
}
diff --git a/src/amd/compiler/aco_lower_bool_phis.cpp b/src/amd/compiler/aco_lower_bool_phis.cpp
index ac4663a2ce1..9e5374fe6a0 100644
--- a/src/amd/compiler/aco_lower_bool_phis.cpp
+++ b/src/amd/compiler/aco_lower_bool_phis.cpp
@@ -150,13 +150,6 @@ void lower_divergent_bool_phi(Program *program, Block *block, aco_ptr<Instructio
assert(phi->operands[i].isTemp());
Temp phi_src = phi->operands[i].getTemp();
- if (phi_src.regClass() == s1) {
- Temp new_phi_src = bld.tmp(s2);
- insert_before_logical_end(pred,
- bld.sop2(aco_opcode::s_cselect_b64, Definition(new_phi_src),
- Operand((uint32_t)-1), Operand(0u), bld.scc(phi_src)).get_ptr());
- phi_src = new_phi_src;
- }
assert(phi_src.regClass() == s2);
Operand cur = get_ssa(program, pred->index, &state);
@@ -218,6 +211,7 @@ void lower_bool_phis(Program* program)
for (Block& block : program->blocks) {
for (aco_ptr<Instruction>& phi : block.instructions) {
if (phi->opcode == aco_opcode::p_phi) {
+ assert(phi->definitions[0].regClass() != s1);
if (phi->definitions[0].regClass() == s2)
lower_divergent_bool_phi(program, &block, phi);
} else if (phi->opcode == aco_opcode::p_linear_phi) {