diff options
Diffstat (limited to 'src/compiler/nir/spirv/vtn_alu.c')
-rw-r--r-- | src/compiler/nir/spirv/vtn_alu.c | 464 |
1 files changed, 0 insertions, 464 deletions
diff --git a/src/compiler/nir/spirv/vtn_alu.c b/src/compiler/nir/spirv/vtn_alu.c deleted file mode 100644 index 8b9a63ce760..00000000000 --- a/src/compiler/nir/spirv/vtn_alu.c +++ /dev/null @@ -1,464 +0,0 @@ -/* - * Copyright © 2016 Intel Corporation - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and associated documentation files (the "Software"), - * to deal in the Software without restriction, including without limitation - * the rights to use, copy, modify, merge, publish, distribute, sublicense, - * and/or sell copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice (including the next - * paragraph) shall be included in all copies or substantial portions of the - * Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL - * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ - -#include "vtn_private.h" - -/* - * Normally, column vectors in SPIR-V correspond to a single NIR SSA - * definition. But for matrix multiplies, we want to do one routine for - * multiplying a matrix by a matrix and then pretend that vectors are matrices - * with one column. So we "wrap" these things, and unwrap the result before we - * send it off. - */ - -static struct vtn_ssa_value * -wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) -{ - if (val == NULL) - return NULL; - - if (glsl_type_is_matrix(val->type)) - return val; - - struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); - dest->type = val->type; - dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); - dest->elems[0] = val; - - return dest; -} - -static struct vtn_ssa_value * -unwrap_matrix(struct vtn_ssa_value *val) -{ - if (glsl_type_is_matrix(val->type)) - return val; - - return val->elems[0]; -} - -static struct vtn_ssa_value * -matrix_multiply(struct vtn_builder *b, - struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1) -{ - - struct vtn_ssa_value *src0 = wrap_matrix(b, _src0); - struct vtn_ssa_value *src1 = wrap_matrix(b, _src1); - struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed); - struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed); - - unsigned src0_rows = glsl_get_vector_elements(src0->type); - unsigned src0_columns = glsl_get_matrix_columns(src0->type); - unsigned src1_columns = glsl_get_matrix_columns(src1->type); - - const struct glsl_type *dest_type; - if (src1_columns > 1) { - dest_type = glsl_matrix_type(glsl_get_base_type(src0->type), - src0_rows, src1_columns); - } else { - dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows); - } - struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); - - dest = wrap_matrix(b, dest); - - bool transpose_result = false; - if (src0_transpose && src1_transpose) { - /* transpose(A) * transpose(B) = transpose(B * A) */ - src1 = src0_transpose; - src0 = src1_transpose; - src0_transpose = NULL; - src1_transpose = NULL; - transpose_result = true; - } - - if (src0_transpose && !src1_transpose && - glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { - /* We already have the rows of src0 and the columns of src1 available, - * so we can just take the dot product of each row with each column to - * get the result. - */ - - for (unsigned i = 0; i < src1_columns; i++) { - nir_ssa_def *vec_src[4]; - for (unsigned j = 0; j < src0_rows; j++) { - vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def, - src1->elems[i]->def); - } - dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows); - } - } else { - /* We don't handle the case where src1 is transposed but not src0, since - * the general case only uses individual components of src1 so the - * optimizer should chew through the transpose we emitted for src1. - */ - - for (unsigned i = 0; i < src1_columns; i++) { - /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ - dest->elems[i]->def = - nir_fmul(&b->nb, src0->elems[0]->def, - nir_channel(&b->nb, src1->elems[i]->def, 0)); - for (unsigned j = 1; j < src0_columns; j++) { - dest->elems[i]->def = - nir_fadd(&b->nb, dest->elems[i]->def, - nir_fmul(&b->nb, src0->elems[j]->def, - nir_channel(&b->nb, src1->elems[i]->def, j))); - } - } - } - - dest = unwrap_matrix(dest); - - if (transpose_result) - dest = vtn_ssa_transpose(b, dest); - - return dest; -} - -static struct vtn_ssa_value * -mat_times_scalar(struct vtn_builder *b, - struct vtn_ssa_value *mat, - nir_ssa_def *scalar) -{ - struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); - for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { - if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT) - dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); - else - dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); - } - - return dest; -} - -static void -vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, - struct vtn_value *dest, - struct vtn_ssa_value *src0, struct vtn_ssa_value *src1) -{ - switch (opcode) { - case SpvOpFNegate: { - dest->ssa = vtn_create_ssa_value(b, src0->type); - unsigned cols = glsl_get_matrix_columns(src0->type); - for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); - break; - } - - case SpvOpFAdd: { - dest->ssa = vtn_create_ssa_value(b, src0->type); - unsigned cols = glsl_get_matrix_columns(src0->type); - for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = - nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def); - break; - } - - case SpvOpFSub: { - dest->ssa = vtn_create_ssa_value(b, src0->type); - unsigned cols = glsl_get_matrix_columns(src0->type); - for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = - nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def); - break; - } - - case SpvOpTranspose: - dest->ssa = vtn_ssa_transpose(b, src0); - break; - - case SpvOpMatrixTimesScalar: - if (src0->transposed) { - dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, - src1->def)); - } else { - dest->ssa = mat_times_scalar(b, src0, src1->def); - } - break; - - case SpvOpVectorTimesMatrix: - case SpvOpMatrixTimesVector: - case SpvOpMatrixTimesMatrix: - if (opcode == SpvOpVectorTimesMatrix) { - dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); - } else { - dest->ssa = matrix_multiply(b, src0, src1); - } - break; - - default: unreachable("unknown matrix opcode"); - } -} - -nir_op -vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap) -{ - /* Indicates that the first two arguments should be swapped. This is - * used for implementing greater-than and less-than-or-equal. - */ - *swap = false; - - switch (opcode) { - case SpvOpSNegate: return nir_op_ineg; - case SpvOpFNegate: return nir_op_fneg; - case SpvOpNot: return nir_op_inot; - case SpvOpIAdd: return nir_op_iadd; - case SpvOpFAdd: return nir_op_fadd; - case SpvOpISub: return nir_op_isub; - case SpvOpFSub: return nir_op_fsub; - case SpvOpIMul: return nir_op_imul; - case SpvOpFMul: return nir_op_fmul; - case SpvOpUDiv: return nir_op_udiv; - case SpvOpSDiv: return nir_op_idiv; - case SpvOpFDiv: return nir_op_fdiv; - case SpvOpUMod: return nir_op_umod; - case SpvOpSMod: return nir_op_imod; - case SpvOpFMod: return nir_op_fmod; - case SpvOpSRem: return nir_op_irem; - case SpvOpFRem: return nir_op_frem; - - case SpvOpShiftRightLogical: return nir_op_ushr; - case SpvOpShiftRightArithmetic: return nir_op_ishr; - case SpvOpShiftLeftLogical: return nir_op_ishl; - case SpvOpLogicalOr: return nir_op_ior; - case SpvOpLogicalEqual: return nir_op_ieq; - case SpvOpLogicalNotEqual: return nir_op_ine; - case SpvOpLogicalAnd: return nir_op_iand; - case SpvOpLogicalNot: return nir_op_inot; - case SpvOpBitwiseOr: return nir_op_ior; - case SpvOpBitwiseXor: return nir_op_ixor; - case SpvOpBitwiseAnd: return nir_op_iand; - case SpvOpSelect: return nir_op_bcsel; - case SpvOpIEqual: return nir_op_ieq; - - case SpvOpBitFieldInsert: return nir_op_bitfield_insert; - case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract; - case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract; - case SpvOpBitReverse: return nir_op_bitfield_reverse; - case SpvOpBitCount: return nir_op_bit_count; - - /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */ - case SpvOpFOrdEqual: return nir_op_feq; - case SpvOpFUnordEqual: return nir_op_feq; - case SpvOpINotEqual: return nir_op_ine; - case SpvOpFOrdNotEqual: return nir_op_fne; - case SpvOpFUnordNotEqual: return nir_op_fne; - case SpvOpULessThan: return nir_op_ult; - case SpvOpSLessThan: return nir_op_ilt; - case SpvOpFOrdLessThan: return nir_op_flt; - case SpvOpFUnordLessThan: return nir_op_flt; - case SpvOpUGreaterThan: *swap = true; return nir_op_ult; - case SpvOpSGreaterThan: *swap = true; return nir_op_ilt; - case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt; - case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt; - case SpvOpULessThanEqual: *swap = true; return nir_op_uge; - case SpvOpSLessThanEqual: *swap = true; return nir_op_ige; - case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge; - case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge; - case SpvOpUGreaterThanEqual: return nir_op_uge; - case SpvOpSGreaterThanEqual: return nir_op_ige; - case SpvOpFOrdGreaterThanEqual: return nir_op_fge; - case SpvOpFUnordGreaterThanEqual: return nir_op_fge; - - /* Conversions: */ - case SpvOpConvertFToU: return nir_op_f2u; - case SpvOpConvertFToS: return nir_op_f2i; - case SpvOpConvertSToF: return nir_op_i2f; - case SpvOpConvertUToF: return nir_op_u2f; - case SpvOpBitcast: return nir_op_imov; - case SpvOpUConvert: - case SpvOpQuantizeToF16: return nir_op_fquantize2f16; - /* TODO: NIR is 32-bit only; these are no-ops. */ - case SpvOpSConvert: return nir_op_imov; - case SpvOpFConvert: return nir_op_fmov; - - /* Derivatives: */ - case SpvOpDPdx: return nir_op_fddx; - case SpvOpDPdy: return nir_op_fddy; - case SpvOpDPdxFine: return nir_op_fddx_fine; - case SpvOpDPdyFine: return nir_op_fddy_fine; - case SpvOpDPdxCoarse: return nir_op_fddx_coarse; - case SpvOpDPdyCoarse: return nir_op_fddy_coarse; - - default: - unreachable("No NIR equivalent"); - } -} - -static void -handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member, - const struct vtn_decoration *dec, void *_void) -{ - assert(dec->scope == VTN_DEC_DECORATION); - if (dec->decoration != SpvDecorationNoContraction) - return; - - b->nb.exact = true; -} - -void -vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, - const uint32_t *w, unsigned count) -{ - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - const struct glsl_type *type = - vtn_value(b, w[1], vtn_value_type_type)->type->type; - - vtn_foreach_decoration(b, val, handle_no_contraction, NULL); - - /* Collect the various SSA sources */ - const unsigned num_inputs = count - 3; - struct vtn_ssa_value *vtn_src[4] = { NULL, }; - for (unsigned i = 0; i < num_inputs; i++) - vtn_src[i] = vtn_ssa_value(b, w[i + 3]); - - if (glsl_type_is_matrix(vtn_src[0]->type) || - (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { - vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]); - b->nb.exact = false; - return; - } - - val->ssa = vtn_create_ssa_value(b, type); - nir_ssa_def *src[4] = { NULL, }; - for (unsigned i = 0; i < num_inputs; i++) { - assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); - src[i] = vtn_src[i]->def; - } - - switch (opcode) { - case SpvOpAny: - if (src[0]->num_components == 1) { - val->ssa->def = nir_imov(&b->nb, src[0]); - } else { - nir_op op; - switch (src[0]->num_components) { - case 2: op = nir_op_bany_inequal2; break; - case 3: op = nir_op_bany_inequal3; break; - case 4: op = nir_op_bany_inequal4; break; - } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], - nir_imm_int(&b->nb, NIR_FALSE), - NULL, NULL); - } - break; - - case SpvOpAll: - if (src[0]->num_components == 1) { - val->ssa->def = nir_imov(&b->nb, src[0]); - } else { - nir_op op; - switch (src[0]->num_components) { - case 2: op = nir_op_ball_iequal2; break; - case 3: op = nir_op_ball_iequal3; break; - case 4: op = nir_op_ball_iequal4; break; - } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], - nir_imm_int(&b->nb, NIR_TRUE), - NULL, NULL); - } - break; - - case SpvOpOuterProduct: { - for (unsigned i = 0; i < src[1]->num_components; i++) { - val->ssa->elems[i]->def = - nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i)); - } - break; - } - - case SpvOpDot: - val->ssa->def = nir_fdot(&b->nb, src[0], src[1]); - break; - - case SpvOpIAddCarry: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); - break; - - case SpvOpISubBorrow: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); - break; - - case SpvOpUMulExtended: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); - break; - - case SpvOpSMulExtended: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); - break; - - case SpvOpFwidth: - val->ssa->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), - nir_fabs(&b->nb, nir_fddy(&b->nb, src[0]))); - break; - case SpvOpFwidthFine: - val->ssa->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), - nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0]))); - break; - case SpvOpFwidthCoarse: - val->ssa->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), - nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0]))); - break; - - case SpvOpVectorTimesScalar: - /* The builder will take care of splatting for us. */ - val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); - break; - - case SpvOpIsNan: - val->ssa->def = nir_fne(&b->nb, src[0], src[0]); - break; - - case SpvOpIsInf: - val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]), - nir_imm_float(&b->nb, INFINITY)); - break; - - default: { - bool swap; - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap); - - if (swap) { - nir_ssa_def *tmp = src[0]; - src[0] = src[1]; - src[1] = tmp; - } - - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); - break; - } /* default */ - } - - b->nb.exact = false; -} |