Dawn&Tint: Implement F16 pipeline IO
This CL implement f16 for pipeline IO, i.e. vertex shader input,
interstage variables between vertex and fragment shader, and fragment
shader output (render target). Unit tests and E2E tests for Tint and
Dawn are also implemented.
Bugs: tint:1473, tint:1502
Change-Id: If0d6b2b3171ec8b7e4efc0efd58cc803c6a3d3a8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111160
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index d473d89..b05f1b8 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -157,11 +157,12 @@
ResultOrError<wgpu::TextureComponentType> TintComponentTypeToTextureComponentType(
tint::inspector::ComponentType type) {
switch (type) {
- case tint::inspector::ComponentType::kFloat:
+ case tint::inspector::ComponentType::kF32:
+ case tint::inspector::ComponentType::kF16:
return wgpu::TextureComponentType::Float;
- case tint::inspector::ComponentType::kSInt:
+ case tint::inspector::ComponentType::kI32:
return wgpu::TextureComponentType::Sint;
- case tint::inspector::ComponentType::kUInt:
+ case tint::inspector::ComponentType::kU32:
return wgpu::TextureComponentType::Uint;
case tint::inspector::ComponentType::kUnknown:
return DAWN_VALIDATION_ERROR("Attempted to convert 'Unknown' component type from Tint");
@@ -172,11 +173,12 @@
ResultOrError<VertexFormatBaseType> TintComponentTypeToVertexFormatBaseType(
tint::inspector::ComponentType type) {
switch (type) {
- case tint::inspector::ComponentType::kFloat:
+ case tint::inspector::ComponentType::kF32:
+ case tint::inspector::ComponentType::kF16:
return VertexFormatBaseType::Float;
- case tint::inspector::ComponentType::kSInt:
+ case tint::inspector::ComponentType::kI32:
return VertexFormatBaseType::Sint;
- case tint::inspector::ComponentType::kUInt:
+ case tint::inspector::ComponentType::kU32:
return VertexFormatBaseType::Uint;
case tint::inspector::ComponentType::kUnknown:
return DAWN_VALIDATION_ERROR("Attempted to convert 'Unknown' component type from Tint");
@@ -213,12 +215,14 @@
ResultOrError<InterStageComponentType> TintComponentTypeToInterStageComponentType(
tint::inspector::ComponentType type) {
switch (type) {
- case tint::inspector::ComponentType::kFloat:
- return InterStageComponentType::Float;
- case tint::inspector::ComponentType::kSInt:
- return InterStageComponentType::Sint;
- case tint::inspector::ComponentType::kUInt:
- return InterStageComponentType::Uint;
+ case tint::inspector::ComponentType::kF32:
+ return InterStageComponentType::F32;
+ case tint::inspector::ComponentType::kI32:
+ return InterStageComponentType::I32;
+ case tint::inspector::ComponentType::kU32:
+ return InterStageComponentType::U32;
+ case tint::inspector::ComponentType::kF16:
+ return InterStageComponentType::F16;
case tint::inspector::ComponentType::kUnknown:
return DAWN_VALIDATION_ERROR("Attempted to convert 'Unknown' component type from Tint");
}
@@ -1042,7 +1046,7 @@
continue;
}
- // Uint/sint can't be statically used with a sampler, so they any
+ // Uint/Sint can't be statically used with a sampler, so they any
// texture bindings reflected must be float or depth textures. If
// the shader uses a float/depth texture but the bind group layout
// specifies a uint/sint texture binding,
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 0c829f5..a1f34ac 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -60,9 +60,10 @@
// Base component type of an inter-stage variable
enum class InterStageComponentType {
- Sint,
- Uint,
- Float,
+ I32,
+ U32,
+ F32,
+ F16,
};
enum class InterpolationType {
diff --git a/src/dawn/native/d3d12/AdapterD3D12.cpp b/src/dawn/native/d3d12/AdapterD3D12.cpp
index 0b9d284..b3b7bf4 100644
--- a/src/dawn/native/d3d12/AdapterD3D12.cpp
+++ b/src/dawn/native/d3d12/AdapterD3D12.cpp
@@ -381,6 +381,11 @@
// Even this means that no vertex buffer view has been set in D3D12 backend.
// https://crbug.com/dawn/1255
D3D12_MESSAGE_ID_COMMAND_LIST_DRAW_VERTEX_BUFFER_NOT_SET,
+
+ // When using f16 in vertex attributes the debug layer may report float16_t as type
+ // `unknown`, resulting in a CREATEINPUTLAYOUT_TYPE_MISMATCH warning.
+ // https://crbug.com/tint/1473
+ D3D12_MESSAGE_ID_CREATEINPUTLAYOUT_TYPE_MISMATCH,
};
// Create a retrieval filter with a deny list to suppress messages.
diff --git a/src/dawn/native/webgpu_absl_format.cpp b/src/dawn/native/webgpu_absl_format.cpp
index 7df36ed..f313c3b 100644
--- a/src/dawn/native/webgpu_absl_format.cpp
+++ b/src/dawn/native/webgpu_absl_format.cpp
@@ -409,14 +409,17 @@
const absl::FormatConversionSpec& spec,
absl::FormatSink* s) {
switch (value) {
- case InterStageComponentType::Float:
- s->Append("Float");
+ case InterStageComponentType::F32:
+ s->Append("f32");
break;
- case InterStageComponentType::Uint:
- s->Append("Uint");
+ case InterStageComponentType::F16:
+ s->Append("f16");
break;
- case InterStageComponentType::Sint:
- s->Append("Sint");
+ case InterStageComponentType::U32:
+ s->Append("u32");
+ break;
+ case InterStageComponentType::I32:
+ s->Append("i32");
break;
}
return {true};
diff --git a/src/dawn/tests/end2end/ShaderF16Tests.cpp b/src/dawn/tests/end2end/ShaderF16Tests.cpp
index 42c881b..6722df9 100644
--- a/src/dawn/tests/end2end/ShaderF16Tests.cpp
+++ b/src/dawn/tests/end2end/ShaderF16Tests.cpp
@@ -19,12 +19,30 @@
#include "dawn/utils/WGPUHelpers.h"
namespace {
+
+constexpr uint32_t kRTSize = 16;
+constexpr wgpu::TextureFormat kFormat = wgpu::TextureFormat::RGBA8Unorm;
+
using RequireShaderF16Feature = bool;
DAWN_TEST_PARAM_STRUCT(ShaderF16TestsParams, RequireShaderF16Feature);
} // anonymous namespace
class ShaderF16Tests : public DawnTestWithParams<ShaderF16TestsParams> {
+ public:
+ wgpu::Texture CreateDefault2DTexture() {
+ wgpu::TextureDescriptor descriptor;
+ descriptor.dimension = wgpu::TextureDimension::e2D;
+ descriptor.size.width = kRTSize;
+ descriptor.size.height = kRTSize;
+ descriptor.size.depthOrArrayLayers = 1;
+ descriptor.sampleCount = 1;
+ descriptor.format = kFormat;
+ descriptor.mipLevelCount = 1;
+ descriptor.usage = wgpu::TextureUsage::RenderAttachment | wgpu::TextureUsage::CopySrc;
+ return device.CreateTexture(&descriptor);
+ }
+
protected:
std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
mIsShaderF16SupportedOnAdapter = SupportsFeatures({wgpu::FeatureName::ShaderF16});
@@ -58,6 +76,8 @@
bool mUseDxcEnabledOrNonD3D12 = false;
};
+// Test simple f16 arithmetic within shader with enable directive. The result should be as expect if
+// device enable f16 extension, otherwise a shader creation error should be caught.
TEST_P(ShaderF16Tests, BasicShaderF16FeaturesTest) {
const char* computeShader = R"(
enable f16;
@@ -118,6 +138,308 @@
EXPECT_BUFFER_U32_RANGE_EQ(expected, bufferOut, 0, 1);
}
+// Test that fragment shader use f16 vector type as render target output.
+TEST_P(ShaderF16Tests, RenderPipelineIOF16_RenderTarget) {
+ // Skip if device don't support f16 extension.
+ DAWN_TEST_UNSUPPORTED_IF(!device.HasFeature(wgpu::FeatureName::ShaderF16));
+
+ const char* shader = R"(
+enable f16;
+
+@vertex
+fn VSMain(@builtin(vertex_index) VertexIndex : u32) -> @builtin(position) vec4<f32> {
+ var pos = array<vec2<f32>, 3>(
+ vec2<f32>(-1.0, 1.0),
+ vec2<f32>( 1.0, -1.0),
+ vec2<f32>(-1.0, -1.0));
+
+ return vec4<f32>(pos[VertexIndex], 0.0, 1.0);
+}
+
+@fragment
+fn FSMain() -> @location(0) vec4<f16> {
+ // Paint it blue
+ return vec4<f16>(0.0, 0.0, 1.0, 1.0);
+})";
+
+ wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader);
+
+ // Create render pipeline.
+ wgpu::RenderPipeline pipeline;
+ {
+ utils::ComboRenderPipelineDescriptor descriptor;
+
+ descriptor.vertex.module = shaderModule;
+ descriptor.vertex.entryPoint = "VSMain";
+
+ descriptor.cFragment.module = shaderModule;
+ descriptor.cFragment.entryPoint = "FSMain";
+ descriptor.primitive.topology = wgpu::PrimitiveTopology::TriangleList;
+ descriptor.cTargets[0].format = kFormat;
+
+ pipeline = device.CreateRenderPipeline(&descriptor);
+ }
+
+ wgpu::Texture renderTarget = CreateDefault2DTexture();
+
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+
+ {
+ // In the render pass we clear renderTarget to red and draw a blue triangle in the
+ // bottom left of renderTarget1.
+ utils::ComboRenderPassDescriptor renderPass({renderTarget.CreateView()});
+ renderPass.cColorAttachments[0].clearValue = {1.0f, 0.0f, 0.0f, 1.0f};
+
+ wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass);
+ pass.SetPipeline(pipeline);
+ pass.Draw(3);
+ pass.End();
+ }
+
+ wgpu::CommandBuffer commands = encoder.Finish();
+ queue.Submit(1, &commands);
+
+ // Validate that bottom left of render target is drawed to blue while upper right is still red
+ EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8::kBlue, renderTarget, 1, kRTSize - 1);
+ EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8::kRed, renderTarget, kRTSize - 1, 1);
+}
+
+// Test using f16 types as vertex shader (user-defined) output and fragment shader
+// (user-defined) input.
+TEST_P(ShaderF16Tests, RenderPipelineIOF16_InterstageVariable) {
+ // Skip if device don't support f16 extension.
+ DAWN_TEST_UNSUPPORTED_IF(!device.HasFeature(wgpu::FeatureName::ShaderF16));
+
+ const char* shader = R"(
+enable f16;
+
+struct VSOutput{
+ @builtin(position)
+ pos: vec4<f32>,
+ @location(3)
+ color_vsout: vec4<f16>,
+}
+
+@vertex
+fn VSMain(@builtin(vertex_index) VertexIndex : u32) -> VSOutput {
+ var pos = array<vec2<f32>, 3>(
+ vec2<f32>(-1.0, 1.0),
+ vec2<f32>( 1.0, -1.0),
+ vec2<f32>(-1.0, -1.0));
+
+ // Blue
+ var color = vec4<f16>(0.0h, 0.0h, 1.0h, 1.0h);
+
+ var result: VSOutput;
+ result.pos = vec4<f32>(pos[VertexIndex], 0.0, 1.0);
+ result.color_vsout = color;
+
+ return result;
+}
+
+struct FSInput{
+ @location(3)
+ color_fsin: vec4<f16>,
+}
+
+@fragment
+fn FSMain(fsInput: FSInput) -> @location(0) vec4<f32> {
+ // Paint it with given color
+ return vec4<f32>(fsInput.color_fsin);
+})";
+
+ wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader);
+
+ // Create render pipeline.
+ wgpu::RenderPipeline pipeline;
+ {
+ utils::ComboRenderPipelineDescriptor descriptor;
+
+ descriptor.vertex.module = shaderModule;
+ descriptor.vertex.entryPoint = "VSMain";
+
+ descriptor.cFragment.module = shaderModule;
+ descriptor.cFragment.entryPoint = "FSMain";
+ descriptor.primitive.topology = wgpu::PrimitiveTopology::TriangleList;
+ descriptor.cTargets[0].format = kFormat;
+
+ pipeline = device.CreateRenderPipeline(&descriptor);
+ }
+
+ wgpu::Texture renderTarget = CreateDefault2DTexture();
+
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+
+ {
+ // In the first render pass we clear renderTarget1 to red and draw a blue triangle in the
+ // bottom left of renderTarget1.
+ utils::ComboRenderPassDescriptor renderPass({renderTarget.CreateView()});
+ renderPass.cColorAttachments[0].clearValue = {1.0f, 0.0f, 0.0f, 1.0f};
+
+ wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass);
+ pass.SetPipeline(pipeline);
+ pass.Draw(3);
+ pass.End();
+ }
+
+ wgpu::CommandBuffer commands = encoder.Finish();
+ queue.Submit(1, &commands);
+
+ // Validate that bottom left of render target is drawed to blue while upper right is still red
+ EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8::kBlue, renderTarget, 1, kRTSize - 1);
+ EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8::kRed, renderTarget, kRTSize - 1, 1);
+}
+
+// Test using f16 types as vertex shader user-defined input (vertex attributes), draw points of
+// different color given as vertex attributes.
+TEST_P(ShaderF16Tests, RenderPipelineIOF16_VertexAttribute) {
+ // Skip if device don't support f16 extension.
+ DAWN_TEST_UNSUPPORTED_IF(!device.HasFeature(wgpu::FeatureName::ShaderF16));
+
+ const char* shader = R"(
+enable f16;
+
+struct VSInput {
+ // position / 2.0
+ @location(0) pos_half : vec2<f16>,
+ // color / 4.0
+ @location(1) color_quarter : vec4<f16>,
+}
+
+struct VSOutput {
+ @builtin(position) pos : vec4<f32>,
+ @location(0) color : vec4<f32>,
+}
+
+@vertex
+fn VSMain(in: VSInput) -> VSOutput {
+ return VSOutput(vec4<f32>(vec2<f32>(in.pos_half * 2.0h), 0.0, 1.0), vec4<f32>(in.color_quarter * 4.0h));
+}
+
+@fragment
+fn FSMain(@location(0) color : vec4<f32>) -> @location(0) vec4<f32> {
+ return color;
+})";
+
+ wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader);
+
+ constexpr uint32_t kPointCount = 8;
+
+ // Position (divided by 2.0) for points on horizontal line
+ std::vector<float> positionData;
+ constexpr float xStep = 2.0 / kPointCount;
+ constexpr float xBias = -1.0 + xStep / 2.0f;
+ for (uint32_t i = 0; i < kPointCount; i++) {
+ // X position, divided by 2.0
+ positionData.push_back((xBias + xStep * i) / 2.0f);
+ // Y position (0.0f) divided by 2.0
+ positionData.push_back(0.0f);
+ }
+
+ // Expected color for each point
+ using RGBA8 = utils::RGBA8;
+ std::vector<RGBA8> colors = {
+ RGBA8::kBlack,
+ RGBA8::kRed,
+ RGBA8::kGreen,
+ RGBA8::kBlue,
+ RGBA8::kYellow,
+ RGBA8::kWhite,
+ RGBA8(96, 192, 176, 255),
+ RGBA8(184, 108, 184, 255),
+ };
+
+ ASSERT(colors.size() == kPointCount);
+ // Color (divided by 4.0) for each point
+ std::vector<float> colorData;
+ for (RGBA8& color : colors) {
+ colorData.push_back(color.r / 255.0 / 4.0);
+ colorData.push_back(color.g / 255.0 / 4.0);
+ colorData.push_back(color.b / 255.0 / 4.0);
+ colorData.push_back(color.a / 255.0 / 4.0);
+ }
+
+ // Store the data as float32x2 and float32x4 in vertex buffer, which should be convert to
+ // corresponding WGSL type vec2<f16> and vec4<f16> by driver.
+ // Buffer for pos_half
+ wgpu::Buffer vertexBufferPos = utils::CreateBufferFromData(
+ device, positionData.data(), 2 * kPointCount * sizeof(float), wgpu::BufferUsage::Vertex);
+ // Buffer for color_quarter
+ wgpu::Buffer vertexBufferColor = utils::CreateBufferFromData(
+ device, colorData.data(), 4 * kPointCount * sizeof(float), wgpu::BufferUsage::Vertex);
+
+ // Create render pipeline.
+ wgpu::RenderPipeline pipeline;
+ {
+ utils::ComboRenderPipelineDescriptor descriptor;
+
+ descriptor.vertex.module = shaderModule;
+ descriptor.vertex.entryPoint = "VSMain";
+ descriptor.vertex.bufferCount = 2;
+ // Interprete the vertex buffer data as Float32x2 and Float32x4, and the result should be
+ // converted to vec2<f16> and vec4<f16>
+ descriptor.cAttributes[0].format = wgpu::VertexFormat::Float32x2;
+ descriptor.cAttributes[0].offset = 0;
+ descriptor.cAttributes[0].shaderLocation = 0;
+ descriptor.cBuffers[0].stepMode = wgpu::VertexStepMode::Vertex;
+ descriptor.cBuffers[0].arrayStride = 8;
+ descriptor.cBuffers[0].attributeCount = 1;
+ descriptor.cBuffers[0].attributes = &descriptor.cAttributes[0];
+ descriptor.cAttributes[1].format = wgpu::VertexFormat::Float32x4;
+ descriptor.cAttributes[1].offset = 0;
+ descriptor.cAttributes[1].shaderLocation = 1;
+ descriptor.cBuffers[1].stepMode = wgpu::VertexStepMode::Vertex;
+ descriptor.cBuffers[1].arrayStride = 16;
+ descriptor.cBuffers[1].attributeCount = 1;
+ descriptor.cBuffers[1].attributes = &descriptor.cAttributes[1];
+
+ descriptor.cFragment.module = shaderModule;
+ descriptor.cFragment.entryPoint = "FSMain";
+ descriptor.primitive.topology = wgpu::PrimitiveTopology::PointList;
+ descriptor.cTargets[0].format = kFormat;
+
+ pipeline = device.CreateRenderPipeline(&descriptor);
+ }
+
+ // Create a render target of horizontal line
+ wgpu::Texture renderTarget;
+ {
+ wgpu::TextureDescriptor descriptor;
+ descriptor.dimension = wgpu::TextureDimension::e2D;
+ descriptor.size.width = kPointCount;
+ descriptor.size.height = 1;
+ descriptor.size.depthOrArrayLayers = 1;
+ descriptor.sampleCount = 1;
+ descriptor.format = kFormat;
+ descriptor.mipLevelCount = 1;
+ descriptor.usage = wgpu::TextureUsage::RenderAttachment | wgpu::TextureUsage::CopySrc;
+ renderTarget = device.CreateTexture(&descriptor);
+ }
+
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+
+ {
+ // Clear renderTarget to zero and draw points.
+ utils::ComboRenderPassDescriptor renderPass({renderTarget.CreateView()});
+ renderPass.cColorAttachments[0].clearValue = {0.0f, 0.0f, 0.0f, 0.0f};
+
+ wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass);
+ pass.SetPipeline(pipeline);
+ pass.SetVertexBuffer(0, vertexBufferPos);
+ pass.SetVertexBuffer(1, vertexBufferColor);
+ pass.Draw(kPointCount);
+ pass.End();
+ }
+
+ wgpu::CommandBuffer commands = encoder.Finish();
+ queue.Submit(1, &commands);
+
+ // Validate the color of each point
+ for (uint32_t i = 0; i < kPointCount; i++) {
+ EXPECT_PIXEL_RGBA8_EQ(colors[i], renderTarget, i, 0);
+ }
+}
+
// DawnTestBase::CreateDeviceImpl always disable disallow_unsafe_apis toggle.
DAWN_INSTANTIATE_TEST_P(ShaderF16Tests,
{
diff --git a/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
index 927f05b..7e8a19f 100644
--- a/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
@@ -24,6 +24,26 @@
class RenderPipelineValidationTest : public ValidationTest {
protected:
+ WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {
+ // Disabled disallowing unsafe APIs so we can test ShaderF16 feature.
+ const char* forceDisabledToggle[] = {"disallow_unsafe_apis"};
+
+ wgpu::DeviceDescriptor descriptor;
+ wgpu::FeatureName requiredFeatures[1] = {wgpu::FeatureName::ShaderF16};
+ descriptor.requiredFeatures = requiredFeatures;
+ descriptor.requiredFeaturesCount = 1;
+
+ wgpu::DawnTogglesDeviceDescriptor togglesDesc;
+ descriptor.nextInChain = &togglesDesc;
+
+ togglesDesc.forceEnabledToggles = nullptr;
+ togglesDesc.forceEnabledTogglesCount = 0;
+ togglesDesc.forceDisabledToggles = forceDisabledToggle;
+ togglesDesc.forceDisabledTogglesCount = 1;
+
+ return dawnAdapter.CreateDevice(&descriptor);
+ }
+
void SetUp() override {
ValidationTest::SetUp();
@@ -326,30 +346,53 @@
// Tests that the format of the color state descriptor must match the output of the fragment shader.
TEST_F(RenderPipelineValidationTest, FragmentOutputFormatCompatibility) {
- std::array<const char*, 3> kScalarTypes = {{"f32", "i32", "u32"}};
- std::array<wgpu::TextureFormat, 3> kColorFormats = {{wgpu::TextureFormat::RGBA8Unorm,
- wgpu::TextureFormat::RGBA8Sint,
- wgpu::TextureFormat::RGBA8Uint}};
+ std::vector<std::vector<std::string>> kScalarTypeLists = {// Float scalar types
+ {"f32", "f16"},
+ // Sint scalar type
+ {"i32"},
+ // Uint scalar type
+ {"u32"}};
- for (size_t i = 0; i < kScalarTypes.size(); ++i) {
- utils::ComboRenderPipelineDescriptor descriptor;
- descriptor.vertex.module = vsModule;
- std::ostringstream stream;
- stream << R"(
+ std::vector<std::vector<wgpu::TextureFormat>> kColorFormatLists = {
+ // Float color formats
+ {wgpu::TextureFormat::RGBA8Unorm, wgpu::TextureFormat::RGBA16Float,
+ wgpu::TextureFormat::RGBA32Float},
+ // Sint color formats
+ {wgpu::TextureFormat::RGBA8Sint, wgpu::TextureFormat::RGBA16Sint,
+ wgpu::TextureFormat::RGBA32Sint},
+ // Uint color formats
+ {wgpu::TextureFormat::RGBA8Uint, wgpu::TextureFormat::RGBA16Uint,
+ wgpu::TextureFormat::RGBA32Uint}};
+
+ for (size_t i = 0; i < kScalarTypeLists.size(); ++i) {
+ for (const std::string& scalarType : kScalarTypeLists[i]) {
+ utils::ComboRenderPipelineDescriptor descriptor;
+ descriptor.vertex.module = vsModule;
+ std::ostringstream stream;
+
+ // Enable f16 extension if needed.
+ if (scalarType == "f16") {
+ stream << "enable f16;\n\n";
+ }
+ stream << R"(
@fragment fn main() -> @location(0) vec4<)"
- << kScalarTypes[i] << R"(> {
+ << scalarType << R"(> {
var result : vec4<)"
- << kScalarTypes[i] << R"(>;
+ << scalarType << R"(>;
return result;
})";
- descriptor.cFragment.module = utils::CreateShaderModule(device, stream.str().c_str());
- for (size_t j = 0; j < kColorFormats.size(); ++j) {
- descriptor.cTargets[0].format = kColorFormats[j];
- if (i == j) {
- device.CreateRenderPipeline(&descriptor);
- } else {
- ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+ descriptor.cFragment.module = utils::CreateShaderModule(device, stream.str().c_str());
+
+ for (size_t j = 0; j < kColorFormatLists.size(); ++j) {
+ for (wgpu::TextureFormat textureFormat : kColorFormatLists[j]) {
+ descriptor.cTargets[0].format = textureFormat;
+ if (i == j) {
+ device.CreateRenderPipeline(&descriptor);
+ } else {
+ ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+ }
+ }
}
}
}
@@ -1446,12 +1489,13 @@
// Tests that creating render pipeline should fail when the type of a vertex stage output variable
// doesn't match the type of the fragment stage input variable at the same location.
TEST_F(InterStageVariableMatchingValidationTest, DifferentTypeAtSameLocation) {
- constexpr std::array<const char*, 12> kTypes = {{"f32", "vec2<f32>", "vec3<f32>", "vec4<f32>",
+ constexpr std::array<const char*, 16> kTypes = {{"f32", "vec2<f32>", "vec3<f32>", "vec4<f32>",
+ "f16", "vec2<f16>", "vec3<f16>", "vec4<f16>",
"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>",
"u32", "vec2<u32>", "vec3<u32>", "vec4<u32>"}};
- std::array<wgpu::ShaderModule, 12> vertexModules;
- std::array<wgpu::ShaderModule, 12> fragmentModules;
+ std::array<wgpu::ShaderModule, 16> vertexModules;
+ std::array<wgpu::ShaderModule, 16> fragmentModules;
for (uint32_t i = 0; i < kTypes.size(); ++i) {
std::string interfaceDeclaration;
{
@@ -1460,9 +1504,12 @@
<< std::endl;
interfaceDeclaration = sstream.str();
}
+
+ std::string extensionDeclaration = "enable f16;\n\n";
+
{
std::ostringstream vertexStream;
- vertexStream << interfaceDeclaration << R"(
+ vertexStream << extensionDeclaration << interfaceDeclaration << R"(
@builtin(position) pos: vec4<f32>,
}
@vertex fn main() -> A {
@@ -1474,7 +1521,7 @@
}
{
std::ostringstream fragmentStream;
- fragmentStream << interfaceDeclaration << R"(
+ fragmentStream << extensionDeclaration << interfaceDeclaration << R"(
}
@fragment fn main(fragmentIn: A) -> @location(0) vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 1.0);
diff --git a/src/tint/inspector/entry_point.h b/src/tint/inspector/entry_point.h
index 4a4706b..eabe601 100644
--- a/src/tint/inspector/entry_point.h
+++ b/src/tint/inspector/entry_point.h
@@ -30,9 +30,10 @@
/// Base component type of a stage variable.
enum class ComponentType {
kUnknown = -1,
- kFloat,
- kUInt,
- kSInt,
+ kF32,
+ kU32,
+ kI32,
+ kF16,
};
/// Composition of components of a stage variable.
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc
index ea5b3ed..b893ed1 100644
--- a/src/tint/inspector/inspector.cc
+++ b/src/tint/inspector/inspector.cc
@@ -69,41 +69,48 @@
}
std::tuple<ComponentType, CompositionType> CalculateComponentAndComposition(const sem::Type* type) {
- if (type->is_float_scalar()) {
- return {ComponentType::kFloat, CompositionType::kScalar};
- } else if (type->is_float_vector()) {
- auto* vec = type->As<sem::Vector>();
- if (vec->Width() == 2) {
- return {ComponentType::kFloat, CompositionType::kVec2};
- } else if (vec->Width() == 3) {
- return {ComponentType::kFloat, CompositionType::kVec3};
- } else if (vec->Width() == 4) {
- return {ComponentType::kFloat, CompositionType::kVec4};
+ // entry point in/out variables must of numeric scalar or vector types.
+ TINT_ASSERT(Inspector, type->is_numeric_scalar_or_vector());
+
+ ComponentType componentType = Switch(
+ sem::Type::DeepestElementOf(type), //
+ [&](const sem::F32*) { return ComponentType::kF32; },
+ [&](const sem::F16*) { return ComponentType::kF16; },
+ [&](const sem::I32*) { return ComponentType::kI32; },
+ [&](const sem::U32*) { return ComponentType::kU32; },
+ [&](Default) {
+ tint::diag::List diagnostics;
+ TINT_UNREACHABLE(Inspector, diagnostics) << "unhandled component type";
+ return ComponentType::kUnknown;
+ });
+
+ CompositionType compositionType;
+ if (auto* vec = type->As<sem::Vector>()) {
+ switch (vec->Width()) {
+ case 2: {
+ compositionType = CompositionType::kVec2;
+ break;
+ }
+ case 3: {
+ compositionType = CompositionType::kVec3;
+ break;
+ }
+ case 4: {
+ compositionType = CompositionType::kVec4;
+ break;
+ }
+ default: {
+ tint::diag::List diagnostics;
+ TINT_UNREACHABLE(Inspector, diagnostics) << "unhandled composition type";
+ compositionType = CompositionType::kUnknown;
+ break;
+ }
}
- } else if (type->is_unsigned_integer_scalar()) {
- return {ComponentType::kUInt, CompositionType::kScalar};
- } else if (type->is_unsigned_integer_vector()) {
- auto* vec = type->As<sem::Vector>();
- if (vec->Width() == 2) {
- return {ComponentType::kUInt, CompositionType::kVec2};
- } else if (vec->Width() == 3) {
- return {ComponentType::kUInt, CompositionType::kVec3};
- } else if (vec->Width() == 4) {
- return {ComponentType::kUInt, CompositionType::kVec4};
- }
- } else if (type->is_signed_integer_scalar()) {
- return {ComponentType::kSInt, CompositionType::kScalar};
- } else if (type->is_signed_integer_vector()) {
- auto* vec = type->As<sem::Vector>();
- if (vec->Width() == 2) {
- return {ComponentType::kSInt, CompositionType::kVec2};
- } else if (vec->Width() == 3) {
- return {ComponentType::kSInt, CompositionType::kVec3};
- } else if (vec->Width() == 4) {
- return {ComponentType::kSInt, CompositionType::kVec4};
- }
+ } else {
+ compositionType = CompositionType::kScalar;
}
- return {ComponentType::kUnknown, CompositionType::kUnknown};
+
+ return {componentType, compositionType};
}
std::tuple<InterpolationType, InterpolationSampling> CalculateInterpolationData(
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index 83302e2..78cfb5a 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -287,6 +287,10 @@
std::tie(component, composition) = GetParam();
std::function<const ast::Type*()> tint_type = GetTypeFunction(component, composition);
+ if (component == ComponentType::kF16) {
+ Enable(ast::Extension::kF16);
+ }
+
auto* in_var = Param("in_var", tint_type(),
utils::Vector{
Location(0_u),
@@ -323,9 +327,10 @@
}
INSTANTIATE_TEST_SUITE_P(InspectorGetEntryPointTest,
InspectorGetEntryPointComponentAndCompositionTest,
- testing::Combine(testing::Values(ComponentType::kFloat,
- ComponentType::kSInt,
- ComponentType::kUInt),
+ testing::Combine(testing::Values(ComponentType::kF32,
+ ComponentType::kI32,
+ ComponentType::kU32,
+ ComponentType::kF16),
testing::Values(CompositionType::kScalar,
CompositionType::kVec2,
CompositionType::kVec3,
@@ -369,23 +374,23 @@
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[0].interpolation_type);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[0].component_type);
EXPECT_EQ("in_var1", result[0].input_variables[1].name);
EXPECT_TRUE(result[0].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].input_variables[1].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[1].interpolation_type);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[1].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[1].component_type);
EXPECT_EQ("in_var4", result[0].input_variables[2].name);
EXPECT_TRUE(result[0].input_variables[2].has_location_attribute);
EXPECT_EQ(4u, result[0].input_variables[2].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[2].interpolation_type);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[2].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[2].component_type);
ASSERT_EQ(1u, result[0].output_variables.size());
EXPECT_EQ("<retval>", result[0].output_variables[0].name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
}
TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) {
@@ -433,26 +438,26 @@
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[0].interpolation_type);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[0].component_type);
ASSERT_EQ(1u, result[0].output_variables.size());
EXPECT_EQ("<retval>", result[0].output_variables[0].name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
ASSERT_EQ(1u, result[1].input_variables.size());
EXPECT_EQ("in_var_bar", result[1].input_variables[0].name);
EXPECT_TRUE(result[1].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[1].input_variables[0].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[1].input_variables[0].interpolation_type);
- EXPECT_EQ(ComponentType::kUInt, result[1].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[1].input_variables[0].component_type);
ASSERT_EQ(1u, result[1].output_variables.size());
EXPECT_EQ("<retval>", result[1].output_variables[0].name);
EXPECT_TRUE(result[1].output_variables[0].has_location_attribute);
EXPECT_EQ(1u, result[1].output_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[1].output_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[1].output_variables[0].component_type);
}
TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) {
@@ -485,7 +490,7 @@
EXPECT_EQ("in_var1", result[0].input_variables[0].name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kFloat, result[0].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kF32, result[0].input_variables[0].component_type);
ASSERT_EQ(0u, result[0].output_variables.size());
}
@@ -517,21 +522,21 @@
EXPECT_EQ("param.a", result[0].input_variables[0].name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[0].component_type);
EXPECT_EQ("param.b", result[0].input_variables[1].name);
EXPECT_TRUE(result[0].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].input_variables[1].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[1].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[1].component_type);
ASSERT_EQ(2u, result[0].output_variables.size());
EXPECT_EQ("<retval>.a", result[0].output_variables[0].name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
EXPECT_EQ("<retval>.b", result[0].output_variables[1].name);
EXPECT_TRUE(result[0].output_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].output_variables[1].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[1].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[1].component_type);
}
TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutSharedStruct) {
@@ -563,21 +568,21 @@
EXPECT_EQ("<retval>.a", result[0].output_variables[0].name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
EXPECT_EQ("<retval>.b", result[0].output_variables[1].name);
EXPECT_TRUE(result[0].output_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].output_variables[1].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[1].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[1].component_type);
ASSERT_EQ(2u, result[1].input_variables.size());
EXPECT_EQ("param.a", result[1].input_variables[0].name);
EXPECT_TRUE(result[1].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[1].input_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[1].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[1].input_variables[0].component_type);
EXPECT_EQ("param.b", result[1].input_variables[1].name);
EXPECT_TRUE(result[1].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[1].input_variables[1].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[1].input_variables[1].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[1].input_variables[1].component_type);
ASSERT_EQ(0u, result[1].output_variables.size());
}
@@ -615,33 +620,33 @@
EXPECT_EQ("param_a.a", result[0].input_variables[0].name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[0].component_type);
EXPECT_EQ("param_a.b", result[0].input_variables[1].name);
EXPECT_TRUE(result[0].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].input_variables[1].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[1].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[1].component_type);
EXPECT_EQ("param_b.a", result[0].input_variables[2].name);
EXPECT_TRUE(result[0].input_variables[2].has_location_attribute);
EXPECT_EQ(2u, result[0].input_variables[2].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[2].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].input_variables[2].component_type);
EXPECT_EQ("param_c", result[0].input_variables[3].name);
EXPECT_TRUE(result[0].input_variables[3].has_location_attribute);
EXPECT_EQ(3u, result[0].input_variables[3].location_attribute);
- EXPECT_EQ(ComponentType::kFloat, result[0].input_variables[3].component_type);
+ EXPECT_EQ(ComponentType::kF32, result[0].input_variables[3].component_type);
EXPECT_EQ("param_d", result[0].input_variables[4].name);
EXPECT_TRUE(result[0].input_variables[4].has_location_attribute);
EXPECT_EQ(4u, result[0].input_variables[4].location_attribute);
- EXPECT_EQ(ComponentType::kFloat, result[0].input_variables[4].component_type);
+ EXPECT_EQ(ComponentType::kF32, result[0].input_variables[4].component_type);
ASSERT_EQ(2u, result[0].output_variables.size());
EXPECT_EQ("<retval>.a", result[0].output_variables[0].name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
EXPECT_EQ("<retval>.b", result[0].output_variables[1].name);
EXPECT_TRUE(result[0].output_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].output_variables[1].location_attribute);
- EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[1].component_type);
+ EXPECT_EQ(ComponentType::kU32, result[0].output_variables[1].component_type);
}
TEST_F(InspectorGetEntryPointTest, OverrideUnreferenced) {
diff --git a/src/tint/inspector/test_inspector_builder.cc b/src/tint/inspector/test_inspector_builder.cc
index 79122dc..97ae097 100644
--- a/src/tint/inspector/test_inspector_builder.cc
+++ b/src/tint/inspector/test_inspector_builder.cc
@@ -307,15 +307,18 @@
CompositionType composition) {
std::function<const ast::Type*()> func;
switch (component) {
- case ComponentType::kFloat:
+ case ComponentType::kF32:
func = [this]() -> const ast::Type* { return ty.f32(); };
break;
- case ComponentType::kSInt:
+ case ComponentType::kI32:
func = [this]() -> const ast::Type* { return ty.i32(); };
break;
- case ComponentType::kUInt:
+ case ComponentType::kU32:
func = [this]() -> const ast::Type* { return ty.u32(); };
break;
+ case ComponentType::kF16:
+ func = [this]() -> const ast::Type* { return ty.f16(); };
+ break;
case ComponentType::kUnknown:
return []() -> const ast::Type* { return nullptr; };
}
diff --git a/src/tint/resolver/entry_point_validation_test.cc b/src/tint/resolver/entry_point_validation_test.cc
index 79b41d7..cd49ded 100644
--- a/src/tint/resolver/entry_point_validation_test.cc
+++ b/src/tint/resolver/entry_point_validation_test.cc
@@ -606,17 +606,14 @@
ParamsFor<alias<i32>>(true), //
ParamsFor<alias<u32>>(true), //
ParamsFor<alias<bool>>(false), //
- // Currently entry point IO of f16 types are not implemented yet.
- // TODO(tint:1473, tint:1502): Change f16 and vecN<f16> cases to valid after f16 is supported in
- // entry point IO.
- ParamsFor<f16>(false), //
- ParamsFor<vec2<f16>>(false), //
- ParamsFor<vec3<f16>>(false), //
- ParamsFor<vec4<f16>>(false), //
+ ParamsFor<f16>(true), //
+ ParamsFor<vec2<f16>>(true), //
+ ParamsFor<vec3<f16>>(true), //
+ ParamsFor<vec4<f16>>(true), //
ParamsFor<mat2x2<f16>>(false), //
ParamsFor<mat3x3<f16>>(false), //
ParamsFor<mat4x4<f16>>(false), //
- ParamsFor<alias<f16>>(false), //
+ ParamsFor<alias<f16>>(true), //
};
TEST_P(TypeValidationTest, BareInputs) {
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 8da1726..558217d 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1068,13 +1068,6 @@
ParamOrRetType param_or_ret,
bool is_struct_member,
std::optional<uint32_t> location) {
- // Temporally forbid using f16 types in entry point IO.
- // TODO(tint:1473, tint:1502): Remove this error after f16 is supported in entry point IO.
- if (Is<sem::F16>(sem::Type::DeepestElementOf(ty))) {
- AddError("entry point IO of f16 types is not implemented yet", source);
- return false;
- }
-
// Scan attributes for pipeline IO attributes.
// Check for overlap with attributes that have been seen previously.
const ast::Attribute* pipeline_io_attribute = nullptr;
diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc
index 0196002..e213ac6 100644
--- a/src/tint/transform/vertex_pulling.cc
+++ b/src/tint/transform/vertex_pulling.cc
@@ -41,6 +41,7 @@
kU32,
kI32,
kF32,
+ kF16,
};
/// The data type of a vertex format.
@@ -138,6 +139,7 @@
bool IsTypeCompatible(AttributeWGSLType wgslType, VertexFormatType vertexFormatType) {
switch (wgslType.base_type) {
case BaseWGSLType::kF32:
+ case BaseWGSLType::kF16:
return (vertexFormatType.base_type == VertexDataType::kFloat);
case BaseWGSLType::kU32:
return (vertexFormatType.base_type == VertexDataType::kUInt);
@@ -149,19 +151,26 @@
}
AttributeWGSLType WGSLTypeOf(const sem::Type* ty) {
- if (ty->Is<sem::I32>()) {
- return {BaseWGSLType::kI32, 1};
- }
- if (ty->Is<sem::U32>()) {
- return {BaseWGSLType::kU32, 1};
- }
- if (ty->Is<sem::F32>()) {
- return {BaseWGSLType::kF32, 1};
- }
- if (auto* vec = ty->As<sem::Vector>()) {
- return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
- }
- return {BaseWGSLType::kInvalid, 0};
+ return Switch(
+ ty,
+ [](const sem::I32*) -> AttributeWGSLType {
+ return {BaseWGSLType::kI32, 1};
+ },
+ [](const sem::U32*) -> AttributeWGSLType {
+ return {BaseWGSLType::kU32, 1};
+ },
+ [](const sem::F32*) -> AttributeWGSLType {
+ return {BaseWGSLType::kF32, 1};
+ },
+ [](const sem::F16*) -> AttributeWGSLType {
+ return {BaseWGSLType::kF16, 1};
+ },
+ [](const sem::Vector* vec) -> AttributeWGSLType {
+ return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
+ },
+ [](Default) -> AttributeWGSLType {
+ return {BaseWGSLType::kInvalid, 0};
+ });
}
VertexFormatType VertexFormatTypeOf(VertexFormat format) {
@@ -378,9 +387,22 @@
// Load the attribute value according to vertex format and convert the element type
// of result to match target WGSL variable. The result of `Fetch` should be of WGSL
- // types `f32`, `i32`, `u32`, and their vectors.
+ // types `f32`, `i32`, `u32`, and their vectors, while WGSL variable can be of
+ // `f16`.
auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx,
attribute_desc.format);
+ // Convert the fetched scalar/vector if WGSL variable is of `f16` types
+ if (var_dt.base_type == BaseWGSLType::kF16) {
+ // The type of the same element number of base type of target WGSL variable
+ const ast::Type* loaded_data_target_type;
+ if (fmt_dt.width == 1) {
+ loaded_data_target_type = b.ty.f16();
+ } else {
+ loaded_data_target_type = b.ty.vec(b.ty.f16(), fmt_dt.width);
+ }
+
+ fetch = b.Construct(loaded_data_target_type, fetch);
+ }
// The attribute value may not be of the desired vector width. If it is not, we'll
// need to either reduce the width with a swizzle, or append 0's and / or a 1.
diff --git a/src/tint/transform/vertex_pulling_test.cc b/src/tint/transform/vertex_pulling_test.cc
index 54c348e..a6dc2d2 100644
--- a/src/tint/transform/vertex_pulling_test.cc
+++ b/src/tint/transform/vertex_pulling_test.cc
@@ -736,6 +736,63 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(VertexPullingTest, FloatVectorAttributes_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(@location(0) var_a : vec2<f16>,
+ @location(1) var_b : vec3<f16>,
+ @location(2) var_c : vec4<f16>
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+@binding(2) @group(4) var<storage, read> tint_pulling_vertex_buffer_2 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : vec2<f16>;
+ var var_b : vec3<f16>;
+ var var_c : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 2u);
+ var_a = vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 1u)])));
+ let buffer_array_base_1 = (tint_pulling_vertex_index * 3u);
+ var_b = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(buffer_array_base_1 + 1u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(buffer_array_base_1 + 2u)])));
+ let buffer_array_base_2 = (tint_pulling_vertex_index * 4u);
+ var_c = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_2.tint_vertex_data[buffer_array_base_2]), unpack2x16float(tint_pulling_vertex_buffer_2.tint_vertex_data[(buffer_array_base_2 + 1u)])));
+ }
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {8, VertexStepMode::kVertex, {{VertexFormat::kFloat32x2, 0, 0}}},
+ {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}},
+ {16, VertexStepMode::kVertex, {{VertexFormat::kFloat16x4, 0, 2}}},
+ }};
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(VertexPullingTest, AttemptSymbolCollision) {
auto* src = R"(
@vertex
@@ -1019,6 +1076,104 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(VertexPullingTest, FormatsAligned_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) unorm8x2 : vec2<f16>,
+ @location(1) unorm8x4 : vec4<f16>,
+ @location(2) snorm8x2 : vec2<f16>,
+ @location(3) snorm8x4 : vec4<f16>,
+ @location(4) unorm16x2 : vec2<f16>,
+ @location(5) unorm16x4 : vec4<f16>,
+ @location(6) snorm16x2 : vec2<f16>,
+ @location(7) snorm16x4 : vec4<f16>,
+ @location(8) float16x2 : vec2<f16>,
+ @location(9) float16x4 : vec4<f16>,
+ @location(10) float32 : f16,
+ @location(11) float32x2 : vec2<f16>,
+ @location(12) float32x3 : vec3<f16>,
+ @location(13) float32x4 : vec4<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var unorm8x2 : vec2<f16>;
+ var unorm8x4 : vec4<f16>;
+ var snorm8x2 : vec2<f16>;
+ var snorm8x4 : vec4<f16>;
+ var unorm16x2 : vec2<f16>;
+ var unorm16x4 : vec4<f16>;
+ var snorm16x2 : vec2<f16>;
+ var snorm16x4 : vec4<f16>;
+ var float16x2 : vec2<f16>;
+ var float16x4 : vec4<f16>;
+ var float32 : f16;
+ var float32x2 : vec2<f16>;
+ var float32x3 : vec3<f16>;
+ var float32x4 : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ unorm8x2 = vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy);
+ unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ snorm8x2 = vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy);
+ snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ unorm16x2 = vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ snorm16x2 = vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ float16x2 = vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ float32 = f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ float32x2 = vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])));
+ float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0},
+ {VertexFormat::kUnorm8x4, 64, 1},
+ {VertexFormat::kSnorm8x2, 64, 2},
+ {VertexFormat::kSnorm8x4, 64, 3},
+ {VertexFormat::kUnorm16x2, 64, 4},
+ {VertexFormat::kUnorm16x4, 64, 5},
+ {VertexFormat::kSnorm16x2, 64, 6},
+ {VertexFormat::kSnorm16x4, 64, 7},
+ {VertexFormat::kFloat16x2, 64, 8},
+ {VertexFormat::kFloat16x4, 64, 9},
+ {VertexFormat::kFloat32, 64, 10},
+ {VertexFormat::kFloat32x2, 64, 11},
+ {VertexFormat::kFloat32x3, 64, 12},
+ {VertexFormat::kFloat32x4, 64, 13},
+ }}}};
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(VertexPullingTest, FormatsUnaligned_SInt) {
auto* src = R"(
@vertex
@@ -1253,6 +1408,104 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(VertexPullingTest, FormatsUnaligned_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) unorm8x2 : vec2<f16>,
+ @location(1) unorm8x4 : vec4<f16>,
+ @location(2) snorm8x2 : vec2<f16>,
+ @location(3) snorm8x4 : vec4<f16>,
+ @location(4) unorm16x2 : vec2<f16>,
+ @location(5) unorm16x4 : vec4<f16>,
+ @location(6) snorm16x2 : vec2<f16>,
+ @location(7) snorm16x4 : vec4<f16>,
+ @location(8) float16x2 : vec2<f16>,
+ @location(9) float16x4 : vec4<f16>,
+ @location(10) float32 : f16,
+ @location(11) float32x2 : vec2<f16>,
+ @location(12) float32x3 : vec3<f16>,
+ @location(13) float32x4 : vec4<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var unorm8x2 : vec2<f16>;
+ var unorm8x4 : vec4<f16>;
+ var snorm8x2 : vec2<f16>;
+ var snorm8x4 : vec4<f16>;
+ var unorm16x2 : vec2<f16>;
+ var unorm16x4 : vec4<f16>;
+ var snorm16x2 : vec2<f16>;
+ var snorm16x4 : vec4<f16>;
+ var float16x2 : vec2<f16>;
+ var float16x4 : vec4<f16>;
+ var float32 : f16;
+ var float32x2 : vec2<f16>;
+ var float32x3 : vec3<f16>;
+ var float32x4 : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ unorm8x2 = vec2<f16>(unpack4x8unorm((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u)) & 65535u)).xy);
+ unorm8x4 = vec4<f16>(unpack4x8unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ snorm8x2 = vec2<f16>(unpack4x8snorm((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u)) & 65535u)).xy);
+ snorm8x4 = vec4<f16>(unpack4x8snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ unorm16x2 = vec2<f16>(unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ snorm16x2 = vec2<f16>(unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ float16x2 = vec2<f16>(unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ float32 = f16(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ float32x2 = vec2<f16>(vec2<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u)))));
+ float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)] << 8u)))));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 63, 0},
+ {VertexFormat::kUnorm8x4, 63, 1},
+ {VertexFormat::kSnorm8x2, 63, 2},
+ {VertexFormat::kSnorm8x4, 63, 3},
+ {VertexFormat::kUnorm16x2, 63, 4},
+ {VertexFormat::kUnorm16x4, 63, 5},
+ {VertexFormat::kSnorm16x2, 63, 6},
+ {VertexFormat::kSnorm16x4, 63, 7},
+ {VertexFormat::kFloat16x2, 63, 8},
+ {VertexFormat::kFloat16x4, 63, 9},
+ {VertexFormat::kFloat32, 63, 10},
+ {VertexFormat::kFloat32x2, 63, 11},
+ {VertexFormat::kFloat32x3, 63, 12},
+ {VertexFormat::kFloat32x4, 63, 13},
+ }}}};
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(VertexPullingTest, FormatsWithVectorsResized_Padding_SInt) {
auto* src = R"(
@vertex
@@ -1511,6 +1764,112 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Padding_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) vec3_unorm8x2 : vec3<f16>,
+ @location(1) vec4_unorm8x2 : vec4<f16>,
+ @location(2) vec3_snorm8x2 : vec3<f16>,
+ @location(3) vec4_snorm8x2 : vec4<f16>,
+ @location(4) vec3_unorm16x2 : vec3<f16>,
+ @location(5) vec4_unorm16x2 : vec4<f16>,
+ @location(6) vec3_snorm16x2 : vec3<f16>,
+ @location(7) vec4_snorm16x2 : vec4<f16>,
+ @location(8) vec3_float16x2 : vec3<f16>,
+ @location(9) vec4_float16x2 : vec4<f16>,
+ @location(10) vec2_float32 : vec2<f16>,
+ @location(11) vec3_float32 : vec3<f16>,
+ @location(12) vec4_float32 : vec4<f16>,
+ @location(13) vec3_float32x2 : vec3<f16>,
+ @location(14) vec4_float32x2 : vec4<f16>,
+ @location(15) vec4_float32x3 : vec4<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var vec3_unorm8x2 : vec3<f16>;
+ var vec4_unorm8x2 : vec4<f16>;
+ var vec3_snorm8x2 : vec3<f16>;
+ var vec4_snorm8x2 : vec4<f16>;
+ var vec3_unorm16x2 : vec3<f16>;
+ var vec4_unorm16x2 : vec4<f16>;
+ var vec3_snorm16x2 : vec3<f16>;
+ var vec4_snorm16x2 : vec4<f16>;
+ var vec3_float16x2 : vec3<f16>;
+ var vec4_float16x2 : vec4<f16>;
+ var vec2_float32 : vec2<f16>;
+ var vec3_float32 : vec3<f16>;
+ var vec4_float32 : vec4<f16>;
+ var vec3_float32x2 : vec3<f16>;
+ var vec4_float32x2 : vec4<f16>;
+ var vec4_float32x3 : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ vec3_unorm8x2 = vec3<f16>(vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0);
+ vec4_unorm8x2 = vec4<f16>(vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0, 1.0);
+ vec3_snorm8x2 = vec3<f16>(vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0);
+ vec4_snorm8x2 = vec4<f16>(vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0, 1.0);
+ vec3_unorm16x2 = vec3<f16>(vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec4_unorm16x2 = vec4<f16>(vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 1.0);
+ vec3_snorm16x2 = vec3<f16>(vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec4_snorm16x2 = vec4<f16>(vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 1.0);
+ vec3_float16x2 = vec3<f16>(vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec4_float16x2 = vec4<f16>(vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 1.0);
+ vec2_float32 = vec2<f16>(f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec3_float32 = vec3<f16>(f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 0.0);
+ vec4_float32 = vec4<f16>(f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 0.0, 1.0);
+ vec3_float32x2 = vec3<f16>(vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))), 0.0);
+ vec4_float32x2 = vec4<f16>(vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))), 0.0, 1.0);
+ vec4_float32x3 = vec4<f16>(vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]))), 1.0);
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0},
+ {VertexFormat::kUnorm8x2, 64, 1},
+ {VertexFormat::kSnorm8x2, 64, 2},
+ {VertexFormat::kSnorm8x2, 64, 3},
+ {VertexFormat::kUnorm16x2, 64, 4},
+ {VertexFormat::kUnorm16x2, 64, 5},
+ {VertexFormat::kSnorm16x2, 64, 6},
+ {VertexFormat::kSnorm16x2, 64, 7},
+ {VertexFormat::kFloat16x2, 64, 8},
+ {VertexFormat::kFloat16x2, 64, 9},
+ {VertexFormat::kFloat32, 64, 10},
+ {VertexFormat::kFloat32, 64, 11},
+ {VertexFormat::kFloat32, 64, 12},
+ {VertexFormat::kFloat32x2, 64, 13},
+ {VertexFormat::kFloat32x2, 64, 14},
+ {VertexFormat::kFloat32x3, 64, 15},
+ }}}};
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(VertexPullingTest, FormatsWithVectorsResized_Shrinking_SInt) {
auto* src = R"(
@vertex
@@ -1829,5 +2188,139 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Shrinking_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) sclr_unorm8x2 : f16 ,
+ @location(1) sclr_unorm8x4 : f16 ,
+ @location(2) vec2_unorm8x4 : vec2<f16>,
+ @location(3) vec3_unorm8x4 : vec3<f16>,
+ @location(4) sclr_snorm8x2 : f16 ,
+ @location(5) sclr_snorm8x4 : f16 ,
+ @location(6) vec2_snorm8x4 : vec2<f16>,
+ @location(7) vec3_snorm8x4 : vec3<f16>,
+ @location(8) sclr_unorm16x2 : f16 ,
+ @location(9) sclr_unorm16x4 : f16 ,
+ @location(10) vec2_unorm16x4 : vec2<f16>,
+ @location(11) vec3_unorm16x4 : vec3<f16>,
+ @location(12) sclr_snorm16x2 : f16 ,
+ @location(13) sclr_snorm16x4 : f16 ,
+ @location(14) vec2_snorm16x4 : vec2<f16>,
+ @location(15) vec3_snorm16x4 : vec3<f16>,
+ @location(16) sclr_float16x2 : f16 ,
+ @location(17) sclr_float16x4 : f16 ,
+ @location(18) vec2_float16x4 : vec2<f16>,
+ @location(19) vec3_float16x4 : vec3<f16>,
+ @location(20) sclr_float32x2 : f16 ,
+ @location(21) sclr_float32x3 : f16 ,
+ @location(22) vec2_float32x3 : vec2<f16>,
+ @location(23) sclr_float32x4 : f16 ,
+ @location(24) vec2_float32x4 : vec2<f16>,
+ @location(25) vec3_float32x4 : vec3<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var sclr_unorm8x2 : f16;
+ var sclr_unorm8x4 : f16;
+ var vec2_unorm8x4 : vec2<f16>;
+ var vec3_unorm8x4 : vec3<f16>;
+ var sclr_snorm8x2 : f16;
+ var sclr_snorm8x4 : f16;
+ var vec2_snorm8x4 : vec2<f16>;
+ var vec3_snorm8x4 : vec3<f16>;
+ var sclr_unorm16x2 : f16;
+ var sclr_unorm16x4 : f16;
+ var vec2_unorm16x4 : vec2<f16>;
+ var vec3_unorm16x4 : vec3<f16>;
+ var sclr_snorm16x2 : f16;
+ var sclr_snorm16x4 : f16;
+ var vec2_snorm16x4 : vec2<f16>;
+ var vec3_snorm16x4 : vec3<f16>;
+ var sclr_float16x2 : f16;
+ var sclr_float16x4 : f16;
+ var vec2_float16x4 : vec2<f16>;
+ var vec3_float16x4 : vec3<f16>;
+ var sclr_float32x2 : f16;
+ var sclr_float32x3 : f16;
+ var vec2_float32x3 : vec2<f16>;
+ var sclr_float32x4 : f16;
+ var vec2_float32x4 : vec2<f16>;
+ var vec3_float32x4 : vec3<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ sclr_unorm8x2 = vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy).x;
+ sclr_unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ vec2_unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xy;
+ vec3_unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xyz;
+ sclr_snorm8x2 = vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy).x;
+ sclr_snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ vec2_snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xy;
+ vec3_snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xyz;
+ sclr_unorm16x2 = vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ sclr_unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ vec2_unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xy;
+ vec3_unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xyz;
+ sclr_snorm16x2 = vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ sclr_snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ vec2_snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xy;
+ vec3_snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xyz;
+ sclr_float16x2 = vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ sclr_float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ vec2_float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xy;
+ vec3_float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xyz;
+ sclr_float32x2 = vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ sclr_float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]))).x;
+ vec2_float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]))).xy;
+ sclr_float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]))).x;
+ vec2_float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]))).xy;
+ vec3_float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]))).xyz;
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0}, {VertexFormat::kUnorm8x4, 64, 1},
+ {VertexFormat::kUnorm8x4, 64, 2}, {VertexFormat::kUnorm8x4, 64, 3},
+ {VertexFormat::kSnorm8x2, 64, 4}, {VertexFormat::kSnorm8x4, 64, 5},
+ {VertexFormat::kSnorm8x4, 64, 6}, {VertexFormat::kSnorm8x4, 64, 7},
+ {VertexFormat::kUnorm16x2, 64, 8}, {VertexFormat::kUnorm16x4, 64, 9},
+ {VertexFormat::kUnorm16x4, 64, 10}, {VertexFormat::kUnorm16x4, 64, 11},
+ {VertexFormat::kSnorm16x2, 64, 12}, {VertexFormat::kSnorm16x4, 64, 13},
+ {VertexFormat::kSnorm16x4, 64, 14}, {VertexFormat::kSnorm16x4, 64, 15},
+ {VertexFormat::kFloat16x2, 64, 16}, {VertexFormat::kFloat16x4, 64, 17},
+ {VertexFormat::kFloat16x4, 64, 18}, {VertexFormat::kFloat16x4, 64, 19},
+ {VertexFormat::kFloat32x2, 64, 20}, {VertexFormat::kFloat32x3, 64, 21},
+ {VertexFormat::kFloat32x3, 64, 22}, {VertexFormat::kFloat32x4, 64, 23},
+ {VertexFormat::kFloat32x4, 64, 24}, {VertexFormat::kFloat32x4, 64, 25},
+ }}}};
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::transform