summaryrefslogtreecommitdiffstats
path: root/src/compiler/spirv
diff options
context:
space:
mode:
authorJason Ekstrand <[email protected]>2017-08-29 20:10:35 -0700
committerJason Ekstrand <[email protected]>2018-03-07 12:13:47 -0800
commit57bff0a546c8ebe9a09335200719cb9e13d6aea9 (patch)
treefad82c7e452f7f94e8103002cc5ce1b70f00e84f /src/compiler/spirv
parent789221dcfa5df3c88e28978c90ccfb9eafb30e10 (diff)
spirv: Add support for subgroup arithmetic
Reviewed-by: Lionel Landwerlin <[email protected]> Reviewed-by: Iago Toral Quiroga <[email protected]>
Diffstat (limited to 'src/compiler/spirv')
-rw-r--r--src/compiler/spirv/spirv_to_nir.c4
-rw-r--r--src/compiler/spirv/vtn_subgroup.c97
2 files changed, 93 insertions, 8 deletions
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index a8b545ec866..19862ab612f 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -3313,6 +3313,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvCapabilityGroupNonUniformQuad:
spv_check_supported(subgroup_quad, cap);
+ case SpvCapabilityGroupNonUniformArithmetic:
+ case SpvCapabilityGroupNonUniformClustered:
+ spv_check_supported(subgroup_arithmetic, cap);
+
case SpvCapabilityVariablePointersStorageBuffer:
case SpvCapabilityVariablePointers:
spv_check_supported(variable_pointers, cap);
diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c
index 1204c5945c8..bd3143962be 100644
--- a/src/compiler/spirv/vtn_subgroup.c
+++ b/src/compiler/spirv/vtn_subgroup.c
@@ -28,7 +28,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
nir_intrinsic_op nir_op,
struct vtn_ssa_value *dst,
struct vtn_ssa_value *src0,
- nir_ssa_def *index)
+ nir_ssa_def *index,
+ unsigned const_idx0,
+ unsigned const_idx1)
{
/* Some of the subgroup operations take an index. SPIR-V allows this to be
* any integer type. To make things simpler for drivers, we only support
@@ -41,7 +43,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
if (!glsl_type_is_vector_or_scalar(dst->type)) {
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
- src0->elems[i], index);
+ src0->elems[i], index,
+ const_idx0, const_idx1);
}
return;
}
@@ -56,6 +59,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
if (index)
intrin->src[1] = nir_src_for_ssa(index);
+ intrin->const_index[0] = const_idx0;
+ intrin->const_index[1] = const_idx1;
+
nir_builder_instr_insert(&b->nb, &intrin->instr);
dst->def = &intrin->dest.ssa;
@@ -169,13 +175,13 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformBroadcastFirst:
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
- val->ssa, vtn_ssa_value(b, w[4]), NULL);
+ val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
break;
case SpvOpGroupNonUniformBroadcast:
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
val->ssa, vtn_ssa_value(b, w[4]),
- vtn_ssa_value(b, w[5])->def);
+ vtn_ssa_value(b, w[5])->def, 0, 0);
break;
case SpvOpGroupNonUniformAll:
@@ -248,14 +254,14 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
unreachable("Invalid opcode");
}
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
- vtn_ssa_value(b, w[5])->def);
+ vtn_ssa_value(b, w[5])->def, 0, 0);
break;
}
case SpvOpGroupNonUniformQuadBroadcast:
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
val->ssa, vtn_ssa_value(b, w[4]),
- vtn_ssa_value(b, w[5])->def);
+ vtn_ssa_value(b, w[5])->def, 0, 0);
break;
case SpvOpGroupNonUniformQuadSwap: {
@@ -272,7 +278,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
op = nir_intrinsic_quad_swap_diagonal;
break;
}
- vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), NULL);
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+ NULL, 0, 0);
break;
}
@@ -291,7 +298,81 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformBitwiseXor:
case SpvOpGroupNonUniformLogicalAnd:
case SpvOpGroupNonUniformLogicalOr:
- case SpvOpGroupNonUniformLogicalXor:
+ case SpvOpGroupNonUniformLogicalXor: {
+ nir_op reduction_op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformIAdd:
+ reduction_op = nir_op_iadd;
+ break;
+ case SpvOpGroupNonUniformFAdd:
+ reduction_op = nir_op_fadd;
+ break;
+ case SpvOpGroupNonUniformIMul:
+ reduction_op = nir_op_imul;
+ break;
+ case SpvOpGroupNonUniformFMul:
+ reduction_op = nir_op_fmul;
+ break;
+ case SpvOpGroupNonUniformSMin:
+ reduction_op = nir_op_imin;
+ break;
+ case SpvOpGroupNonUniformUMin:
+ reduction_op = nir_op_umin;
+ break;
+ case SpvOpGroupNonUniformFMin:
+ reduction_op = nir_op_fmin;
+ break;
+ case SpvOpGroupNonUniformSMax:
+ reduction_op = nir_op_imax;
+ break;
+ case SpvOpGroupNonUniformUMax:
+ reduction_op = nir_op_umax;
+ break;
+ case SpvOpGroupNonUniformFMax:
+ reduction_op = nir_op_fmax;
+ break;
+ case SpvOpGroupNonUniformBitwiseAnd:
+ case SpvOpGroupNonUniformLogicalAnd:
+ reduction_op = nir_op_iand;
+ break;
+ case SpvOpGroupNonUniformBitwiseOr:
+ case SpvOpGroupNonUniformLogicalOr:
+ reduction_op = nir_op_ior;
+ break;
+ case SpvOpGroupNonUniformBitwiseXor:
+ case SpvOpGroupNonUniformLogicalXor:
+ reduction_op = nir_op_ixor;
+ break;
+ default:
+ unreachable("Invalid reduction operation");
+ }
+
+ nir_intrinsic_op op;
+ unsigned cluster_size = 0;
+ switch ((SpvGroupOperation)w[4]) {
+ case SpvGroupOperationReduce:
+ op = nir_intrinsic_reduce;
+ break;
+ case SpvGroupOperationInclusiveScan:
+ op = nir_intrinsic_inclusive_scan;
+ break;
+ case SpvGroupOperationExclusiveScan:
+ op = nir_intrinsic_exclusive_scan;
+ break;
+ case SpvGroupOperationClusteredReduce:
+ op = nir_intrinsic_reduce;
+ assert(count == 7);
+ cluster_size = vtn_constant_value(b, w[6])->values[0].u32[0];
+ break;
+ default:
+ unreachable("Invalid group operation");
+ }
+
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
+ NULL, reduction_op, cluster_size);
+ break;
+ }
+
default:
unreachable("Invalid SPIR-V opcode");
}