diff options
-rw-r--r-- | src/compiler/spirv/spirv_to_nir.c | 4 | ||||
-rw-r--r-- | src/compiler/spirv/vtn_alu.c | 29 | ||||
-rw-r--r-- | src/compiler/spirv/vtn_private.h | 3 |
3 files changed, 23 insertions, 13 deletions
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index a305858c928..f203ebc7ef7 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -1213,7 +1213,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, default: { bool swap; - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap); + nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(val->const_type); + nir_alu_type src_alu_type = dst_alu_type; + nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); unsigned num_components = glsl_get_vector_elements(val->const_type); unsigned bit_size = diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 95ff2b1aafe..55f7f2ea42f 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -211,7 +211,8 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, } nir_op -vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap) +vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap, + nir_alu_type src, nir_alu_type dst) { /* Indicates that the first two arguments should be swapped. This is * used for implementing greater-than and less-than-or-equal. @@ -284,16 +285,16 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap) case SpvOpFUnordGreaterThanEqual: return nir_op_fge; /* Conversions: */ - case SpvOpConvertFToU: return nir_op_f2u; - case SpvOpConvertFToS: return nir_op_f2i; - case SpvOpConvertSToF: return nir_op_i2f; - case SpvOpConvertUToF: return nir_op_u2f; case SpvOpBitcast: return nir_op_imov; case SpvOpUConvert: case SpvOpQuantizeToF16: return nir_op_fquantize2f16; - /* TODO: NIR is 32-bit only; these are no-ops. */ - case SpvOpSConvert: return nir_op_imov; - case SpvOpFConvert: return nir_op_fmov; + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertUToF: + case SpvOpSConvert: + case SpvOpFConvert: + return nir_type_conversion_op(src, dst); /* Derivatives: */ case SpvOpDPdx: return nir_op_fddx; @@ -457,7 +458,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpFUnordLessThanEqual: case SpvOpFUnordGreaterThanEqual: { bool swap; - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap); + nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); + nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); if (swap) { nir_ssa_def *tmp = src[0]; @@ -481,7 +484,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpFOrdLessThanEqual: case SpvOpFOrdGreaterThanEqual: { bool swap; - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap); + nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); + nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); if (swap) { nir_ssa_def *tmp = src[0]; @@ -500,7 +505,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, default: { bool swap; - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap); + nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); + nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); if (swap) { nir_ssa_def *tmp = src[0]; diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 9302611803f..ffa00d7f68a 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -481,7 +481,8 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct vtn_builder *, void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value *value, vtn_execution_mode_foreach_cb cb, void *data); -nir_op vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap); +nir_op vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap, + nir_alu_type src, nir_alu_type dst); void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count); |