diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/compiler/nir/nir_loop_analyze.c | 289 |
1 files changed, 149 insertions, 140 deletions
diff --git a/src/compiler/nir/nir_loop_analyze.c b/src/compiler/nir/nir_loop_analyze.c index c64314aa378..587cf08fa02 100644 --- a/src/compiler/nir/nir_loop_analyze.c +++ b/src/compiler/nir/nir_loop_analyze.c @@ -32,7 +32,10 @@ typedef enum { basic_induction } nir_loop_variable_type; -struct nir_basic_induction_var; +typedef struct nir_basic_induction_var { + nir_alu_instr *alu; /* The def of the alu-operation */ + nir_ssa_def *def_outside_loop; /* The phi-src outside the loop */ +} nir_basic_induction_var; typedef struct { /* A link for the work list */ @@ -57,13 +60,6 @@ typedef struct { } nir_loop_variable; -typedef struct nir_basic_induction_var { - nir_op alu_op; /* The type of alu-operation */ - nir_loop_variable *alu_def; /* The def of the alu-operation */ - nir_loop_variable *invariant; /* The invariant alu-operand */ - nir_loop_variable *def_outside_loop; /* The phi-src outside the loop */ -} nir_basic_induction_var; - typedef struct { /* The loop we store information for */ nir_loop *loop; @@ -300,6 +296,19 @@ phi_instr_as_alu(nir_phi_instr *phi) } static bool +alu_src_has_identity_swizzle(nir_alu_instr *alu, unsigned src_idx) +{ + assert(nir_op_infos[alu->op].input_sizes[src_idx] == 0); + assert(alu->dest.dest.is_ssa); + for (unsigned i = 0; i < alu->dest.dest.ssa.num_components; i++) { + if (alu->src[src_idx].swizzle[i] != i) + return false; + } + + return true; +} + +static bool compute_induction_information(loop_info_state *state) { bool found_induction_var = false; @@ -320,15 +329,10 @@ compute_induction_information(loop_info_state *state) if (!is_var_phi(var)) continue; - /* We only handle scalars because none of the rest of the loop analysis - * code can properly handle swizzles. - */ - if (var->def->num_components > 1) - continue; - nir_phi_instr *phi = nir_instr_as_phi(var->def->parent_instr); nir_basic_induction_var *biv = rzalloc(state, nir_basic_induction_var); + nir_loop_variable *alu_src_var = NULL; nir_foreach_phi_src(src, phi) { nir_loop_variable *src_var = get_loop_var(src->src.ssa, state); @@ -352,32 +356,36 @@ compute_induction_information(loop_info_state *state) } } - if (!src_var->in_loop) { - biv->def_outside_loop = src_var; - } else if (is_var_alu(src_var)) { + if (!src_var->in_loop && !biv->def_outside_loop) { + biv->def_outside_loop = src_var->def; + } else if (is_var_alu(src_var) && !biv->alu) { + alu_src_var = src_var; nir_alu_instr *alu = nir_instr_as_alu(src_var->def->parent_instr); if (nir_op_infos[alu->op].num_inputs == 2) { - biv->alu_def = src_var; - biv->alu_op = alu->op; - for (unsigned i = 0; i < 2; i++) { - /* Is one of the operands const, and the other the phi */ - if (alu->src[i].src.ssa->parent_instr->type == nir_instr_type_load_const && - alu->src[i].swizzle[0] == 0 && - alu->src[1-i].src.ssa == &phi->dest.ssa) - assert(alu->src[1-i].swizzle[0] == 0); - biv->invariant = get_loop_var(alu->src[i].src.ssa, state); + /* Is one of the operands const, and the other the phi. The + * phi source can't be swizzled in any way. + */ + if (nir_src_is_const(alu->src[i].src) && + alu->src[1-i].src.ssa == &phi->dest.ssa && + alu_src_has_identity_swizzle(alu, 1 - i)) + biv->alu = alu; } } + + if (!biv->alu) + break; + } else { + biv->alu = NULL; + break; } } - if (biv->alu_def && biv->def_outside_loop && biv->invariant && - is_var_constant(biv->def_outside_loop)) { - assert(is_var_constant(biv->invariant)); - biv->alu_def->type = basic_induction; - biv->alu_def->ind = biv; + if (biv->alu && biv->def_outside_loop && + biv->def_outside_loop->parent_instr->type == nir_instr_type_load_const) { + alu_src_var->type = basic_induction; + alu_src_var->ind = biv; var->type = basic_induction; var->ind = biv; @@ -504,7 +512,7 @@ find_array_access_via_induction(loop_info_state *state, static bool guess_loop_limit(loop_info_state *state, nir_const_value *limit_val, - nir_loop_variable *basic_ind) + nir_ssa_scalar basic_ind) { unsigned min_array_size = 0; @@ -525,8 +533,10 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val, find_array_access_via_induction(state, nir_src_as_deref(intrin->src[0]), &array_idx); - if (basic_ind == array_idx && + if (array_idx && basic_ind.def == array_idx->def && (min_array_size == 0 || min_array_size > array_size)) { + /* Array indices are scalars */ + assert(basic_ind.def->num_components == 1); min_array_size = array_size; } @@ -537,8 +547,10 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val, find_array_access_via_induction(state, nir_src_as_deref(intrin->src[1]), &array_idx); - if (basic_ind == array_idx && + if (array_idx && basic_ind.def == array_idx->def && (min_array_size == 0 || min_array_size > array_size)) { + /* Array indices are scalars */ + assert(basic_ind.def->num_components == 1); min_array_size = array_size; } } @@ -547,7 +559,7 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val, if (min_array_size) { *limit_val = nir_const_value_for_uint(min_array_size, - basic_ind->def->bit_size); + basic_ind.def->bit_size); return true; } @@ -555,33 +567,22 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val, } static bool -try_find_limit_of_alu(nir_loop_variable *limit, nir_const_value *limit_val, +try_find_limit_of_alu(nir_ssa_scalar limit, nir_const_value *limit_val, nir_loop_terminator *terminator, loop_info_state *state) { - if(!is_var_alu(limit)) + if (!nir_ssa_scalar_is_alu(limit)) return false; - nir_alu_instr *limit_alu = nir_instr_as_alu(limit->def->parent_instr); - - if (limit_alu->op == nir_op_imin || - limit_alu->op == nir_op_fmin) { - /* We don't handle swizzles here */ - if (limit_alu->src[0].swizzle[0] > 0 || limit_alu->src[1].swizzle[0] > 0) - return false; - - limit = get_loop_var(limit_alu->src[0].src.ssa, state); - - if (!is_var_constant(limit)) - limit = get_loop_var(limit_alu->src[1].src.ssa, state); - - if (!is_var_constant(limit)) - return false; - - *limit_val = nir_instr_as_load_const(limit->def->parent_instr)->value[0]; - - terminator->exact_trip_count_unknown = true; - - return true; + nir_op limit_op = nir_ssa_scalar_alu_op(limit); + if (limit_op == nir_op_imin || limit_op == nir_op_fmin) { + for (unsigned i = 0; i < 2; i++) { + nir_ssa_scalar src = nir_ssa_scalar_chase_alu_src(limit, i); + if (nir_ssa_scalar_is_const(src)) { + *limit_val = nir_ssa_scalar_as_const_value(src); + terminator->exact_trip_count_unknown = true; + return true; + } + } } return false; @@ -696,14 +697,12 @@ test_iterations(int32_t iter_int, nir_const_value *step, static int calculate_iterations(nir_const_value *initial, nir_const_value *step, - nir_const_value *limit, nir_loop_variable *alu_def, - nir_alu_instr *cond_alu, nir_op alu_op, bool limit_rhs, + nir_const_value *limit, nir_alu_instr *alu, + nir_ssa_scalar cond, nir_op alu_op, bool limit_rhs, bool invert_cond) { assert(initial != NULL && step != NULL && limit != NULL); - nir_alu_instr *alu = nir_instr_as_alu(alu_def->def->parent_instr); - /* nir_op_isub should have been lowered away by this point */ assert(alu->op != nir_op_isub); @@ -735,8 +734,9 @@ calculate_iterations(nir_const_value *initial, nir_const_value *step, * condition and if so we assume we need to step the initial value. */ unsigned trip_offset = 0; - if (cond_alu->src[0].src.ssa == alu_def->def || - cond_alu->src[1].src.ssa == alu_def->def) { + nir_alu_instr *cond_alu = nir_instr_as_alu(cond.def->parent_instr); + if (cond_alu->src[0].src.ssa == &alu->dest.dest.ssa || + cond_alu->src[1].src.ssa == &alu->dest.dest.ssa) { trip_offset = 1; } @@ -774,9 +774,9 @@ calculate_iterations(nir_const_value *initial, nir_const_value *step, } static nir_op -inverse_comparison(nir_alu_instr *alu) +inverse_comparison(nir_op alu_op) { - switch (alu->op) { + switch (alu_op) { case nir_op_fge: return nir_op_flt; case nir_op_ige: @@ -803,29 +803,33 @@ inverse_comparison(nir_alu_instr *alu) } static bool -is_supported_terminator_condition(nir_alu_instr *alu) +is_supported_terminator_condition(nir_ssa_scalar cond) { + if (!nir_ssa_scalar_is_alu(cond)) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(cond.def->parent_instr); return nir_alu_instr_is_comparison(alu) && nir_op_infos[alu->op].num_inputs == 2; } static bool -get_induction_and_limit_vars(nir_alu_instr *alu, - nir_loop_variable **ind, - nir_loop_variable **limit, +get_induction_and_limit_vars(nir_ssa_scalar cond, + nir_ssa_scalar *ind, + nir_ssa_scalar *limit, bool *limit_rhs, loop_info_state *state) { - nir_loop_variable *rhs, *lhs; - lhs = get_loop_var(alu->src[0].src.ssa, state); - rhs = get_loop_var(alu->src[1].src.ssa, state); + nir_ssa_scalar rhs, lhs; + lhs = nir_ssa_scalar_chase_alu_src(cond, 0); + rhs = nir_ssa_scalar_chase_alu_src(cond, 1); - if (lhs->type == basic_induction) { + if (get_loop_var(lhs.def, state)->type == basic_induction) { *ind = lhs; *limit = rhs; *limit_rhs = true; return true; - } else if (rhs->type == basic_induction) { + } else if (get_loop_var(rhs.def, state)->type == basic_induction) { *ind = rhs; *limit = lhs; *limit_rhs = false; @@ -836,53 +840,40 @@ get_induction_and_limit_vars(nir_alu_instr *alu, } static bool -try_find_trip_count_vars_in_iand(nir_alu_instr **alu, - nir_loop_variable **ind, - nir_loop_variable **limit, +try_find_trip_count_vars_in_iand(nir_ssa_scalar *cond, + nir_ssa_scalar *ind, + nir_ssa_scalar *limit, bool *limit_rhs, loop_info_state *state) { - assert((*alu)->op == nir_op_ieq || (*alu)->op == nir_op_inot); - - nir_ssa_def *iand_def = (*alu)->src[0].src.ssa; - /* This is used directly in an if condition so it must be a scalar */ - assert(iand_def->num_components == 1); + const nir_op alu_op = nir_ssa_scalar_alu_op(*cond); + assert(alu_op == nir_op_ieq || alu_op == nir_op_inot); - if ((*alu)->op == nir_op_ieq) { - nir_ssa_def *zero_def = (*alu)->src[1].src.ssa; + nir_ssa_scalar iand = nir_ssa_scalar_chase_alu_src(*cond, 0); - /* We don't handle swizzles here */ - if ((*alu)->src[0].swizzle[0] > 0 || (*alu)->src[1].swizzle[0] > 0) - return false; - - if (iand_def->parent_instr->type != nir_instr_type_alu || - zero_def->parent_instr->type != nir_instr_type_load_const) { + if (alu_op == nir_op_ieq) { + nir_ssa_scalar zero = nir_ssa_scalar_chase_alu_src(*cond, 1); + if (!nir_ssa_scalar_is_alu(iand) || !nir_ssa_scalar_is_const(zero)) { /* Maybe we had it the wrong way, flip things around */ - iand_def = (*alu)->src[1].src.ssa; - zero_def = (*alu)->src[0].src.ssa; + nir_ssa_scalar tmp = zero; + zero = iand; + iand = tmp; /* If we still didn't find what we need then return */ - if (zero_def->parent_instr->type != nir_instr_type_load_const) + if (!nir_ssa_scalar_is_const(zero)) return false; } /* If the loop is not breaking on (x && y) == 0 then return */ - nir_const_value *zero = - nir_instr_as_load_const(zero_def->parent_instr)->value; - if (zero[0].i32 != 0) + if (nir_ssa_scalar_as_uint(zero) != 0) return false; } - if (iand_def->parent_instr->type != nir_instr_type_alu) - return false; - - nir_alu_instr *iand = nir_instr_as_alu(iand_def->parent_instr); - if (iand->op != nir_op_iand) + if (!nir_ssa_scalar_is_alu(iand)) return false; - /* We don't handle swizzles here */ - if ((*alu)->src[0].swizzle[0] > 0 || (*alu)->src[1].swizzle[0] > 0) + if (nir_ssa_scalar_alu_op(iand) != nir_op_iand) return false; /* Check if iand src is a terminator condition and try get induction var @@ -890,19 +881,15 @@ try_find_trip_count_vars_in_iand(nir_alu_instr **alu, */ bool found_induction_var = false; for (unsigned i = 0; i < 2; i++) { - nir_ssa_def *src = iand->src[i].src.ssa; - if (src->parent_instr->type == nir_instr_type_alu) { - nir_alu_instr *src_alu = nir_instr_as_alu(src->parent_instr); - if (is_supported_terminator_condition(src_alu) && - get_induction_and_limit_vars(src_alu, ind, limit, - limit_rhs, state)) { - *alu = src_alu; - found_induction_var = true; - - /* If we've found one with a constant limit, stop. */ - if (is_var_constant(*limit)) - return true; - } + nir_ssa_scalar src = nir_ssa_scalar_chase_alu_src(iand, i); + if (is_supported_terminator_condition(src) && + get_induction_and_limit_vars(src, ind, limit, limit_rhs, state)) { + *cond = src; + found_induction_var = true; + + /* If we've found one with a constant limit, stop. */ + if (nir_ssa_scalar_is_const(*limit)) + return true; } } @@ -926,8 +913,10 @@ find_trip_count(loop_info_state *state) list_for_each_entry(nir_loop_terminator, terminator, &state->loop->info->loop_terminator_list, loop_terminator_link) { + assert(terminator->nif->condition.is_ssa); + nir_ssa_scalar cond = { terminator->nif->condition.ssa, 0 }; - if (terminator->conditional_instr->type != nir_instr_type_alu) { + if (!nir_ssa_scalar_is_alu(cond)) { /* If we get here the loop is dead and will get cleaned up by the * nir_opt_dead_cf pass. */ @@ -935,27 +924,27 @@ find_trip_count(loop_info_state *state) continue; } - nir_alu_instr *alu = nir_instr_as_alu(terminator->conditional_instr); - nir_op alu_op = alu->op; + nir_op alu_op = nir_ssa_scalar_alu_op(cond); bool limit_rhs; - nir_loop_variable *basic_ind = NULL; - nir_loop_variable *limit; - if ((alu->op == nir_op_inot || alu->op == nir_op_ieq) && - try_find_trip_count_vars_in_iand(&alu, &basic_ind, &limit, + nir_ssa_scalar basic_ind = { NULL, 0 }; + nir_ssa_scalar limit; + if ((alu_op == nir_op_inot || alu_op == nir_op_ieq) && + try_find_trip_count_vars_in_iand(&cond, &basic_ind, &limit, &limit_rhs, state)) { + /* The loop is exiting on (x && y) == 0 so we need to get the * inverse of x or y (i.e. which ever contained the induction var) in * order to compute the trip count. */ - alu_op = inverse_comparison(alu); + alu_op = inverse_comparison(nir_ssa_scalar_alu_op(cond)); trip_count_known = false; terminator->exact_trip_count_unknown = true; } - if (!basic_ind) { - if (is_supported_terminator_condition(alu)) { - get_induction_and_limit_vars(alu, &basic_ind, + if (!basic_ind.def) { + if (is_supported_terminator_condition(cond)) { + get_induction_and_limit_vars(cond, &basic_ind, &limit, &limit_rhs, state); } } @@ -963,7 +952,7 @@ find_trip_count(loop_info_state *state) /* The comparison has to have a basic induction variable for us to be * able to find trip counts. */ - if (!basic_ind) { + if (!basic_ind.def) { trip_count_known = false; continue; } @@ -972,9 +961,8 @@ find_trip_count(loop_info_state *state) /* Attempt to find a constant limit for the loop */ nir_const_value limit_val; - if (is_var_constant(limit)) { - limit_val = - nir_instr_as_load_const(limit->def->parent_instr)->value[0]; + if (nir_ssa_scalar_is_const(limit)) { + limit_val = nir_ssa_scalar_as_const_value(limit); } else { trip_count_known = false; @@ -996,17 +984,38 @@ find_trip_count(loop_info_state *state) * Thats all thats needed to calculate the trip-count */ - nir_const_value *initial_val = - nir_instr_as_load_const(basic_ind->ind->def_outside_loop-> - def->parent_instr)->value; + nir_basic_induction_var *ind_var = + get_loop_var(basic_ind.def, state)->ind; - nir_const_value *step_val = - nir_instr_as_load_const(basic_ind->ind->invariant->def-> - parent_instr)->value; + /* The basic induction var might be a vector but, because we guarantee + * earlier that the phi source has a scalar swizzle, we can take the + * component from basic_ind. + */ + nir_ssa_scalar initial_s = { ind_var->def_outside_loop, basic_ind.comp }; + nir_ssa_scalar alu_s = { &ind_var->alu->dest.dest.ssa, basic_ind.comp }; + + nir_const_value initial_val = nir_ssa_scalar_as_const_value(initial_s); + + /* We are guaranteed by earlier code that at least one of these sources + * is a constant but we don't know which. + */ + nir_const_value step_val; + memset(&step_val, 0, sizeof(step_val)); + UNUSED bool found_step_value = false; + assert(nir_op_infos[ind_var->alu->op].num_inputs == 2); + for (unsigned i = 0; i < 2; i++) { + nir_ssa_scalar alu_src = nir_ssa_scalar_chase_alu_src(alu_s, i); + if (nir_ssa_scalar_is_const(alu_src)) { + found_step_value = true; + step_val = nir_ssa_scalar_as_const_value(alu_src); + break; + } + } + assert(found_step_value); - int iterations = calculate_iterations(initial_val, step_val, + int iterations = calculate_iterations(&initial_val, &step_val, &limit_val, - basic_ind->ind->alu_def, alu, + ind_var->alu, cond, alu_op, limit_rhs, terminator->continue_from_then); |