[spirv-reader][ir] Split ShaderIO GetParameter apart.
Split the entry point specific code for `GetParameter` out of the SPIR-V
ShaderIO lowering code into its own method.
Bug: 42250952
Change-Id: Id4f1840d775dc85e27711714f2d803603b673ea6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/247154
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/lower/shader_io.cc b/src/tint/lang/spirv/reader/lower/shader_io.cc
index dbe23a5..a675fb6 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io.cc
@@ -409,60 +409,148 @@
});
}
+ core::ir::Value* GetEntryPointParameter(core::ir::Function* func, core::ir::Var* var) {
+ auto* var_type = var->Result()->Type()->UnwrapPtr();
+
+ // The SPIR_V type may not match the required WGSL entry point type, swap them as
+ // needed.
+ if (var->Attributes().builtin.has_value()) {
+ switch (var->Attributes().builtin.value()) {
+ case core::BuiltinValue::kSampleMask: {
+ TINT_ASSERT(var_type->Is<core::type::Array>());
+ TINT_ASSERT(var_type->As<core::type::Array>()->ConstantCount() == 1u);
+ var_type = ty.u32();
+ break;
+ }
+ case core::BuiltinValue::kInstanceIndex:
+ case core::BuiltinValue::kVertexIndex:
+ case core::BuiltinValue::kLocalInvocationIndex:
+ case core::BuiltinValue::kSubgroupInvocationId:
+ case core::BuiltinValue::kSubgroupSize:
+ case core::BuiltinValue::kSampleIndex: {
+ var_type = ty.u32();
+ break;
+ }
+ case core::BuiltinValue::kLocalInvocationId:
+ case core::BuiltinValue::kGlobalInvocationId:
+ case core::BuiltinValue::kWorkgroupId:
+ case core::BuiltinValue::kNumWorkgroups: {
+ var_type = ty.vec3<u32>();
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ }
+
+ // Create a new function parameter for the input.
+ auto* param = b.FunctionParam(var_type);
+ func->AppendParam(param);
+ if (auto name = ir.NameOf(var)) {
+ ir.SetName(param, name);
+ }
+
+ // Add attributes to the parameter
+ if (auto* str = param->Type()->UnwrapPtr()->As<core::type::Struct>()) {
+ for (const auto* const_member : str->Members()) {
+ // TODO(crbug.com/tint/745): Remove the const_cast.
+ auto* member = const_cast<core::type::StructMember*>(const_member);
+
+ // Use the base variable attributes if not specified directly on the member.
+ auto member_attributes = member->Attributes();
+ if (auto base_loc = var->Attributes().location) {
+ // Location values increment from the base location value on the variable.
+ member->SetLocation(base_loc.value() + member->Index());
+ }
+ if (!member_attributes.interpolation) {
+ member->SetInterpolation(var->Attributes().interpolation);
+ }
+ }
+ } else {
+ // Set attributes directly on the function parameter.
+ param->SetAttributes(var->Attributes());
+ }
+
+ core::ir::Value* result = param;
+ if (var->Attributes().builtin.has_value()) {
+ switch (var->Attributes().builtin.value()) {
+ case core::BuiltinValue::kSampleMask: {
+ // Construct an array from the scalar sample_mask builtin value for entry
+ // points.
+
+ auto* mask_ty = var->Result()->Type()->UnwrapPtr()->As<core::type::Array>();
+ TINT_ASSERT(mask_ty);
+
+ // If the SPIR-V mask was an i32, need to convert from the u32 provided by
+ // WGSL.
+ if (mask_ty->ElemType()->IsSignedIntegerScalar()) {
+ auto* conv = b.Convert(ty.i32(), result);
+ func->Block()->Prepend(conv);
+
+ auto* construct = b.Construct(mask_ty, conv);
+ construct->InsertAfter(conv);
+ result = construct->Result();
+ } else {
+ auto* construct = b.Construct(mask_ty, result);
+ func->Block()->Prepend(construct);
+ result = construct->Result();
+ }
+ break;
+ }
+ case core::BuiltinValue::kInstanceIndex:
+ case core::BuiltinValue::kVertexIndex:
+ case core::BuiltinValue::kLocalInvocationIndex:
+ case core::BuiltinValue::kSubgroupInvocationId:
+ case core::BuiltinValue::kSubgroupSize:
+ case core::BuiltinValue::kSampleIndex: {
+ auto* idx_ty = var->Result()->Type()->UnwrapPtr();
+ if (idx_ty->IsSignedIntegerScalar()) {
+ auto* conv = b.Convert(ty.i32(), result);
+ func->Block()->Prepend(conv);
+ result = conv->Result();
+ }
+ break;
+ }
+ case core::BuiltinValue::kLocalInvocationId:
+ case core::BuiltinValue::kGlobalInvocationId:
+ case core::BuiltinValue::kWorkgroupId:
+ case core::BuiltinValue::kNumWorkgroups: {
+ auto* idx_ty = var->Result()->Type()->UnwrapPtr();
+ auto* elem_ty = idx_ty->DeepestElement();
+ if (elem_ty->IsSignedIntegerScalar()) {
+ auto* conv = b.Convert(ty.MatchWidth(ty.i32(), idx_ty), result);
+ func->Block()->Prepend(conv);
+ result = conv->Result();
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ }
+ return result;
+ }
+
/// Get or create a function parameter to replace a module-scope variable.
/// @param func the function
/// @param var the module-scope variable
/// @returns the function parameter
core::ir::Value* GetParameter(core::ir::Function* func, core::ir::Var* var) {
- return function_parameter_map.GetOrAddZero(func).GetOrAdd(var, [&] {
- const bool entry_point = func->IsEntryPoint();
- auto* var_type = var->Result()->Type()->UnwrapPtr();
-
- // The SPIR_V type may not match the required WGSL entry point type, swap them as
- // needed.
- if (entry_point && var->Attributes().builtin.has_value()) {
- switch (var->Attributes().builtin.value()) {
- case core::BuiltinValue::kSampleMask: {
- TINT_ASSERT(var_type->Is<core::type::Array>());
- TINT_ASSERT(var_type->As<core::type::Array>()->ConstantCount() == 1u);
- var_type = ty.u32();
- break;
- }
- case core::BuiltinValue::kInstanceIndex:
- case core::BuiltinValue::kVertexIndex:
- case core::BuiltinValue::kLocalInvocationIndex:
- case core::BuiltinValue::kSubgroupInvocationId:
- case core::BuiltinValue::kSubgroupSize:
- case core::BuiltinValue::kSampleIndex: {
- var_type = ty.u32();
- break;
- }
- case core::BuiltinValue::kLocalInvocationId:
- case core::BuiltinValue::kGlobalInvocationId:
- case core::BuiltinValue::kWorkgroupId:
- case core::BuiltinValue::kNumWorkgroups: {
- var_type = ty.vec3<u32>();
- break;
- }
- default: {
- break;
- }
- }
+ return function_parameter_map.GetOrAddZero(func).GetOrAdd(var, [&]() -> core::ir::Value* {
+ if (func->IsEntryPoint()) {
+ return GetEntryPointParameter(func, var);
}
// Create a new function parameter for the input.
- auto* param = b.FunctionParam(var_type);
+ auto* param = b.FunctionParam(var->Result()->Type()->UnwrapPtr());
func->AppendParam(param);
if (auto name = ir.NameOf(var)) {
ir.SetName(param, name);
}
- // Add attributes to the parameter if this is an entry point function.
- if (entry_point) {
- AddEntryPointParameterAttributes(param, var->Attributes());
- }
-
- // Update the callsites of this function.
+ // Update the call sites of this function.
func->ForEachUseUnsorted([&](core::ir::Usage use) {
if (auto* call = use.instruction->As<core::ir::UserCall>()) {
// Recurse into the calling function.
@@ -474,93 +562,9 @@
}
});
- core::ir::Value* result = param;
- if (entry_point && var->Attributes().builtin.has_value()) {
- switch (var->Attributes().builtin.value()) {
- case core::BuiltinValue::kSampleMask: {
- // Construct an array from the scalar sample_mask builtin value for entry
- // points.
-
- auto* mask_ty = var->Result()->Type()->UnwrapPtr()->As<core::type::Array>();
- TINT_ASSERT(mask_ty);
-
- // If the SPIR-V mask was an i32, need to convert from the u32 provided by
- // WGSL.
- if (mask_ty->ElemType()->IsSignedIntegerScalar()) {
- auto* conv = b.Convert(ty.i32(), result);
- func->Block()->Prepend(conv);
-
- auto* construct = b.Construct(mask_ty, conv);
- construct->InsertAfter(conv);
- result = construct->Result();
- } else {
- auto* construct = b.Construct(mask_ty, result);
- func->Block()->Prepend(construct);
- result = construct->Result();
- }
- break;
- }
- case core::BuiltinValue::kInstanceIndex:
- case core::BuiltinValue::kVertexIndex:
- case core::BuiltinValue::kLocalInvocationIndex:
- case core::BuiltinValue::kSubgroupInvocationId:
- case core::BuiltinValue::kSubgroupSize:
- case core::BuiltinValue::kSampleIndex: {
- auto* idx_ty = var->Result()->Type()->UnwrapPtr();
- if (idx_ty->IsSignedIntegerScalar()) {
- auto* conv = b.Convert(ty.i32(), result);
- func->Block()->Prepend(conv);
- result = conv->Result();
- }
- break;
- }
- case core::BuiltinValue::kLocalInvocationId:
- case core::BuiltinValue::kGlobalInvocationId:
- case core::BuiltinValue::kWorkgroupId:
- case core::BuiltinValue::kNumWorkgroups: {
- auto* idx_ty = var->Result()->Type()->UnwrapPtr();
- auto* elem_ty = idx_ty->DeepestElement();
- if (elem_ty->IsSignedIntegerScalar()) {
- auto* conv = b.Convert(ty.MatchWidth(ty.i32(), idx_ty), result);
- func->Block()->Prepend(conv);
- result = conv->Result();
- }
- break;
- }
- default: {
- break;
- }
- }
- }
- return result;
+ return param;
});
}
-
- /// Add attributes to an entry point function parameter.
- /// @param param the parameter
- /// @param attributes the attributes
- void AddEntryPointParameterAttributes(core::ir::FunctionParam* param,
- const core::IOAttributes& attributes) {
- if (auto* str = param->Type()->UnwrapPtr()->As<core::type::Struct>()) {
- for (const auto* const_member : str->Members()) {
- // TODO(crbug.com/tint/745): Remove the const_cast.
- auto* member = const_cast<core::type::StructMember*>(const_member);
-
- // Use the base variable attributes if not specified directly on the member.
- auto member_attributes = member->Attributes();
- if (auto base_loc = attributes.location) {
- // Location values increment from the base location value on the variable.
- member->SetLocation(base_loc.value() + member->Index());
- }
- if (!member_attributes.interpolation) {
- member->SetInterpolation(attributes.interpolation);
- }
- }
- } else {
- // Set attributes directly on the function parameter.
- param->SetAttributes(attributes);
- }
- }
};
} // namespace