From fc9c4f89b85c0116c0dc22a3eaf25f5df88ad657 Mon Sep 17 00:00:00 2001
From: Jason Ekstrand <jason.ekstrand@intel.com>
Date: Thu, 13 Dec 2018 11:08:13 -0600
Subject: nir: Move propagation of cast derefs to a new nir_opt_deref pass
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

We're going to want to do more deref optimizations going forward and
this gives us a central place to do them.  Also, cast propagation will
get a bit more complicated with the addition of ptr_as_array derefs.

Reviewed-by: Alejandro PiƱeiro <apinheiro@igalia.com>
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
---
 src/amd/vulkan/radv_shader.c              |  2 +-
 src/compiler/nir/nir.h                    |  2 +
 src/compiler/nir/nir_deref.c              | 63 +++++++++++++++++++++++++++++++
 src/compiler/nir/nir_inline_functions.c   | 12 +++---
 src/compiler/nir/nir_opt_copy_propagate.c | 22 -----------
 src/intel/vulkan/anv_pipeline.c           |  2 +-
 src/mesa/main/glspirv.c                   |  2 +-
 7 files changed, 75 insertions(+), 30 deletions(-)

diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index 7ad9abe8df8..34bfa447930 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -265,7 +265,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
 		NIR_PASS_V(nir, nir_lower_constant_initializers, nir_var_local);
 		NIR_PASS_V(nir, nir_lower_returns);
 		NIR_PASS_V(nir, nir_inline_functions);
-		NIR_PASS_V(nir, nir_copy_prop);
+		NIR_PASS_V(nir, nir_opt_deref);
 
 		/* Pick off the single entrypoint that we want */
 		foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index e72585000d4..e991c625536 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -3196,6 +3196,8 @@ bool nir_opt_dead_cf(nir_shader *shader);
 
 bool nir_opt_dead_write_vars(nir_shader *shader);
 
+bool nir_opt_deref(nir_shader *shader);
+
 bool nir_opt_find_array_copies(nir_shader *shader);
 
 bool nir_opt_gcm(nir_shader *shader, bool value_number);
diff --git a/src/compiler/nir/nir_deref.c b/src/compiler/nir/nir_deref.c
index 7a625bcf92b..1dffa285037 100644
--- a/src/compiler/nir/nir_deref.c
+++ b/src/compiler/nir/nir_deref.c
@@ -510,3 +510,66 @@ nir_rematerialize_derefs_in_use_blocks_impl(nir_function_impl *impl)
 
    return state.progress;
 }
+
+static bool
+is_trivial_deref_cast(nir_deref_instr *cast)
+{
+   nir_deref_instr *parent = nir_src_as_deref(cast->parent);
+   if (!parent)
+      return false;
+
+   return cast->mode == parent->mode &&
+          cast->type == parent->type &&
+          cast->dest.ssa.num_components == parent->dest.ssa.num_components &&
+          cast->dest.ssa.bit_size == parent->dest.ssa.bit_size;
+}
+
+static bool
+nir_opt_deref_impl(nir_function_impl *impl)
+{
+   bool progress = false;
+
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type != nir_instr_type_deref)
+            continue;
+
+         nir_deref_instr *deref = nir_instr_as_deref(instr);
+         switch (deref->deref_type) {
+         case nir_deref_type_cast:
+            if (is_trivial_deref_cast(deref)) {
+               assert(deref->parent.is_ssa);
+               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
+                                        nir_src_for_ssa(deref->parent.ssa));
+               nir_instr_remove(&deref->instr);
+               progress = true;
+            }
+            break;
+
+         default:
+            /* Do nothing */
+            break;
+         }
+      }
+   }
+
+   if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+   }
+
+   return progress;
+}
+
+bool
+nir_opt_deref(nir_shader *shader)
+{
+   bool progress = false;
+
+   nir_foreach_function(func, shader) {
+      if (func->impl && nir_opt_deref_impl(func->impl))
+         progress = true;
+   }
+
+   return progress;
+}
diff --git a/src/compiler/nir/nir_inline_functions.c b/src/compiler/nir/nir_inline_functions.c
index 29474bb417b..864638d2315 100644
--- a/src/compiler/nir/nir_inline_functions.c
+++ b/src/compiler/nir/nir_inline_functions.c
@@ -164,14 +164,16 @@ inline_function_impl(nir_function_impl *impl, struct set *inlined)
  *     This does the actual function inlining and the resulting shader will
  *     contain no call instructions.
  *
- *  4. nir_copy_prop(shader)
+ *  4. nir_opt_deref(shader)
  *
  *     Most functions contain pointer parameters where the result of a deref
  *     instruction is passed in as a parameter, loaded via a load_param
- *     intrinsic, and then turned back into a deref via a cast.  Running copy
- *     propagation gets rid of the intermediate steps and results in a whole
- *     deref chain again.  This is currently required by a number of
- *     optimizations and lowering passes at least for certain variable modes.
+ *     intrinsic, and then turned back into a deref via a cast.  Function
+ *     inlining will get rid of the load_param but we are still left with a
+ *     cast.  Running nir_opt_deref gets rid of the intermediate cast and
+ *     results in a whole deref chain again.  This is currently required by a
+ *     number of optimizations and lowering passes at least for certain
+ *     variable modes.
  *
  *  5. Loop over the functions and delete all but the main entrypoint.
  *
diff --git a/src/compiler/nir/nir_opt_copy_propagate.c b/src/compiler/nir/nir_opt_copy_propagate.c
index 189d544979b..7673e9b62dd 100644
--- a/src/compiler/nir/nir_opt_copy_propagate.c
+++ b/src/compiler/nir/nir_opt_copy_propagate.c
@@ -98,22 +98,6 @@ is_swizzleless_move(nir_alu_instr *instr)
    }
 }
 
-static bool
-is_trivial_deref_cast(nir_deref_instr *cast)
-{
-   if (cast->deref_type != nir_deref_type_cast)
-      return false;
-
-   nir_deref_instr *parent = nir_src_as_deref(cast->parent);
-   if (!parent)
-      return false;
-
-   return cast->mode == parent->mode &&
-          cast->type == parent->type &&
-          cast->dest.ssa.num_components == parent->dest.ssa.num_components &&
-          cast->dest.ssa.bit_size == parent->dest.ssa.bit_size;
-}
-
 static bool
 copy_prop_src(nir_src *src, nir_instr *parent_instr, nir_if *parent_if,
               unsigned num_components)
@@ -135,12 +119,6 @@ copy_prop_src(nir_src *src, nir_instr *parent_instr, nir_if *parent_if,
          return false;
 
       copy_def= alu_instr->src[0].src.ssa;
-   } else if (src_instr->type == nir_instr_type_deref) {
-      nir_deref_instr *deref_instr = nir_instr_as_deref(src_instr);
-      if (!is_trivial_deref_cast(deref_instr))
-         return false;
-
-      copy_def = deref_instr->parent.ssa;
    } else {
       return false;
    }
diff --git a/src/intel/vulkan/anv_pipeline.c b/src/intel/vulkan/anv_pipeline.c
index 6db9945e0d4..10afe0825ab 100644
--- a/src/intel/vulkan/anv_pipeline.c
+++ b/src/intel/vulkan/anv_pipeline.c
@@ -184,7 +184,7 @@ anv_shader_compile_to_nir(struct anv_pipeline *pipeline,
    NIR_PASS_V(nir, nir_lower_constant_initializers, nir_var_local);
    NIR_PASS_V(nir, nir_lower_returns);
    NIR_PASS_V(nir, nir_inline_functions);
-   NIR_PASS_V(nir, nir_copy_prop);
+   NIR_PASS_V(nir, nir_opt_deref);
 
    /* Pick off the single entrypoint that we want */
    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
diff --git a/src/mesa/main/glspirv.c b/src/mesa/main/glspirv.c
index 04e46ba571e..ec1edf11cfa 100644
--- a/src/mesa/main/glspirv.c
+++ b/src/mesa/main/glspirv.c
@@ -245,7 +245,7 @@ _mesa_spirv_to_nir(struct gl_context *ctx,
    NIR_PASS_V(nir, nir_lower_constant_initializers, nir_var_local);
    NIR_PASS_V(nir, nir_lower_returns);
    NIR_PASS_V(nir, nir_inline_functions);
-   NIR_PASS_V(nir, nir_copy_prop);
+   NIR_PASS_V(nir, nir_opt_deref);
 
    /* Pick off the single entrypoint that we want */
    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
-- 
cgit v1.2.3