diff options
-rw-r--r-- | src/amd/compiler/aco_instruction_selection.cpp | 400 | ||||
-rw-r--r-- | src/amd/compiler/aco_instruction_selection_setup.cpp | 63 | ||||
-rw-r--r-- | src/amd/compiler/aco_lower_bool_phis.cpp | 8 |
3 files changed, 184 insertions, 287 deletions
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index ab34a068671..a7c3c703403 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -130,6 +130,8 @@ Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_ne if (!dst.id()) dst = bld.tmp(src.regClass()); + assert(src.size() == dst.size()); + if (ctx->stage != fragment_fs) { if (!dst.id()) return src; @@ -331,33 +333,31 @@ void expand_vector(isel_context* ctx, Temp vec_src, Temp dst, unsigned num_compo ctx->allocated_vec.emplace(dst.id(), elems); } -Temp as_divergent_bool(isel_context *ctx, Temp val, bool vcc_hint) +Temp bool_to_vector_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s2)) { - if (val.regClass() == s2) { - return val; - } else { - assert(val.regClass() == s1); - Builder bld(ctx->program, ctx->block); - Definition& def = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), - Operand((uint32_t) -1), Operand(0u), bld.scc(val)).def(0); - if (vcc_hint) - def.setHint(vcc); - return def.getTemp(); - } + Builder bld(ctx->program, ctx->block); + if (!dst.id()) + dst = bld.tmp(s2); + + assert(val.regClass() == s1); + assert(dst.regClass() == s2); + + return bld.sop2(aco_opcode::s_cselect_b64, bld.hint_vcc(Definition(dst)), Operand((uint32_t) -1), Operand(0u), bld.scc(val)); } -Temp as_uniform_bool(isel_context *ctx, Temp val) +Temp bool_to_scalar_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s1)) { - if (val.regClass() == s1) { - return val; - } else { - assert(val.regClass() == s2); - Builder bld(ctx->program, ctx->block); - /* if we're currently in WQM mode, ensure that the source is also computed in WQM */ - Temp tmp = bld.tmp(s1); - bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2)).def(1).getTemp(); - return emit_wqm(ctx, tmp); - } + Builder bld(ctx->program, ctx->block); + if (!dst.id()) + dst = bld.tmp(s1); + + assert(val.regClass() == s2); + assert(dst.regClass() == s1); + + /* if we're currently in WQM mode, ensure that the source is also computed in WQM */ + Temp tmp = bld.tmp(s1); + bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2)); + return emit_wqm(ctx, tmp, dst); } Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1) @@ -526,27 +526,44 @@ void emit_vopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o src1 = as_vgpr(ctx, src1); } } + Builder bld(ctx->program, ctx->block); - bld.vopc(op, Definition(dst), src0, src1).def(0).setHint(vcc); + bld.vopc(op, bld.hint_vcc(Definition(dst)), src0, src1); } -void emit_comparison(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst) +void emit_sopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst) { - if (dst.regClass() == s2) { - emit_vopc_instruction(ctx, instr, op, dst); - if (!ctx->divergent_vals[instr->dest.dest.ssa.index]) - emit_split_vector(ctx, dst, 2); - } else if (dst.regClass() == s1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - assert(src0.type() == RegType::sgpr && src1.type() == RegType::sgpr); + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); - Builder bld(ctx->program, ctx->block); - bld.sopc(op, bld.scc(Definition(dst)), src0, src1); + assert(dst.regClass() == s2); + assert(src0.type() == RegType::sgpr); + assert(src1.type() == RegType::sgpr); - } else { - assert(false); - } + Builder bld(ctx->program, ctx->block); + /* Emit the SALU comparison instruction */ + Temp cmp = bld.sopc(op, bld.scc(bld.def(s1)), src0, src1); + /* Turn the result into a per-lane bool */ + bool_to_vector_condition(ctx, cmp, dst); +} + +void emit_comparison(isel_context *ctx, nir_alu_instr *instr, Temp dst, + aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::last_opcode, aco_opcode s64_op = aco_opcode::last_opcode) +{ + aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : s32_op; + aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : v32_op; + bool divergent_vals = ctx->divergent_vals[instr->dest.dest.ssa.index]; + bool use_valu = s_op == aco_opcode::last_opcode || + divergent_vals || + ctx->allocated[instr->src[0].src.ssa->index].type() == RegType::vgpr || + ctx->allocated[instr->src[1].src.ssa->index].type() == RegType::vgpr; + aco_opcode op = use_valu ? v_op : s_op; + assert(op != aco_opcode::last_opcode); + + if (use_valu) + emit_vopc_instruction(ctx, instr, op, dst); + else + emit_sopc_instruction(ctx, instr, op, dst); } void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32, aco_opcode op64, Temp dst) @@ -554,16 +571,13 @@ void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32 Builder bld(ctx->program, ctx->block); Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); - if (dst.regClass() == s2) { - bld.sop2(op64, Definition(dst), bld.def(s1, scc), - as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false)); - } else { - assert(dst.regClass() == s1); - bld.sop2(op32, bld.def(s1), bld.scc(Definition(dst)), - as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1)); - } -} + assert(dst.regClass() == s2); + assert(src0.regClass() == s2); + assert(src1.regClass() == s2); + + bld.sop2(op64, Definition(dst), bld.def(s1, scc), src0, src1); +} void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) { @@ -572,9 +586,9 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) Temp then = get_alu_src(ctx, instr->src[1]); Temp els = get_alu_src(ctx, instr->src[2]); - if (dst.type() == RegType::vgpr) { - cond = as_divergent_bool(ctx, cond, true); + assert(cond.regClass() == s2); + if (dst.type() == RegType::vgpr) { aco_ptr<Instruction> bcsel; if (dst.size() == 1) { then = as_vgpr(ctx, then); @@ -599,11 +613,17 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) return; } - if (instr->dest.dest.ssa.bit_size != 1) { /* uniform condition and values in sgpr */ + if (instr->dest.dest.ssa.bit_size == 1) { + assert(dst.regClass() == s2); + assert(then.regClass() == s2); + assert(els.regClass() == s2); + } + + if (!ctx->divergent_vals[instr->src[0].src.ssa->index]) { /* uniform condition and values in sgpr */ if (dst.regClass() == s1 || dst.regClass() == s2) { assert((then.regClass() == s1 || then.regClass() == s2) && els.regClass() == then.regClass()); aco_opcode op = dst.regClass() == s1 ? aco_opcode::s_cselect_b32 : aco_opcode::s_cselect_b64; - bld.sop2(op, Definition(dst), then, els, bld.scc(as_uniform_bool(ctx, cond))); + bld.sop2(op, Definition(dst), then, els, bld.scc(bool_to_scalar_condition(ctx, cond))); } else { fprintf(stderr, "Unimplemented uniform bcsel bit size: "); nir_print_instr(&instr->instr, stderr); @@ -612,34 +632,10 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) return; } - /* boolean bcsel */ - assert(instr->dest.dest.ssa.bit_size == 1); - - if (dst.regClass() == s1) - cond = as_uniform_bool(ctx, cond); - - if (cond.regClass() == s1) { /* uniform selection */ - aco_opcode op; - if (dst.regClass() == s2) { - op = aco_opcode::s_cselect_b64; - then = as_divergent_bool(ctx, then, false); - els = as_divergent_bool(ctx, els, false); - } else { - assert(dst.regClass() == s1); - op = aco_opcode::s_cselect_b32; - then = as_uniform_bool(ctx, then); - els = as_uniform_bool(ctx, els); - } - bld.sop2(op, Definition(dst), then, els, bld.scc(cond)); - return; - } - /* divergent boolean bcsel * this implements bcsel on bools: dst = s0 ? s1 : s2 * are going to be: dst = (s0 & s1) | (~s0 & s2) */ - assert (dst.regClass() == s2); - then = as_divergent_bool(ctx, then, false); - els = as_divergent_bool(ctx, els, false); + assert(instr->dest.dest.ssa.bit_size == 1); if (cond.id() != then.id()) then = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), cond, then); @@ -700,16 +696,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_inot: { Temp src = get_alu_src(ctx, instr->src[0]); - /* uniform booleans */ - if (instr->dest.dest.ssa.bit_size == 1 && dst.regClass() == s1) { - if (src.regClass() == s1) { - /* in this case, src is either 1 or 0 */ - bld.sop2(aco_opcode::s_xor_b32, bld.def(s1), bld.scc(Definition(dst)), Operand(1u), src); - } else { - /* src is either exec_mask or 0 */ - assert(src.regClass() == s2); - bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(Definition(dst)), Operand(0u), src); - } + if (instr->dest.dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); + assert(dst.regClass() == s2); + bld.sop2(aco_opcode::s_andn2_b64, Definition(dst), bld.def(s1, scc), Operand(exec, s2), src); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_not_b32, dst); } else if (dst.type() == RegType::sgpr) { @@ -1919,12 +1909,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_b2f32: { Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == s2); + if (dst.regClass() == s1) { - src = as_uniform_bool(ctx, src); + src = bool_to_scalar_condition(ctx, src); bld.sop2(aco_opcode::s_mul_i32, Definition(dst), Operand(0x3f800000u), src); } else if (dst.regClass() == v1) { - bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u), - as_divergent_bool(ctx, src, true)); + bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u), src); } else { unreachable("Wrong destination register class for nir_op_b2f32."); } @@ -1932,13 +1923,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_b2f64: { Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == s2); + if (dst.regClass() == s2) { - src = as_uniform_bool(ctx, src); + src = bool_to_scalar_condition(ctx, src); bld.sop2(aco_opcode::s_cselect_b64, Definition(dst), Operand(0x3f800000u), Operand(0u), bld.scc(src)); } else if (dst.regClass() == v2) { Temp one = bld.vop1(aco_opcode::v_mov_b32, bld.def(v2), Operand(0x3FF00000u)); - Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, - as_divergent_bool(ctx, src, true)); + Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, src); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), Operand(0u), upper); } else { unreachable("Wrong destination register class for nir_op_b2f64."); @@ -2000,29 +1992,31 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_b2i32: { Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == s2); + if (dst.regClass() == s1) { - if (src.regClass() == s1) { - bld.copy(Definition(dst), src); - } else { - // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ - assert(src.regClass() == s2); - bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(Definition(dst)), Operand(0u), src); - } - } else { - assert(dst.regClass() == v1 && src.regClass() == s2); + // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ + bool_to_scalar_condition(ctx, src, dst); + } else if (dst.regClass() == v1) { bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(1u), src); + } else { + unreachable("Invalid register class for b2i32"); } break; } case nir_op_i2b1: { Temp src = get_alu_src(ctx, instr->src[0]); - if (dst.regClass() == s2) { + assert(dst.regClass() == s2); + + if (src.type() == RegType::vgpr) { assert(src.regClass() == v1 || src.regClass() == v2); bld.vopc(src.size() == 2 ? aco_opcode::v_cmp_lg_u64 : aco_opcode::v_cmp_lg_u32, Definition(dst), Operand(0u), src).def(0).setHint(vcc); } else { - assert(src.regClass() == s1 && dst.regClass() == s1); - bld.sopc(aco_opcode::s_cmp_lg_u32, bld.scc(Definition(dst)), Operand(0u), src); + assert(src.regClass() == s1 || src.regClass() == s2); + Temp tmp = bld.sopc(src.size() == 2 ? aco_opcode::s_cmp_lg_u64 : aco_opcode::s_cmp_lg_u32, + bld.scc(bld.def(s1)), Operand(0u), src); + bool_to_vector_condition(ctx, tmp, dst); } break; } @@ -2228,119 +2222,49 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_flt: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64); break; } case nir_op_fge: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64); break; } case nir_op_feq: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64); break; } case nir_op_fne: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64); break; } case nir_op_ilt: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_i32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32); break; } case nir_op_ige: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_i32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32); break; } case nir_op_ieq: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i32, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_i32, dst); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_u64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sopc(aco_opcode::s_cmp_eq_i32, bld.scc(Definition(dst)), - as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1)); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sop2(aco_opcode::s_xnor_b64, Definition(dst), bld.def(s1, scc), - as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false)); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + if (instr->src[0].src.ssa->bit_size == 1) + emit_boolean_logic(ctx, instr, aco_opcode::s_xnor_b32, aco_opcode::s_xnor_b64, dst); + else + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, aco_opcode::s_cmp_eq_u64); break; } case nir_op_ine: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i32, dst); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_i32, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_u64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sopc(aco_opcode::s_cmp_lg_i32, bld.scc(Definition(dst)), - as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1)); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sop2(aco_opcode::s_xor_b64, Definition(dst), bld.def(s1, scc), - as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false)); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + if (instr->src[0].src.ssa->bit_size == 1) + emit_boolean_logic(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::s_xor_b64, dst); + else + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, aco_opcode::s_cmp_lg_u64); break; } case nir_op_ult: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_u32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32); break; } case nir_op_uge: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_u32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32); break; } case nir_op_fddx: @@ -2387,9 +2311,13 @@ void visit_load_const(isel_context *ctx, nir_load_const_instr *instr) assert(instr->def.num_components == 1 && "Vector load_const should be lowered to scalar."); assert(dst.type() == RegType::sgpr); - if (dst.size() == 1) - { - Builder(ctx->program, ctx->block).copy(Definition(dst), Operand(instr->value[0].u32)); + Builder bld(ctx->program, ctx->block); + + if (instr->def.bit_size == 1) { + assert(dst.regClass() == s2); + bld.sop1(aco_opcode::s_mov_b64, Definition(dst), Operand((uint64_t)(instr->value[0].b ? -1 : 0))); + } else if (dst.size() == 1) { + bld.copy(Definition(dst), Operand(instr->value[0].u32)); } else { assert(dst.size() != 1); aco_ptr<Pseudo_instruction> vec{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, dst.size(), 1)}; @@ -3577,7 +3505,8 @@ void visit_discard_if(isel_context *ctx, nir_intrinsic_instr *instr) // TODO: optimize uniform conditions Builder bld(ctx->program, ctx->block); - Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); + assert(src.regClass() == s2); src = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)); bld.pseudo(aco_opcode::p_discard_if, src); ctx->block->kind |= block_kind_uses_discard_if; @@ -5114,15 +5043,17 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te } else if (op == nir_op_iand && cluster_size == 64) { //subgroupAnd(val) -> (exec & ~val) == 0 Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp(); - return bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), tmp, Operand(0u)); + return bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp)); } else if (op == nir_op_ior && cluster_size == 64) { //subgroupOr(val) -> (val & exec) != 0 - return bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp(); + Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp(); + return bool_to_vector_condition(ctx, tmp); } else if (op == nir_op_ixor && cluster_size == 64) { //subgroupXor(val) -> s_bcnt1_i32_b64(val & exec) & 1 Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)); tmp = bld.sop1(aco_opcode::s_bcnt1_i32_b64, bld.def(s2), bld.def(s1, scc), tmp); - return bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp(); + tmp = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp(); + return bool_to_vector_condition(ctx, tmp); } else { //subgroupClustered{And,Or,Xor}(val, n) -> //lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0)) @@ -5221,8 +5152,6 @@ void emit_uniform_subgroup(isel_context *ctx, nir_intrinsic_instr *instr, Temp s Definition dst(get_ssa_temp(ctx, &instr->dest.ssa)); if (src.regClass().type() == RegType::vgpr) { bld.pseudo(aco_opcode::p_as_uniform, dst, src); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { - bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(dst), Operand(0u), Operand(src)); } else if (src.regClass() == s1) { bld.sop1(aco_opcode::s_mov_b32, dst, src); } else if (src.regClass() == s2) { @@ -5541,10 +5470,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) case nir_intrinsic_ballot: { Definition tmp = bld.def(s2); Temp src = get_ssa_temp(ctx, instr->src[0].ssa); - if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s2) { + if (instr->src[0].ssa->bit_size == 1) { + assert(src.regClass() == s2); bld.sop2(aco_opcode::s_and_b64, tmp, bld.def(s1, scc), Operand(exec, s2), src); - } else if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s1) { - bld.sop2(aco_opcode::s_cselect_b64, tmp, Operand(exec, s2), Operand(0u), bld.scc(src)); } else if (instr->src[0].ssa->bit_size == 32 && src.regClass() == v1) { bld.vopc(aco_opcode::v_cmp_lg_u32, tmp, Operand(0u), src); } else if (instr->src[0].ssa->bit_size == 64 && src.regClass() == v2) { @@ -5576,9 +5504,12 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) hi = emit_wqm(ctx, emit_bpermute(ctx, bld, tid, hi)); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi); emit_split_vector(ctx, dst, 2); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2 && tid.regClass() == s1) { - emit_wqm(ctx, bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid), dst); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { + } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == s1) { + assert(src.regClass() == s2); + Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid); + bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst); + } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == v1) { + assert(src.regClass() == s2); Temp tmp = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), tid, src); tmp = emit_extract_vector(ctx, tmp, 0, v1); tmp = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(1u), tmp); @@ -5614,11 +5545,11 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) hi = emit_wqm(ctx, bld.vop1(aco_opcode::v_readfirstlane_b32, bld.def(s1), hi)); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi); emit_split_vector(ctx, dst, 2); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { - emit_wqm(ctx, - bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, - bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2))), - dst); + } else if (instr->dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); + Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, + bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2))); + bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst); } else if (src.regClass() == s1) { bld.sop1(aco_opcode::s_mov_b32, Definition(dst), src); } else if (src.regClass() == s2) { @@ -5631,27 +5562,25 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) break; } case nir_intrinsic_vote_all: { - Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); assert(src.regClass() == s2); - assert(dst.regClass() == s1); + assert(dst.regClass() == s2); - Definition tmp = bld.def(s1); - bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(tmp), - bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)), - Operand(exec, s2)); - emit_wqm(ctx, tmp.getTemp(), dst); + Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp(); + Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp)); + emit_wqm(ctx, val, dst); break; } case nir_intrinsic_vote_any: { - Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); assert(src.regClass() == s2); - assert(dst.regClass() == s1); + assert(dst.regClass() == s2); - Definition tmp = bld.def(s1); - bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(tmp), src, Operand(exec, s2)); - emit_wqm(ctx, tmp.getTemp(), dst); + Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp(); + Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(-1u), Operand(0u), bld.scc(tmp)); + emit_wqm(ctx, val, dst); break; } case nir_intrinsic_reduce: @@ -5752,7 +5681,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) } else { Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); unsigned lane = nir_src_as_const_value(instr->src[1])->u32; - if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { + if (instr->dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); uint32_t half_mask = 0x11111111u << lane; Temp mask_tmp = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), Operand(half_mask), Operand(half_mask)); Temp tmp = bld.tmp(s2); @@ -5809,7 +5739,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) } Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); - if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { + if (instr->dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); src = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), Operand((uint32_t)-1), src); src = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl); Temp tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(s2), Operand(0u), src); @@ -5912,9 +5843,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) ctx->program->needs_exact = true; break; case nir_intrinsic_demote_if: { - Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), - as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false), - Operand(exec, s2)); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); + assert(src.regClass() == s2); + Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)); bld.pseudo(aco_opcode::p_demote_to_helper, cond); ctx->block->kind |= block_kind_uses_demote; ctx->program->needs_exact = true; @@ -6520,7 +6451,9 @@ void visit_tex(isel_context *ctx, nir_tex_instr *instr) Operand((uint32_t)V_008F14_IMG_NUM_FORMAT_SINT), bld.scc(compare_cube_wa)); } - tg4_compare_cube_wa64 = as_divergent_bool(ctx, compare_cube_wa, true); + tg4_compare_cube_wa64 = bld.tmp(s2); + bool_to_vector_condition(ctx, compare_cube_wa, tg4_compare_cube_wa64); + nfmt = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), nfmt, Operand(26u)); desc[1] = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), desc[1], @@ -6770,6 +6703,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr) aco_ptr<Pseudo_instruction> phi; unsigned num_src = exec_list_length(&instr->srcs); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); + assert(instr->dest.ssa.bit_size != 1 || dst.regClass() == s2); aco_opcode opcode = !dst.is_linear() || ctx->divergent_vals[instr->dest.ssa.index] ? aco_opcode::p_phi : aco_opcode::p_linear_phi; @@ -6797,7 +6731,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr) } /* try to scalarize vector phis */ - if (dst.size() > 1) { + if (instr->dest.ssa.bit_size != 1 && dst.size() > 1) { // TODO: scalarize linear phis on divergent ifs bool can_scalarize = (opcode == aco_opcode::p_phi || !(ctx->block->kind & block_kind_merge)); std::array<Temp, 4> new_vec; @@ -7265,10 +7199,10 @@ static void visit_if(isel_context *ctx, nir_if *if_stmt) ctx->block->kind |= block_kind_uniform; /* emit branch */ - if (cond.regClass() == s2) { - // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction - cond = as_uniform_bool(ctx, cond); - } + assert(cond.regClass() == s2); + // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction + cond = bool_to_scalar_condition(ctx, cond); + branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_cbranch_z, Format::PSEUDO_BRANCH, 1, 0)); branch->operands[0] = Operand(cond); branch->operands[0].setFixed(scc); diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index 2c349635799..cdc8103497b 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -244,25 +244,14 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_op_fge: case nir_op_feq: case nir_op_fne: - size = 2; - break; case nir_op_ilt: case nir_op_ige: case nir_op_ult: case nir_op_uge: - size = alu_instr->src[0].src.ssa->bit_size == 64 ? 2 : 1; - /* fallthrough */ case nir_op_ieq: case nir_op_ine: case nir_op_i2b1: - if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) { - size = 2; - } else { - for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) { - if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr) - size = 2; - } - } + size = 2; break; case nir_op_f2i64: case nir_op_f2u64: @@ -274,13 +263,7 @@ void init_context(isel_context *ctx, nir_shader *shader) break; case nir_op_bcsel: if (alu_instr->dest.dest.ssa.bit_size == 1) { - if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) - size = 2; - else if (allocated[alu_instr->src[1].src.ssa->index].regClass() == s2 && - allocated[alu_instr->src[2].src.ssa->index].regClass() == s2) - size = 2; - else - size = 1; + size = 2; } else { if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) { type = RegType::vgpr; @@ -298,32 +281,14 @@ void init_context(isel_context *ctx, nir_shader *shader) break; case nir_op_mov: if (alu_instr->dest.dest.ssa.bit_size == 1) { - size = allocated[alu_instr->src[0].src.ssa->index].size(); + size = 2; } else { type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr; } break; - case nir_op_inot: - case nir_op_ixor: - if (alu_instr->dest.dest.ssa.bit_size == 1) { - size = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? 2 : 1; - break; - } else { - /* fallthrough */ - } default: if (alu_instr->dest.dest.ssa.bit_size == 1) { - if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) { - size = 2; - } else { - size = 2; - for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) { - if (allocated[alu_instr->src[i].src.ssa->index].regClass() == s1) { - size = 1; - break; - } - } - } + size = 2; } else { for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) { if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr) @@ -339,6 +304,8 @@ void init_context(isel_context *ctx, nir_shader *shader) unsigned size = nir_instr_as_load_const(instr)->def.num_components; if (nir_instr_as_load_const(instr)->def.bit_size == 64) size *= 2; + else if (nir_instr_as_load_const(instr)->def.bit_size == 1) + size *= 2; allocated[nir_instr_as_load_const(instr)->def.index] = Temp(0, RegClass(RegType::sgpr, size)); break; } @@ -365,6 +332,8 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_intrinsic_read_invocation: case nir_intrinsic_first_invocation: type = RegType::sgpr; + if (intrinsic->dest.ssa.bit_size == 1) + size = 2; break; case nir_intrinsic_ballot: type = RegType::sgpr; @@ -433,11 +402,11 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_intrinsic_masked_swizzle_amd: case nir_intrinsic_inclusive_scan: case nir_intrinsic_exclusive_scan: - if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) { + if (intrinsic->dest.ssa.bit_size == 1) { + size = 2; type = RegType::sgpr; - } else if (intrinsic->src[0].ssa->bit_size == 1) { + } else if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) { type = RegType::sgpr; - size = 2; } else { type = RegType::vgpr; } @@ -452,12 +421,12 @@ void init_context(isel_context *ctx, nir_shader *shader) size = 2; break; case nir_intrinsic_reduce: - if (nir_intrinsic_cluster_size(intrinsic) == 0 || - !ctx->divergent_vals[intrinsic->dest.ssa.index]) { + if (intrinsic->dest.ssa.bit_size == 1) { + size = 2; type = RegType::sgpr; - } else if (intrinsic->src[0].ssa->bit_size == 1) { + } else if (nir_intrinsic_cluster_size(intrinsic) == 0 || + !ctx->divergent_vals[intrinsic->dest.ssa.index]) { type = RegType::sgpr; - size = 2; } else { type = RegType::vgpr; } @@ -554,7 +523,7 @@ void init_context(isel_context *ctx, nir_shader *shader) if (phi->dest.ssa.bit_size == 1) { assert(size == 1 && "multiple components not yet supported on boolean phis."); type = RegType::sgpr; - size *= ctx->divergent_vals[phi->dest.ssa.index] ? 2 : 1; + size *= 2; allocated[phi->dest.ssa.index] = Temp(0, RegClass(type, size)); break; } diff --git a/src/amd/compiler/aco_lower_bool_phis.cpp b/src/amd/compiler/aco_lower_bool_phis.cpp index ac4663a2ce1..9e5374fe6a0 100644 --- a/src/amd/compiler/aco_lower_bool_phis.cpp +++ b/src/amd/compiler/aco_lower_bool_phis.cpp @@ -150,13 +150,6 @@ void lower_divergent_bool_phi(Program *program, Block *block, aco_ptr<Instructio assert(phi->operands[i].isTemp()); Temp phi_src = phi->operands[i].getTemp(); - if (phi_src.regClass() == s1) { - Temp new_phi_src = bld.tmp(s2); - insert_before_logical_end(pred, - bld.sop2(aco_opcode::s_cselect_b64, Definition(new_phi_src), - Operand((uint32_t)-1), Operand(0u), bld.scc(phi_src)).get_ptr()); - phi_src = new_phi_src; - } assert(phi_src.regClass() == s2); Operand cur = get_ssa(program, pred->index, &state); @@ -218,6 +211,7 @@ void lower_bool_phis(Program* program) for (Block& block : program->blocks) { for (aco_ptr<Instruction>& phi : block.instructions) { if (phi->opcode == aco_opcode::p_phi) { + assert(phi->definitions[0].regClass() != s1); if (phi->definitions[0].regClass() == s2) lower_divergent_bool_phi(program, &block, phi); } else if (phi->opcode == aco_opcode::p_linear_phi) { |