diff options
author | Karol Herbst <[email protected]> | 2018-12-07 09:44:55 +0100 |
---|---|---|
committer | Karol Herbst <[email protected]> | 2018-12-09 18:19:59 +0100 |
commit | 77944fb2b7c9b40539084f600b5df4fff18e9640 (patch) | |
tree | 18df186daef64746c8c119ff34a88035b570f3a4 /src/gallium/drivers | |
parent | d63a13308229b5c5a08358ccacdac83272596c78 (diff) |
nv50/ir: fix use-after-free in ConstantFolding::visit
opnd() might delete the passed in instruction, but it's used through
i->srcExists() later in visit
v2: use continue instead return
v3: use brackets for the outer if/else chain
Signed-off-by: Karol Herbst <[email protected]>
Reviewed-by: Ilia Mirkin <[email protected]>
Diffstat (limited to 'src/gallium/drivers')
-rw-r--r-- | src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp | 82 |
1 files changed, 49 insertions, 33 deletions
diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp index 202faf0746a..5d3b4aba9cc 100644 --- a/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp +++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp @@ -370,7 +370,8 @@ private: void expr(Instruction *, ImmediateValue&, ImmediateValue&); void expr(Instruction *, ImmediateValue&, ImmediateValue&, ImmediateValue&); - void opnd(Instruction *, ImmediateValue&, int s); + /* true if i was deleted */ + bool opnd(Instruction *i, ImmediateValue&, int s); void opnd3(Instruction *, ImmediateValue&); void unary(Instruction *, const ImmediateValue&); @@ -414,18 +415,21 @@ ConstantFolding::visit(BasicBlock *bb) if (i->srcExists(2) && i->src(0).getImmediate(src0) && i->src(1).getImmediate(src1) && - i->src(2).getImmediate(src2)) + i->src(2).getImmediate(src2)) { expr(i, src0, src1, src2); - else + } else if (i->srcExists(1) && - i->src(0).getImmediate(src0) && i->src(1).getImmediate(src1)) + i->src(0).getImmediate(src0) && i->src(1).getImmediate(src1)) { expr(i, src0, src1); - else - if (i->srcExists(0) && i->src(0).getImmediate(src0)) - opnd(i, src0, 0); - else - if (i->srcExists(1) && i->src(1).getImmediate(src1)) - opnd(i, src1, 1); + } else + if (i->srcExists(0) && i->src(0).getImmediate(src0)) { + if (opnd(i, src0, 0)) + continue; + } else + if (i->srcExists(1) && i->src(1).getImmediate(src1)) { + if (opnd(i, src1, 1)) + continue; + } if (i->srcExists(2) && i->src(2).getImmediate(src2)) opnd3(i, src2); } @@ -1011,12 +1015,13 @@ ConstantFolding::createMul(DataType ty, Value *def, Value *a, int64_t b, Value * return false; } -void +bool ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) { const int t = !s; const operation op = i->op; Instruction *newi = i; + bool deleted = false; switch (i->op) { case OP_SPLIT: { @@ -1036,6 +1041,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) val >>= bitsize; } delete_Instruction(prog, i); + deleted = true; break; } case OP_MUL: @@ -1050,6 +1056,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) newi = bld.mkCmp(OP_SET, CC_LT, TYPE_S32, i->getDef(0), TYPE_S32, i->getSrc(t), bld.mkImm(0)); delete_Instruction(prog, i); + deleted = true; } else if (imm0.isInteger(0) || imm0.isInteger(1)) { // The high bits can't be set in this case (either mul by 0 or // unsigned by 1) @@ -1101,8 +1108,10 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) if (!isFloatType(i->dType) && !i->src(t).mod) { bld.setPosition(i, false); int64_t b = typeSizeof(i->dType) == 8 ? imm0.reg.data.s64 : imm0.reg.data.s32; - if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, NULL)) + if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, NULL)) { delete_Instruction(prog, i); + deleted = true; + } } else if (i->postFactor && i->sType == TYPE_F32) { /* Can't emit a postfactor with an immediate, have to fold it in */ @@ -1139,8 +1148,10 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) if (!isFloatType(i->dType) && !i->subOp && !i->src(t).mod && !i->src(2).mod) { bld.setPosition(i, false); int64_t b = typeSizeof(i->dType) == 8 ? imm0.reg.data.s64 : imm0.reg.data.s32; - if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, i->getSrc(2))) + if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, i->getSrc(2))) { delete_Instruction(prog, i); + deleted = true; + } } break; case OP_SUB: @@ -1210,6 +1221,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) bld.mkOp2(OP_SHR, TYPE_U32, i->getDef(0), tB, bld.mkImm(s)); delete_Instruction(prog, i); + deleted = true; } else if (imm0.reg.data.s32 == -1) { i->op = OP_NEG; @@ -1242,6 +1254,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) bld.mkOp1(OP_NEG, TYPE_S32, i->getDef(0), tB); delete_Instruction(prog, i); + deleted = true; } break; @@ -1273,6 +1286,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) newi = bld.mkOp2(OP_UNION, TYPE_S32, i->getDef(0), v1, v2); delete_Instruction(prog, i); + deleted = true; } } else if (s == 1) { // In this case, we still want the optimized lowering that we get @@ -1289,6 +1303,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) newi->src(1).mod = Modifier(NV50_IR_MOD_NEG); delete_Instruction(prog, i); + deleted = true; } break; @@ -1301,7 +1316,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) CmpInstruction *si = findOriginForTestWithZero(i->getSrc(t)); CondCode cc, ccZ; if (imm0.reg.data.u32 != 0 || !si) - return; + return false; cc = si->setCond; ccZ = (CondCode)((unsigned int)i->asCmp()->setCond & ~CC_U); // We do everything assuming var (cmp) 0, reverse the condition if 0 is @@ -1327,7 +1342,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) case CC_GT: break; // bool > 0 -- bool case CC_NE: break; // bool != 0 -- bool default: - return; + return false; } // Update the condition of this SET to be identical to the origin set, @@ -1362,13 +1377,13 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) } else if (src->asCmp()) { CmpInstruction *cmp = src->asCmp(); if (!cmp || cmp->op == OP_SLCT || cmp->getDef(0)->refCount() > 1) - return; + return false; if (!prog->getTarget()->isOpSupported(cmp->op, TYPE_F32)) - return; + return false; if (imm0.reg.data.f32 != 1.0) - return; + return false; if (cmp->dType != TYPE_U32) - return; + return false; cmp->dType = TYPE_F32; if (i->src(t).mod != Modifier(0)) { @@ -1435,13 +1450,13 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) case OP_MUL: int muls; if (isFloatType(si->dType)) - return; + return false; if (si->src(1).getImmediate(imm1)) muls = 1; else if (si->src(0).getImmediate(imm1)) muls = 0; else - return; + return false; bld.setPosition(i, false); i->op = OP_MUL; @@ -1452,15 +1467,15 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) case OP_ADD: int adds; if (isFloatType(si->dType)) - return; + return false; if (si->op != OP_SUB && si->src(0).getImmediate(imm1)) adds = 0; else if (si->src(1).getImmediate(imm1)) adds = 1; else - return; + return false; if (si->src(!adds).mod != Modifier(0)) - return; + return false; // SHL(ADD(x, y), z) = ADD(SHL(x, z), SHL(y, z)) // This is more operations, but if one of x, y is an immediate, then @@ -1475,7 +1490,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) bld.mkImm(imm0.reg.data.u32))); break; default: - return; + return false; } } break; @@ -1500,7 +1515,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) case TYPE_S32: res = util_last_bit_signed(imm0.reg.data.s32) - 1; break; case TYPE_U32: res = util_last_bit(imm0.reg.data.u32) - 1; break; default: - return; + return false; } if (i->subOp == NV50_IR_SUBOP_BFIND_SAMT && res >= 0) res = 31 - res; @@ -1526,11 +1541,11 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) // TODO: handle 64-bit values properly if (typeSizeof(i->dType) == 8 || typeSizeof(i->sType) == 8) - return; + return false; // TODO: handle single byte/word extractions if (i->subOp) - return; + return false; bld.setPosition(i, true); /* make sure bld is init'ed */ @@ -1567,7 +1582,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) CLAMP(imm0.reg.data.u16, umin, umax) : \ imm0.reg.data.u16; \ break; \ - default: return; \ + default: return false; \ } \ i->setSrc(0, bld.mkImm(res.data.dst)); \ break @@ -1594,7 +1609,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) case TYPE_S16: res.data.f32 = (float) imm0.reg.data.s16; break; case TYPE_S32: res.data.f32 = (float) imm0.reg.data.s32; break; default: - return; + return false; } i->setSrc(0, bld.mkImm(res.data.f32)); break; @@ -1615,12 +1630,12 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) case TYPE_S16: res.data.f64 = (double) imm0.reg.data.s16; break; case TYPE_S32: res.data.f64 = (double) imm0.reg.data.s32; break; default: - return; + return false; } i->setSrc(0, bld.mkImm(res.data.f64)); break; default: - return; + return false; } #undef CASE @@ -1631,7 +1646,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) break; } default: - return; + return false; } // This can get left behind some of the optimizations which simplify @@ -1646,6 +1661,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s) if (newi->op != op) foldCount++; + return deleted; } // ============================================================================= |