diff options
-rw-r--r-- | src/compiler/nir/nir_opt_algebraic.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index dad0545594f..55e46b04466 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -884,6 +884,48 @@ for x, y in itertools.product(['f', 'u', 'i'], ['f', 'u', 'i']): x2yN = '{}2{}'.format(x, y) optimizations.append(((x2yN, (b2x, a)), (b2y, a))) +# Optimize away x2xN(a@N) +for t in ['int', 'uint', 'float']: + for N in type_sizes(t): + x2xN = '{0}2{0}{1}'.format(t[0], N) + aN = 'a@{0}'.format(N) + optimizations.append(((x2xN, aN), a)) + +# Optimize x2xN(y2yM(a@P)) -> y2yN(a) for integers +# In particular, we can optimize away everything except upcast of downcast and +# upcasts where the type differs from the other cast +for N, M in itertools.product(type_sizes('uint'), type_sizes('uint')): + if N < M: + # The outer cast is a down-cast. It doesn't matter what the size of the + # argument of the inner cast is because we'll never been in the upcast + # of downcast case. Regardless of types, we'll always end up with y2yN + # in the end. + for x, y in itertools.product(['i', 'u'], ['i', 'u']): + x2xN = '{0}2{0}{1}'.format(x, N) + y2yM = '{0}2{0}{1}'.format(y, M) + y2yN = '{0}2{0}{1}'.format(y, N) + optimizations.append(((x2xN, (y2yM, a)), (y2yN, a))) + elif N > M: + # If the outer cast is an up-cast, we have to be more careful about the + # size of the argument of the inner cast and with types. In this case, + # the type is always the type of type up-cast which is given by the + # outer cast. + for P in type_sizes('uint'): + # We can't optimize away up-cast of down-cast. + if M < P: + continue + + # Because we're doing down-cast of down-cast, the types always have + # to match between the two casts + for x in ['i', 'u']: + x2xN = '{0}2{0}{1}'.format(x, N) + x2xM = '{0}2{0}{1}'.format(x, M) + aP = 'a@{0}'.format(P) + optimizations.append(((x2xN, (x2xM, aP)), (x2xN, a))) + else: + # The N == M case is handled by other optimizations + pass + def fexp2i(exp, bits): # We assume that exp is already in the right range. if bits == 16: |