summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/glsl/nir/nir.c54
-rw-r--r--src/glsl/nir/nir.h1
2 files changed, 55 insertions, 0 deletions
diff --git a/src/glsl/nir/nir.c b/src/glsl/nir/nir.c
index 3b74e424bbe..4100f9770cc 100644
--- a/src/glsl/nir/nir.c
+++ b/src/glsl/nir/nir.c
@@ -1605,6 +1605,60 @@ nir_ssa_def_init(nir_function_impl *impl, nir_instr *instr, nir_ssa_def *def,
def->num_components = num_components;
}
+struct ssa_def_rewrite_state {
+ void *mem_ctx;
+ nir_ssa_def *old;
+ nir_src new_src;
+};
+
+static bool
+ssa_def_rewrite_uses_src(nir_src *src, void *void_state)
+{
+ struct ssa_def_rewrite_state *state = void_state;
+
+ if (src->is_ssa && src->ssa == state->old)
+ *src = nir_src_copy(state->new_src, state->mem_ctx);
+
+ return true;
+}
+
+void
+nir_ssa_def_rewrite_uses(nir_ssa_def *def, nir_src new_src, void *mem_ctx)
+{
+ struct ssa_def_rewrite_state state;
+ state.mem_ctx = mem_ctx;
+ state.old = def;
+ state.new_src = new_src;
+
+ assert(!new_src.is_ssa || def != new_src.ssa);
+
+ struct set *new_uses, *new_if_uses;
+ if (new_src.is_ssa) {
+ new_uses = new_src.ssa->uses;
+ new_if_uses = new_src.ssa->if_uses;
+ } else {
+ new_uses = new_src.reg.reg->uses;
+ new_if_uses = new_src.reg.reg->if_uses;
+ }
+
+ struct set_entry *entry;
+ set_foreach(def->uses, entry) {
+ nir_instr *instr = (nir_instr *)entry->key;
+
+ _mesa_set_remove(def->uses, entry);
+ nir_foreach_src(instr, ssa_def_rewrite_uses_src, &state);
+ _mesa_set_add(new_uses, _mesa_hash_pointer(instr), instr);
+ }
+
+ set_foreach(def->if_uses, entry) {
+ nir_if *if_use = (nir_if *)entry->key;
+
+ _mesa_set_remove(def->if_uses, entry);
+ if_use->condition = nir_src_copy(new_src, mem_ctx);
+ _mesa_set_add(new_if_uses, _mesa_hash_pointer(if_use), if_use);
+ }
+}
+
static bool foreach_cf_node(nir_cf_node *node, nir_foreach_block_cb cb,
bool reverse, void *state);
diff --git a/src/glsl/nir/nir.h b/src/glsl/nir/nir.h
index 558ec914044..5933b5dc447 100644
--- a/src/glsl/nir/nir.h
+++ b/src/glsl/nir/nir.h
@@ -1281,6 +1281,7 @@ bool nir_foreach_src(nir_instr *instr, nir_foreach_src_cb cb, void *state);
void nir_ssa_def_init(nir_function_impl *impl, nir_instr *instr,
nir_ssa_def *def, unsigned num_components,
const char *name);
+void nir_ssa_def_rewrite_uses(nir_ssa_def *def, nir_src new_src, void *mem_ctx);
/* visits basic blocks in source-code order */
typedef bool (*nir_foreach_block_cb)(nir_block *block, void *state);