diff options
-rw-r--r-- | src/compiler/nir/nir_algebraic.py | 396 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.c | 45 | ||||
-rw-r--r-- | src/compiler/nir/nir_search.h | 3 |
3 files changed, 425 insertions, 19 deletions
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index 4779507fada..6db749e9248 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -51,6 +51,13 @@ conv_opcode_types = { 'f2b' : 'bool', } +def get_c_opcode(op): + if op in conv_opcode_types: + return 'nir_search_op_' + op + else: + return 'nir_op_' + op + + if sys.version_info < (3, 0): integer_types = (int, long) string_type = unicode @@ -347,10 +354,7 @@ class Expression(Value): return self.comm_exprs def c_opcode(self): - if self.opcode in conv_opcode_types: - return 'nir_search_op_' + self.opcode - else: - return 'nir_op_' + self.opcode + return get_c_opcode(self.opcode) def render(self, cache): srcs = "\n".join(src.render(cache) for src in self.sources) @@ -692,6 +696,266 @@ class SearchAndReplace(object): BitSizeValidator(varset).validate(self.search, self.replace) +class TreeAutomaton(object): + """This class calculates a bottom-up tree automaton to quickly search for + the left-hand sides of tranforms. Tree automatons are a generalization of + classical NFA's and DFA's, where the transition function determines the + state of the parent node based on the state of its children. We construct a + deterministic automaton to match patterns, using a similar algorithm to the + classical NFA to DFA construction. At the moment, it only matches opcodes + and constants (without checking the actual value), leaving more detailed + checking to the search function which actually checks the leaves. The + automaton acts as a quick filter for the search function, requiring only n + + 1 table lookups for each n-source operation. The implementation is based + on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." + In the language of that reference, this is a frontier-to-root deterministic + automaton using only symbol filtering. The filtering is crucial to reduce + both the time taken to generate the tables and the size of the tables. + """ + def __init__(self, transforms): + self.patterns = [t.search for t in transforms] + self._compute_items() + self._build_table() + #print('num items: {}'.format(len(set(self.items.values())))) + #print('num states: {}'.format(len(self.states))) + #for state, patterns in zip(self.states, self.patterns): + # print('{}: num patterns: {}'.format(state, len(patterns))) + + class IndexMap(object): + """An indexed list of objects, where one can either lookup an object by + index or find the index associated to an object quickly using a hash + table. Compared to a list, it has a constant time index(). Compared to a + set, it provides a stable iteration order. + """ + def __init__(self, iterable=()): + self.objects = [] + self.map = {} + for obj in iterable: + self.add(obj) + + def __getitem__(self, i): + return self.objects[i] + + def __contains__(self, obj): + return obj in self.map + + def __len__(self): + return len(self.objects) + + def __iter__(self): + return iter(self.objects) + + def clear(self): + self.objects = [] + self.map.clear() + + def index(self, obj): + return self.map[obj] + + def add(self, obj): + if obj in self.map: + return self.map[obj] + else: + index = len(self.objects) + self.objects.append(obj) + self.map[obj] = index + return index + + def __repr__(self): + return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' + + class Item(object): + """This represents an "item" in the language of "Tree Automatons." This + is just a subtree of some pattern, which represents a potential partial + match at runtime. We deduplicate them, so that identical subtrees of + different patterns share the same object, and store some extra + information needed for the main algorithm as well. + """ + def __init__(self, opcode, children): + self.opcode = opcode + self.children = children + # These are the indices of patterns for which this item is the root node. + self.patterns = [] + # This the set of opcodes for parents of this item. Used to speed up + # filtering. + self.parent_ops = set() + + def __str__(self): + return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' + + def __repr__(self): + return str(self) + + def _compute_items(self): + """Build a set of all possible items, deduplicating them.""" + # This is a map from (opcode, sources) to item. + self.items = {} + + # The set of all opcodes used by the patterns. Used later to avoid + # building and emitting all the tables for opcodes that aren't used. + self.opcodes = self.IndexMap() + + def get_item(opcode, children, pattern=None): + commutative = len(children) == 2 \ + and "commutative" in opcodes[opcode].algebraic_properties + item = self.items.setdefault((opcode, children), + self.Item(opcode, children)) + if commutative: + self.items[opcode, (children[1], children[0])] = item + if pattern is not None: + item.patterns.append(pattern) + return item + + self.wildcard = get_item("__wildcard", ()) + self.const = get_item("__const", ()) + + def process_subpattern(src, pattern=None): + if isinstance(src, Constant): + # Note: we throw away the actual constant value! + return self.const + elif isinstance(src, Variable): + if src.is_constant: + return self.const + else: + # Note: we throw away which variable it is here! This special + # item is equivalent to nu in "Tree Automatons." + return self.wildcard + else: + assert isinstance(src, Expression) + opcode = src.opcode + stripped = opcode.rstrip('0123456789') + if stripped in conv_opcode_types: + # Matches that use conversion opcodes with a specific type, + # like f2b1, are tricky. Either we construct the automaton to + # match specific NIR opcodes like nir_op_f2b1, in which case we + # need to create separate items for each possible NIR opcode + # for patterns that have a generic opcode like f2b, or we + # construct it to match the search opcode, in which case we + # need to map f2b1 to f2b when constructing the automaton. Here + # we do the latter. + opcode = stripped + self.opcodes.add(opcode) + children = tuple(process_subpattern(c) for c in src.sources) + item = get_item(opcode, children, pattern) + for i, child in enumerate(children): + child.parent_ops.add(opcode) + return item + + for i, pattern in enumerate(self.patterns): + process_subpattern(pattern, i) + + def _build_table(self): + """This is the core algorithm which builds up the transition table. It + is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . + Comp_a and Filt_{a,i} using integers to identify match sets." It + simultaneously builds up a list of all possible "match sets" or + "states", where each match set represents the set of Item's that match a + given instruction, and builds up the transition table between states. + """ + # Map from opcode + filtered state indices to transitioned state. + self.table = defaultdict(dict) + # Bijection from state to index. q in the original algorithm is + # len(self.states) + self.states = self.IndexMap() + # List of pattern matches for each state index. + self.state_patterns = [] + # Map from state index to filtered state index for each opcode. + self.filter = defaultdict(list) + # Bijections from filtered state to filtered state index for each + # opcode, called the "representor sets" in the original algorithm. + # q_{a,j} in the original algorithm is len(self.rep[op]). + self.rep = defaultdict(self.IndexMap) + + # Everything in self.states with a index at least worklist_index is part + # of the worklist of newly created states. There is also a worklist of + # newly fitered states for each opcode, for which worklist_indices + # serves a similar purpose. worklist_index corresponds to p in the + # original algorithm, while worklist_indices is p_{a,j} (although since + # we only filter by opcode/symbol, it's really just p_a). + self.worklist_index = 0 + worklist_indices = defaultdict(lambda: 0) + + # This is the set of opcodes for which the filtered worklist is non-empty. + # It's used to avoid scanning opcodes for which there is nothing to + # process when building the transition table. It corresponds to new_a in + # the original algorithm. + new_opcodes = self.IndexMap() + + # Process states on the global worklist, filtering them for each opcode, + # updating the filter tables, and updating the filtered worklists if any + # new filtered states are found. Similar to ComputeRepresenterSets() in + # the original algorithm, although that only processes a single state. + def process_new_states(): + while self.worklist_index < len(self.states): + state = self.states[self.worklist_index] + + # Calculate pattern matches for this state. Each pattern is + # assigned to a unique item, so we don't have to worry about + # deduplicating them here. However, we do have to sort them so + # that they're visited at runtime in the order they're specified + # in the source. + patterns = list(sorted(p for item in state for p in item.patterns)) + assert len(self.state_patterns) == self.worklist_index + self.state_patterns.append(patterns) + + # calculate filter table for this state, and update filtered + # worklists. + for op in self.opcodes: + filt = self.filter[op] + rep = self.rep[op] + filtered = frozenset(item for item in state if \ + op in item.parent_ops) + if filtered in rep: + rep_index = rep.index(filtered) + else: + rep_index = rep.add(filtered) + new_opcodes.add(op) + assert len(filt) == self.worklist_index + filt.append(rep_index) + self.worklist_index += 1 + + # There are two start states: one which can only match as a wildcard, + # and one which can match as a wildcard or constant. These will be the + # states of intrinsics/other instructions and load_const instructions, + # respectively. The indices of these must match the definitions of + # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can + # initialize things correctly. + self.states.add(frozenset((self.wildcard,))) + self.states.add(frozenset((self.const,self.wildcard))) + process_new_states() + + while len(new_opcodes) > 0: + for op in new_opcodes: + rep = self.rep[op] + table = self.table[op] + op_worklist_index = worklist_indices[op] + if op in conv_opcode_types: + num_srcs = 1 + else: + num_srcs = opcodes[op].num_inputs + + # Iterate over all possible source combinations where at least one + # is on the worklist. + for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): + if all(src_idx < op_worklist_index for src_idx in src_indices): + continue + + srcs = tuple(rep[src_idx] for src_idx in src_indices) + + # Try all possible pairings of source items and add the + # corresponding parent items. This is Comp_a from the paper. + parent = set(self.items[op, item_srcs] for item_srcs in + itertools.product(*srcs) if (op, item_srcs) in self.items) + + # We could always start matching something else with a + # wildcard. This is Cl from the paper. + parent.add(self.wildcard) + + table[src_indices] = self.states.add(frozenset(parent)) + worklist_indices[op] = len(rep) + new_opcodes.clear() + process_new_states() + _algebraic_pass_template = mako.template.Template(""" #include "nir.h" #include "nir_builder.h" @@ -707,6 +971,19 @@ struct transform { unsigned condition_offset; }; +struct per_op_table { + const uint16_t *filter; + unsigned num_filtered_states; + const uint16_t *table; +}; + +/* Note: these must match the start states created in + * TreeAutomaton._build_table() + */ + +/* WILDCARD_STATE = 0 is set by zeroing the state array */ +static const uint16_t CONST_STATE = 1; + #endif <% cache = {} %> @@ -715,17 +992,80 @@ struct transform { ${xform.replace.render(cache)} % endfor -% for (opcode, xform_list) in sorted(opcode_xforms.items()): -static const struct transform ${pass_name}_${opcode}_xforms[] = { -% for xform in xform_list: - { ${xform.search.c_ptr(cache)}, ${xform.replace.c_value_ptr(cache)}, ${xform.condition_index} }, +% for state_id, state_xforms in enumerate(automaton.state_patterns): +static const struct transform ${pass_name}_state${state_id}_xforms[] = { +% for i in state_xforms: + { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} }, % endfor }; % endfor +static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = { +% for op in automaton.opcodes: + [${get_c_opcode(op)}] = { + .filter = (uint16_t []) { + % for e in automaton.filter[op]: + ${e}, + % endfor + }, + <% + num_filtered = len(automaton.rep[op]) + %> + .num_filtered_states = ${num_filtered}, + .table = (uint16_t []) { + <% + num_srcs = len(next(iter(automaton.table[op]))) + %> + % for indices in itertools.product(range(num_filtered), repeat=num_srcs): + ${automaton.table[op][indices]}, + % endfor + }, + }, +% endfor +}; + +static void +${pass_name}_pre_block(nir_block *block, uint16_t *states) +{ + nir_foreach_instr(instr, block) { + switch (instr->type) { + case nir_instr_type_alu: { + nir_alu_instr *alu = nir_instr_as_alu(instr); + nir_op op = alu->op; + uint16_t search_op = nir_search_op_for_nir_op(op); + const struct per_op_table *tbl = &${pass_name}_table[search_op]; + if (tbl->num_filtered_states == 0) + continue; + + /* Calculate the index into the transition table. Note the index + * calculated must match the iteration order of Python's + * itertools.product(), which was used to emit the transition + * table. + */ + uint16_t index = 0; + for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { + index *= tbl->num_filtered_states; + index += tbl->filter[states[alu->src[i].src.ssa->index]]; + } + states[alu->dest.dest.ssa.index] = tbl->table[index]; + break; + } + + case nir_instr_type_load_const: { + nir_load_const_instr *load_const = nir_instr_as_load_const(instr); + states[load_const->def.index] = CONST_STATE; + break; + } + + default: + break; + } + } +} + static bool ${pass_name}_block(nir_builder *build, nir_block *block, - const bool *condition_flags) + const uint16_t *states, const bool *condition_flags) { bool progress = false; @@ -737,11 +1077,11 @@ ${pass_name}_block(nir_builder *build, nir_block *block, if (!alu->dest.dest.is_ssa) continue; - switch (alu->op) { - % for opcode in sorted(opcode_xforms.keys()): - case nir_op_${opcode}: - for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) { - const struct transform *xform = &${pass_name}_${opcode}_xforms[i]; + switch (states[alu->dest.dest.ssa.index]) { +% for i in range(len(automaton.state_patterns)): + case ${i}: + for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) { + const struct transform *xform = &${pass_name}_state${i}_xforms[i]; if (condition_flags[xform->condition_offset] && nir_replace_instr(build, alu, xform->search, xform->replace)) { progress = true; @@ -749,9 +1089,8 @@ ${pass_name}_block(nir_builder *build, nir_block *block, } } break; - % endfor - default: - break; +% endfor + default: assert(0); } } @@ -766,10 +1105,22 @@ ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags) nir_builder build; nir_builder_init(&build, impl); + /* Note: it's important here that we're allocating a zeroed array, since + * state 0 is the default state, which means we don't have to visit + * anything other than constants and ALU instructions. + */ + uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states)); + + nir_foreach_block(block, impl) { + ${pass_name}_pre_block(block, states); + } + nir_foreach_block_reverse(block, impl) { - progress |= ${pass_name}_block(&build, block, condition_flags); + progress |= ${pass_name}_block(&build, block, states, condition_flags); } + free(states); + if (progress) { nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance); @@ -806,6 +1157,8 @@ ${pass_name}(nir_shader *shader) } """) + + class AlgebraicPass(object): def __init__(self, pass_name, transforms): self.xforms = [] @@ -835,6 +1188,8 @@ class AlgebraicPass(object): else: self.opcode_xforms[xform.search.opcode].append(xform) + self.automaton = TreeAutomaton(self.xforms) + if error: sys.exit(1) @@ -843,4 +1198,7 @@ class AlgebraicPass(object): return _algebraic_pass_template.render(pass_name=self.pass_name, xforms=self.xforms, opcode_xforms=self.opcode_xforms, - condition_list=condition_list) + condition_list=condition_list, + automaton=self.automaton, + get_c_opcode=get_c_opcode, + itertools=itertools) diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index df27a2473ee..c8acdfb46b4 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -134,6 +134,50 @@ nir_op_matches_search_op(nir_op nop, uint16_t sop) #undef MATCH_FCONV_CASE #undef MATCH_ICONV_CASE +#undef MATCH_BCONV_CASE +} + +uint16_t +nir_search_op_for_nir_op(nir_op nop) +{ +#define MATCH_FCONV_CASE(op) \ + case nir_op_##op##16: \ + case nir_op_##op##32: \ + case nir_op_##op##64: \ + return nir_search_op_##op; + +#define MATCH_ICONV_CASE(op) \ + case nir_op_##op##8: \ + case nir_op_##op##16: \ + case nir_op_##op##32: \ + case nir_op_##op##64: \ + return nir_search_op_##op; + +#define MATCH_BCONV_CASE(op) \ + case nir_op_##op##1: \ + case nir_op_##op##32: \ + return nir_search_op_##op; + + + switch (nop) { + MATCH_FCONV_CASE(i2f) + MATCH_FCONV_CASE(u2f) + MATCH_FCONV_CASE(f2f) + MATCH_ICONV_CASE(f2u) + MATCH_ICONV_CASE(f2i) + MATCH_ICONV_CASE(u2u) + MATCH_ICONV_CASE(i2i) + MATCH_FCONV_CASE(b2f) + MATCH_ICONV_CASE(b2i) + MATCH_BCONV_CASE(i2b) + MATCH_BCONV_CASE(f2b) + default: + return nop; + } + +#undef MATCH_FCONV_CASE +#undef MATCH_ICONV_CASE +#undef MATCH_BCONV_CASE } static nir_op @@ -187,6 +231,7 @@ nir_op_for_search_op(uint16_t sop, unsigned bit_size) #undef RET_FCONV_CASE #undef RET_ICONV_CASE +#undef RET_BCONV_CASE } static bool diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index 9dc09d2361c..526a498cd47 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -121,8 +121,11 @@ enum nir_search_op { nir_search_op_b2i, nir_search_op_i2b, nir_search_op_f2b, + nir_num_search_ops, }; +uint16_t nir_search_op_for_nir_op(nir_op op); + typedef struct { nir_search_value value; |