summaryrefslogtreecommitdiffstats
path: root/src/gallium
diff options
context:
space:
mode:
Diffstat (limited to 'src/gallium')
-rw-r--r--src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp116
1 files changed, 116 insertions, 0 deletions
diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp
index c99680613f1..d788b36e1df 100644
--- a/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp
+++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp
@@ -2246,6 +2246,121 @@ LateAlgebraicOpt::visit(Instruction *i)
// =============================================================================
+// Split 64-bit MUL and MAD
+class Split64BitOpPreRA : public Pass
+{
+private:
+ virtual bool visit(BasicBlock *);
+ void split64MulMad(Function *, Instruction *, DataType);
+
+ BuildUtil bld;
+};
+
+bool
+Split64BitOpPreRA::visit(BasicBlock *bb)
+{
+ Instruction *i, *next;
+ Modifier mod;
+
+ for (i = bb->getEntry(); i; i = next) {
+ next = i->next;
+
+ DataType hTy;
+ switch (i->dType) {
+ case TYPE_U64: hTy = TYPE_U32; break;
+ case TYPE_S64: hTy = TYPE_S32; break;
+ default:
+ continue;
+ }
+
+ if (i->op == OP_MAD || i->op == OP_MUL)
+ split64MulMad(func, i, hTy);
+ }
+
+ return true;
+}
+
+void
+Split64BitOpPreRA::split64MulMad(Function *fn, Instruction *i, DataType hTy)
+{
+ assert(i->op == OP_MAD || i->op == OP_MUL);
+ assert(!isFloatType(i->dType) && !isFloatType(i->sType));
+ assert(typeSizeof(hTy) == 4);
+
+ bld.setPosition(i, true);
+
+ Value *zero = bld.mkImm(0u);
+ Value *carry = bld.getSSA(1, FILE_FLAGS);
+
+ // We want to compute `d = a * b (+ c)?`, where a, b, c and d are 64-bit
+ // values (a, b and c might be 32-bit values), using 32-bit operations. This
+ // gives the following operations:
+ // * `d.low = low(a.low * b.low) (+ c.low)?`
+ // * `d.high = low(a.high * b.low) + low(a.low * b.high)
+ // + high(a.low * b.low) (+ c.high)?`
+ //
+ // To compute the high bits, we can split in the following operations:
+ // * `tmp1 = low(a.high * b.low) (+ c.high)?`
+ // * `tmp2 = low(a.low * b.high) + tmp1`
+ // * `d.high = high(a.low * b.low) + tmp2`
+ //
+ // mkSplit put lower bits at index 0 and higher bits at index 1
+
+ Value *op1[2];
+ if (i->getSrc(0)->reg.size == 8)
+ bld.mkSplit(op1, 4, i->getSrc(0));
+ else {
+ op1[0] = i->getSrc(0);
+ op1[1] = zero;
+ }
+ Value *op2[2];
+ if (i->getSrc(1)->reg.size == 8)
+ bld.mkSplit(op2, 4, i->getSrc(1));
+ else {
+ op2[0] = i->getSrc(1);
+ op2[1] = zero;
+ }
+
+ Value *op3[2] = { NULL, NULL };
+ if (i->op == OP_MAD) {
+ if (i->getSrc(2)->reg.size == 8)
+ bld.mkSplit(op3, 4, i->getSrc(2));
+ else {
+ op3[0] = i->getSrc(2);
+ op3[1] = zero;
+ }
+ }
+
+ Value *tmpRes1Hi = bld.getSSA();
+ if (i->op == OP_MAD)
+ bld.mkOp3(OP_MAD, hTy, tmpRes1Hi, op1[1], op2[0], op3[1]);
+ else
+ bld.mkOp2(OP_MUL, hTy, tmpRes1Hi, op1[1], op2[0]);
+
+ Value *tmpRes2Hi = bld.mkOp3v(OP_MAD, hTy, bld.getSSA(), op1[0], op2[1], tmpRes1Hi);
+
+ Value *def[2] = { bld.getSSA(), bld.getSSA() };
+
+ // If it was a MAD, add the carry from the low bits
+ // It is not needed if it was a MUL, since we added high(a.low * b.low) to
+ // d.high
+ if (i->op == OP_MAD)
+ bld.mkOp3(OP_MAD, hTy, def[0], op1[0], op2[0], op3[0])->setFlagsDef(1, carry);
+ else
+ bld.mkOp2(OP_MUL, hTy, def[0], op1[0], op2[0]);
+
+ Instruction *hiPart3 = bld.mkOp3(OP_MAD, hTy, def[1], op1[0], op2[0], tmpRes2Hi);
+ hiPart3->subOp = NV50_IR_SUBOP_MUL_HIGH;
+ if (i->op == OP_MAD)
+ hiPart3->setFlagsSrc(3, carry);
+
+ bld.mkOp2(OP_MERGE, i->dType, i->getDef(0), def[0], def[1]);
+
+ delete_Instruction(fn->getProgram(), i);
+}
+
+// =============================================================================
+
static inline void
updateLdStOffset(Instruction *ldst, int32_t offset, Function *fn)
{
@@ -3552,6 +3667,7 @@ Program::optimizeSSA(int level)
RUN_PASS(2, ModifierFolding, run); // before load propagation -> less checks
RUN_PASS(1, ConstantFolding, foldAll);
RUN_PASS(2, LateAlgebraicOpt, run);
+ RUN_PASS(1, Split64BitOpPreRA, run);
RUN_PASS(1, LoadPropagation, run);
RUN_PASS(1, IndirectPropagation, run);
RUN_PASS(2, MemoryOpt, run);