aboutsummaryrefslogtreecommitdiffstats
path: root/src/amd/compiler
diff options
context:
space:
mode:
authorRhys Perry <[email protected]>2019-12-04 20:18:05 +0000
committerDaniel Schürmann <[email protected]>2020-04-03 23:13:15 +0100
commitb84d59af50a53959fcde232ee2682e77569a7da2 (patch)
tree3995f677855421458587159651c26eddb9270c38 /src/amd/compiler
parent00312f3c95d9ef2f545a8479d6ad289bc791974b (diff)
aco: add SDWA_instruction
Signed-off-by: Rhys Perry <[email protected]> Reviewed-by: Daniel Schürmann <[email protected]> Reviewed-By: Timur Kristóf <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4002>
Diffstat (limited to 'src/amd/compiler')
-rw-r--r--src/amd/compiler/aco_assembler.cpp43
-rw-r--r--src/amd/compiler/aco_ir.h54
-rw-r--r--src/amd/compiler/aco_print_asm.cpp3
-rw-r--r--src/amd/compiler/aco_print_ir.cpp62
-rw-r--r--src/amd/compiler/aco_validate.cpp52
5 files changed, 207 insertions, 7 deletions
diff --git a/src/amd/compiler/aco_assembler.cpp b/src/amd/compiler/aco_assembler.cpp
index c46208b13b4..33bc612bac4 100644
--- a/src/amd/compiler/aco_assembler.cpp
+++ b/src/amd/compiler/aco_assembler.cpp
@@ -547,7 +547,7 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
/* first emit the instruction without the DPP operand */
Operand dpp_op = instr->operands[0];
instr->operands[0] = Operand(PhysReg{250}, v1);
- instr->format = (Format) ((uint32_t) instr->format & ~(1 << 14));
+ instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::DPP);
emit_instruction(ctx, out, instr);
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
uint32_t encoding = (0xF & dpp->row_mask) << 28;
@@ -561,6 +561,47 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
encoding |= (0xFF) & dpp_op.physReg();
out.push_back(encoding);
return;
+ } else if (instr->isSDWA()) {
+ /* first emit the instruction without the SDWA operand */
+ Operand sdwa_op = instr->operands[0];
+ instr->operands[0] = Operand(PhysReg{249}, v1);
+ instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::SDWA);
+ emit_instruction(ctx, out, instr);
+
+ SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
+ uint32_t encoding = 0;
+
+ if ((uint16_t)instr->format & (uint16_t)Format::VOPC) {
+ if (instr->definitions[0].physReg() != vcc) {
+ encoding |= instr->definitions[0].physReg() << 8;
+ encoding |= 1 << 15;
+ }
+ encoding |= (sdwa->clamp ? 1 : 0) << 13;
+ } else {
+ encoding |= (uint32_t)(sdwa->dst_sel & sdwa_asuint) << 8;
+ uint32_t dst_u = sdwa->dst_sel & sdwa_sext ? 1 : 0;
+ encoding |= dst_u << 11;
+ encoding |= (sdwa->clamp ? 1 : 0) << 13;
+ encoding |= sdwa->omod << 14;
+ }
+
+ encoding |= (uint32_t)(sdwa->sel[0] & sdwa_asuint) << 16;
+ encoding |= sdwa->sel[0] & sdwa_sext ? 1 << 19 : 0;
+ encoding |= sdwa->abs[0] << 21;
+ encoding |= sdwa->neg[0] << 20;
+
+ if (instr->operands.size() >= 2) {
+ encoding |= (uint32_t)(sdwa->sel[1] & sdwa_asuint) << 24;
+ encoding |= sdwa->sel[1] & sdwa_sext ? 1 << 27 : 0;
+ encoding |= sdwa->abs[1] << 29;
+ encoding |= sdwa->neg[1] << 28;
+ }
+
+ encoding |= 0xFF & sdwa_op.physReg();
+ encoding |= (sdwa_op.physReg() < 256) << 23;
+ if (instr->operands.size() >= 2)
+ encoding |= (instr->operands[1].physReg() < 256) << 31;
+ out.push_back(encoding);
} else {
unreachable("unimplemented instruction format");
}
diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h
index 05a9754c6b1..c8b5c00e1f2 100644
--- a/src/amd/compiler/aco_ir.h
+++ b/src/amd/compiler/aco_ir.h
@@ -169,6 +169,11 @@ constexpr Format asVOP3(Format format) {
return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format);
};
+constexpr Format asSDWA(Format format) {
+ assert(format == Format::VOP1 || format == Format::VOP2 || format == Format::VOPC);
+ return (Format) ((uint32_t) Format::SDWA | (uint32_t) format);
+}
+
enum class RegType {
none = 0,
sgpr,
@@ -841,6 +846,55 @@ struct DPP_instruction : public Instruction {
bool bound_ctrl : 1;
};
+enum sdwa_sel : uint8_t {
+ /* masks */
+ sdwa_wordnum = 0x1,
+ sdwa_bytenum = 0x3,
+ sdwa_asuint = 0x7,
+
+ /* flags */
+ sdwa_isword = 0x4,
+ sdwa_sext = 0x8,
+
+ /* specific values */
+ sdwa_ubyte0 = 0,
+ sdwa_ubyte1 = 1,
+ sdwa_ubyte2 = 2,
+ sdwa_ubyte3 = 3,
+ sdwa_uword0 = sdwa_isword | 0,
+ sdwa_uword1 = sdwa_isword | 1,
+ sdwa_udword = 6,
+
+ sdwa_sbyte0 = sdwa_ubyte0 | sdwa_sext,
+ sdwa_sbyte1 = sdwa_ubyte1 | sdwa_sext,
+ sdwa_sbyte2 = sdwa_ubyte2 | sdwa_sext,
+ sdwa_sbyte3 = sdwa_ubyte3 | sdwa_sext,
+ sdwa_sword0 = sdwa_uword0 | sdwa_sext,
+ sdwa_sword1 = sdwa_uword1 | sdwa_sext,
+ sdwa_sdword = sdwa_udword | sdwa_sext,
+};
+
+/**
+ * Sub-Dword Addressing Format:
+ * This format can be used for VOP1, VOP2 or VOPC instructions.
+ *
+ * omod and SGPR/constant operands are only available on GFX9+. For VOPC,
+ * the definition doesn't have to be VCC on GFX9+.
+ *
+ */
+struct SDWA_instruction : public Instruction {
+ /* these destination modifiers aren't available with VOPC except for
+ * clamp on GFX8 */
+ unsigned dst_sel:4;
+ bool dst_preserve:1;
+ bool clamp:1;
+ unsigned omod:2; /* GFX9+ */
+
+ unsigned sel[2];
+ bool neg[2];
+ bool abs[2];
+};
+
struct Interp_instruction : public Instruction {
uint8_t attribute;
uint8_t component;
diff --git a/src/amd/compiler/aco_print_asm.cpp b/src/amd/compiler/aco_print_asm.cpp
index fead382c7cf..e2dbc5bd8b6 100644
--- a/src/amd/compiler/aco_print_asm.cpp
+++ b/src/amd/compiler/aco_print_asm.cpp
@@ -140,6 +140,9 @@ void print_asm(Program *program, std::vector<uint32_t>& binary,
if (!l && program->chip_class == GFX9 && ((binary[pos] & 0xffff8000) == 0xd1348000)) { /* not actually an invalid instruction */
out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_add_u32_e64 + clamp";
new_pos = pos + 2;
+ } else if (program->chip_class == GFX10 && l == 4 && ((binary[pos] & 0xfe0001ff) == 0x020000f9)) {
+ out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_cndmask_b32 + sdwa";
+ new_pos = pos + 2;
} else if (!l) {
out << std::left << std::setw(align_width) << std::setfill(' ') << "(invalid instruction)";
new_pos = pos + 1;
diff --git a/src/amd/compiler/aco_print_ir.cpp b/src/amd/compiler/aco_print_ir.cpp
index 7564b52c17c..43afe0a77c0 100644
--- a/src/amd/compiler/aco_print_ir.cpp
+++ b/src/amd/compiler/aco_print_ir.cpp
@@ -528,7 +528,38 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
if (dpp->bound_ctrl)
fprintf(output, " bound_ctrl:1");
} else if ((int)instr->format & (int)Format::SDWA) {
- fprintf(output, " (printing unimplemented)");
+ SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
+ switch (sdwa->omod) {
+ case 1:
+ fprintf(output, " *2");
+ break;
+ case 2:
+ fprintf(output, " *4");
+ break;
+ case 3:
+ fprintf(output, " *0.5");
+ break;
+ }
+ if (sdwa->clamp)
+ fprintf(output, " clamp");
+ switch (sdwa->dst_sel & sdwa_asuint) {
+ case sdwa_udword:
+ break;
+ case sdwa_ubyte0:
+ case sdwa_ubyte1:
+ case sdwa_ubyte2:
+ case sdwa_ubyte3:
+ fprintf(output, " dst_sel:%sbyte%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
+ sdwa->dst_sel & sdwa_bytenum);
+ break;
+ case sdwa_uword0:
+ case sdwa_uword1:
+ fprintf(output, " dst_sel:%sword%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
+ sdwa->dst_sel & sdwa_wordnum);
+ break;
+ }
+ if (sdwa->dst_preserve)
+ fprintf(output, " dst_preserve");
}
}
@@ -546,23 +577,33 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
if (instr->operands.size()) {
bool abs[instr->operands.size()];
bool neg[instr->operands.size()];
+ uint8_t sel[instr->operands.size()];
if ((int)instr->format & (int)Format::VOP3A) {
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(instr);
for (unsigned i = 0; i < instr->operands.size(); ++i) {
abs[i] = vop3->abs[i];
neg[i] = vop3->neg[i];
+ sel[i] = sdwa_udword;
}
} else if (instr->isDPP()) {
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
- assert(instr->operands.size() <= 2);
for (unsigned i = 0; i < instr->operands.size(); ++i) {
- abs[i] = dpp->abs[i];
- neg[i] = dpp->neg[i];
+ abs[i] = i < 2 ? dpp->abs[i] : false;
+ neg[i] = i < 2 ? dpp->neg[i] : false;
+ sel[i] = sdwa_udword;
+ }
+ } else if (instr->isSDWA()) {
+ SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
+ for (unsigned i = 0; i < instr->operands.size(); ++i) {
+ abs[i] = i < 2 ? sdwa->abs[i] : false;
+ neg[i] = i < 2 ? sdwa->neg[i] : false;
+ sel[i] = i < 2 ? sdwa->sel[i] : sdwa_udword;
}
} else {
for (unsigned i = 0; i < instr->operands.size(); ++i) {
abs[i] = false;
neg[i] = false;
+ sel[i] = sdwa_udword;
}
}
for (unsigned i = 0; i < instr->operands.size(); ++i) {
@@ -575,7 +616,20 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
fprintf(output, "-");
if (abs[i])
fprintf(output, "|");
+ if (sel[i] & sdwa_sext)
+ fprintf(output, "sext(");
print_operand(&instr->operands[i], output);
+ if (sel[i] & sdwa_sext)
+ fprintf(output, ")");
+ if ((sel[i] & sdwa_asuint) == sdwa_udword) {
+ /* print nothing */
+ } else if (sel[i] & sdwa_isword) {
+ unsigned index = sel[i] & sdwa_wordnum;
+ fprintf(output, "[%u:%u]", index * 16, index * 16 + 15);
+ } else {
+ unsigned index = sel[i] & sdwa_bytenum;
+ fprintf(output, "[%u:%u]", index * 8, index * 8 + 7);
+ }
if (abs[i])
fprintf(output, "|");
}
diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp
index e967f0ca9e7..4bbce14a86a 100644
--- a/src/amd/compiler/aco_validate.cpp
+++ b/src/amd/compiler/aco_validate.cpp
@@ -93,6 +93,50 @@ void validate(Program* program, FILE * output)
"Format cannot have VOP3A/VOP3B applied", instr.get());
}
+ /* check SDWA */
+ if (instr->isSDWA()) {
+ check(base_format == Format::VOP2 ||
+ base_format == Format::VOP1 ||
+ base_format == Format::VOPC,
+ "Format cannot have SDWA applied", instr.get());
+
+ check(program->chip_class >= GFX8, "SDWA is GFX8+ only", instr.get());
+
+ SDWA_instruction *sdwa = static_cast<SDWA_instruction*>(instr.get());
+ check(sdwa->omod == 0 || program->chip_class >= GFX9, "SDWA omod only supported on GFX9+", instr.get());
+ if (base_format == Format::VOPC) {
+ check(sdwa->clamp == false || program->chip_class == GFX8, "SDWA VOPC clamp only supported on GFX8", instr.get());
+ check((instr->definitions[0].isFixed() && instr->definitions[0].physReg() == vcc) ||
+ program->chip_class >= GFX9,
+ "SDWA+VOPC definition must be fixed to vcc on GFX8", instr.get());
+ }
+
+ if (instr->operands.size() >= 3) {
+ check(instr->operands[2].isFixed() && instr->operands[2].physReg() == vcc,
+ "3rd operand must be fixed to vcc with SDWA", instr.get());
+ }
+ if (instr->definitions.size() >= 2) {
+ check(instr->definitions[1].isFixed() && instr->definitions[1].physReg() == vcc,
+ "2nd definition must be fixed to vcc with SDWA", instr.get());
+ }
+
+ check(instr->opcode != aco_opcode::v_madmk_f32 &&
+ instr->opcode != aco_opcode::v_madak_f32 &&
+ instr->opcode != aco_opcode::v_madmk_f16 &&
+ instr->opcode != aco_opcode::v_madak_f16 &&
+ instr->opcode != aco_opcode::v_readfirstlane_b32 &&
+ instr->opcode != aco_opcode::v_clrexcp &&
+ instr->opcode != aco_opcode::v_swap_b32,
+ "SDWA can't be used with this opcode", instr.get());
+ if (program->chip_class != GFX8) {
+ check(instr->opcode != aco_opcode::v_mac_f32 &&
+ instr->opcode != aco_opcode::v_mac_f16 &&
+ instr->opcode != aco_opcode::v_fmac_f32 &&
+ instr->opcode != aco_opcode::v_fmac_f16,
+ "SDWA can't be used with this opcode", instr.get());
+ }
+ }
+
/* check for undefs */
for (unsigned i = 0; i < instr->operands.size(); i++) {
if (instr->operands[i].isUndefined()) {
@@ -137,6 +181,10 @@ void validate(Program* program, FILE * output)
if (program->chip_class >= GFX10 && !is_shift64)
const_bus_limit = 2;
+ uint32_t scalar_mask = instr->isVOP3() ? 0x7 : 0x5;
+ if (instr->isSDWA())
+ scalar_mask = program->chip_class >= GFX9 ? 0x7 : 0x4;
+
check(instr->definitions[0].getTemp().type() == RegType::vgpr ||
(int) instr->format & (int) Format::VOPC ||
instr->opcode == aco_opcode::v_readfirstlane_b32 ||
@@ -158,7 +206,7 @@ void validate(Program* program, FILE * output)
continue;
}
if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
- check(i != 1 || instr->isVOP3(), "Wrong source position for SGPR argument", instr.get());
+ check(scalar_mask & (1 << i), "Wrong source position for SGPR argument", instr.get());
if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
if (num_sgprs < 2)
@@ -167,7 +215,7 @@ void validate(Program* program, FILE * output)
}
if (op.isConstant() && !op.isLiteral())
- check(i == 0 || instr->isVOP3(), "Wrong source position for constant argument", instr.get());
+ check(scalar_mask & (1 << i), "Wrong source position for constant argument", instr.get());
}
check(num_sgprs + (literal.isUndefined() ? 0 : 1) <= const_bus_limit, "Too many SGPRs/literals", instr.get());
}