diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/glsl/nir/nir_algebraic.py | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/src/glsl/nir/nir_algebraic.py b/src/glsl/nir/nir_algebraic.py index 4929745dfa5..afab1a0084f 100644 --- a/src/glsl/nir/nir_algebraic.py +++ b/src/glsl/nir/nir_algebraic.py @@ -147,10 +147,23 @@ class Expression(Value): _optimization_ids = itertools.count() +condition_list = ['true'] + class SearchAndReplace(object): - def __init__(self, search, replace): + def __init__(self, transform): self.id = _optimization_ids.next() + search = transform[0] + replace = transform[1] + if len(transform) > 2: + self.condition = transform[2] + else: + self.condition = 'true' + + if self.condition not in condition_list: + condition_list.append(self.condition) + self.condition_index = condition_list.index(self.condition) + varset = VarSet() if isinstance(search, Expression): self.search = search @@ -171,6 +184,7 @@ _algebraic_pass_template = mako.template.Template(""" struct transform { const nir_search_expression *search; const nir_search_value *replace; + unsigned condition_offset; }; % for (opcode, xform_list) in xform_dict.iteritems(): @@ -181,7 +195,7 @@ struct transform { static const struct transform ${pass_name}_${opcode}_xforms[] = { % for xform in xform_list: - { &${xform.search.name}, ${xform.replace.c_ptr} }, + { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} }, % endfor }; % endfor @@ -189,6 +203,7 @@ static const struct transform ${pass_name}_${opcode}_xforms[] = { struct opt_state { void *mem_ctx; bool progress; + const bool *condition_flags; }; static bool @@ -209,7 +224,8 @@ ${pass_name}_block(nir_block *block, void *void_state) case nir_op_${opcode}: for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) { const struct transform *xform = &${pass_name}_${opcode}_xforms[i]; - if (nir_replace_instr(alu, xform->search, xform->replace, + if (state->condition_flags[xform->condition_offset] && + nir_replace_instr(alu, xform->search, xform->replace, state->mem_ctx)) { state->progress = true; break; @@ -226,12 +242,13 @@ ${pass_name}_block(nir_block *block, void *void_state) } static bool -${pass_name}_impl(nir_function_impl *impl) +${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags) { struct opt_state state; state.mem_ctx = ralloc_parent(impl); state.progress = false; + state.condition_flags = condition_flags; nir_foreach_block(impl, ${pass_name}_block, &state); @@ -242,14 +259,21 @@ ${pass_name}_impl(nir_function_impl *impl) return state.progress; } + bool ${pass_name}(nir_shader *shader) { bool progress = false; + bool condition_flags[${len(condition_list)}]; + const nir_shader_compiler_options *options = shader->options; + + % for index, condition in enumerate(condition_list): + condition_flags[${index}] = ${condition}; + % endfor nir_foreach_overload(shader, overload) { if (overload->impl) - progress |= ${pass_name}_impl(overload->impl); + progress |= ${pass_name}_impl(overload->impl, condition_flags); } return progress; @@ -263,7 +287,7 @@ class AlgebraicPass(object): for xform in transforms: if not isinstance(xform, SearchAndReplace): - xform = SearchAndReplace(*xform) + xform = SearchAndReplace(xform) if xform.search.opcode not in self.xform_dict: self.xform_dict[xform.search.opcode] = [] @@ -272,4 +296,5 @@ class AlgebraicPass(object): def render(self): return _algebraic_pass_template.render(pass_name=self.pass_name, - xform_dict=self.xform_dict) + xform_dict=self.xform_dict, + condition_list=condition_list) |