diff options
Diffstat (limited to 'src/amd')
-rw-r--r-- | src/amd/compiler/aco_instruction_selection.cpp | 113 |
1 files changed, 33 insertions, 80 deletions
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index f9246247167..1ac3cda14a8 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -750,18 +750,12 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) if (dst.type() == RegType::vgpr) { aco_ptr<Instruction> bcsel; - if (dst.regClass() == v2b) { - then = as_vgpr(ctx, then); - els = as_vgpr(ctx, els); - - Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), els, then, cond); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); - } else if (dst.regClass() == v1) { + if (dst.size() == 1) { then = as_vgpr(ctx, then); els = as_vgpr(ctx, els); bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), els, then, cond); - } else if (dst.regClass() == v2) { + } else if (dst.size() == 2) { Temp then_lo = bld.tmp(v1), then_hi = bld.tmp(v1); bld.pseudo(aco_opcode::p_split_vector, Definition(then_lo), Definition(then_hi), then); Temp else_lo = bld.tmp(v1), else_hi = bld.tmp(v1); @@ -1601,9 +1595,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { - Temp tmp = bld.tmp(v1); - emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, tmp, true); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true); } else if (dst.regClass() == v2) { @@ -1619,9 +1611,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { - Temp tmp = bld.tmp(v1); - emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, tmp, true); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true); } else if (dst.regClass() == v2) { @@ -1637,12 +1627,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); if (dst.regClass() == v2b) { - Temp tmp = bld.tmp(v1); if (src1.type() == RegType::vgpr || src0.type() != RegType::vgpr) - emit_vop2_instruction(ctx, instr, aco_opcode::v_sub_f16, tmp, false); + emit_vop2_instruction(ctx, instr, aco_opcode::v_sub_f16, dst, false); else - emit_vop2_instruction(ctx, instr, aco_opcode::v_subrev_f16, tmp, true); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop2_instruction(ctx, instr, aco_opcode::v_subrev_f16, dst, true); } else if (dst.regClass() == v1) { if (src1.type() == RegType::vgpr || src0.type() != RegType::vgpr) emit_vop2_instruction(ctx, instr, aco_opcode::v_sub_f32, dst, false); @@ -1665,9 +1653,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 - Temp tmp = bld.tmp(v1); - emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, tmp, true); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { @@ -1689,9 +1675,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 - Temp tmp = bld.tmp(v1); - emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, tmp, true); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { @@ -1710,9 +1694,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fmax3: { if (dst.regClass() == v2b) { - Temp tmp = bld.tmp(v1); - emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f16, tmp, false); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f16, dst, false); } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { @@ -1724,9 +1706,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fmin3: { if (dst.regClass() == v2b) { - Temp tmp = bld.tmp(v1); - emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f16, tmp, false); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f16, dst, false); } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { @@ -1738,9 +1718,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fmed3: { if (dst.regClass() == v2b) { - Temp tmp = bld.tmp(v1); - emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f16, tmp, false); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f16, dst, false); } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { @@ -1839,8 +1817,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_frsq: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_rsq_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f16, dst); } else if (dst.regClass() == v1) { emit_rsq(ctx, bld, Definition(dst), src); } else if (dst.regClass() == v2) { @@ -1855,8 +1832,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fneg: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop2(aco_opcode::v_xor_b32, bld.def(v1), Operand(0x8000u), as_vgpr(ctx, src)); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop2(aco_opcode::v_xor_b32, Definition(dst), Operand(0x8000u), as_vgpr(ctx, src)); } else if (dst.regClass() == v1) { if (ctx->block->fp_mode.must_flush_denorms32) src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src)); @@ -1878,8 +1854,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fabs: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x7FFFu), as_vgpr(ctx, src)); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0x7FFFu), as_vgpr(ctx, src)); } else if (dst.regClass() == v1) { if (ctx->block->fp_mode.must_flush_denorms32) src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src)); @@ -1901,8 +1876,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fsat: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop3(aco_opcode::v_med3_f16, bld.def(v1), Operand(0u), Operand(0x3f800000u), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop3(aco_opcode::v_med3_f16, Definition(dst), Operand(0u), Operand(0x3f800000u), src); } else if (dst.regClass() == v1) { bld.vop3(aco_opcode::v_med3_f32, Definition(dst), Operand(0u), Operand(0x3f800000u), src); /* apparently, it is not necessary to flush denorms if this instruction is used with these operands */ @@ -1921,8 +1895,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_flog2: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_log_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f16, dst); } else if (dst.regClass() == v1) { emit_log2(ctx, bld, Definition(dst), src); } else { @@ -1935,8 +1908,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_frcp: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_rcp_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f16, dst); } else if (dst.regClass() == v1) { emit_rcp(ctx, bld, Definition(dst), src); } else if (dst.regClass() == v2) { @@ -1950,9 +1922,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fexp2: { if (dst.regClass() == v2b) { - Temp src = get_alu_src(ctx, instr->src[0]); - Temp tmp = bld.vop1(aco_opcode::v_exp_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f32, dst); } else { @@ -1965,8 +1935,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fsqrt: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_sqrt_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f16, dst); } else if (dst.regClass() == v1) { emit_sqrt(ctx, bld, Definition(dst), src); } else if (dst.regClass() == v2) { @@ -1980,9 +1949,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_ffract: { if (dst.regClass() == v2b) { - Temp src = get_alu_src(ctx, instr->src[0]); - Temp tmp = bld.vop1(aco_opcode::v_fract_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f32, dst); } else if (dst.regClass() == v2) { @@ -1997,8 +1964,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_ffloor: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_floor_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_floor_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_floor_f32, dst); } else if (dst.regClass() == v2) { @@ -2013,8 +1979,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fceil: { Temp src0 = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_ceil_f16, bld.def(v1), src0); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_ceil_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_ceil_f32, dst); } else if (dst.regClass() == v2) { @@ -2044,8 +2009,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_ftrunc: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_trunc_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_trunc_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_trunc_f32, dst); } else if (dst.regClass() == v2) { @@ -2060,8 +2024,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fround_even: { Temp src0 = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_rndne_f16, bld.def(v1), src0); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop1_instruction(ctx, instr, aco_opcode::v_rndne_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rndne_f32, dst); } else if (dst.regClass() == v2) { @@ -2106,8 +2069,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == v2b) { Temp tmp = bld.vop2(aco_opcode::v_mul_f16, bld.def(v1), half_pi, src); aco_opcode opcode = instr->op == nir_op_fsin ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16; - tmp = bld.vop1(opcode, bld.def(v1), tmp); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop1(opcode, Definition(dst), tmp); } else if (dst.regClass() == v1) { Temp tmp = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), half_pi, src); @@ -2128,9 +2090,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); if (dst.regClass() == v2b) { - Temp tmp = bld.tmp(v1); - emit_vop2_instruction(ctx, instr, aco_opcode::v_ldexp_f16, tmp, false); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + emit_vop2_instruction(ctx, instr, aco_opcode::v_ldexp_f16, dst, false); } else if (dst.regClass() == v1) { bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst), as_vgpr(ctx, src0), src1); } else if (dst.regClass() == v2) { @@ -2145,8 +2105,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_frexp_sig: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp tmp = bld.vop1(aco_opcode::v_frexp_mant_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop1(aco_opcode::v_frexp_mant_f16, Definition(dst), src); } else if (dst.regClass() == v1) { bld.vop1(aco_opcode::v_frexp_mant_f32, Definition(dst), src); } else if (dst.regClass() == v2) { @@ -2183,8 +2142,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f16, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); src = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), one, src, cond); cond = bld.vopc(aco_opcode::v_cmp_le_f16, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); - Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), minus_one, src, cond); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), minus_one, src, cond); } else if (dst.regClass() == v1) { Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f32, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); src = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0x3f800000u), src, cond); @@ -2212,16 +2170,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 64) src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src); - src = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); break; } case nir_op_f2f16_rtz: { Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 64) src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src); - src = bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, bld.def(v1), src, Operand(0u)); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src, Operand(0u)); break; } case nir_op_f2f32: { @@ -2248,8 +2204,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 8) src = convert_int(bld, src, 8, 16, true); - Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_i16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop1(aco_opcode::v_cvt_f16_i16, Definition(dst), src); break; } case nir_op_i2f32: { @@ -2288,8 +2243,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 8) src = convert_int(bld, src, 8, 16, false); - Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop1(aco_opcode::v_cvt_f16_u16, Definition(dst), src); break; } case nir_op_u2f32: { @@ -2581,8 +2535,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) bld.sop2(aco_opcode::s_mul_i32, Definition(dst), Operand(0x3c00u), src); } else if (dst.regClass() == v2b) { Temp one = bld.copy(bld.def(v1), Operand(0x3c00u)); - Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, src); - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), one, src); } else { unreachable("Wrong destination register class for nir_op_b2f16."); } |