perf_tests: add f16 variant of matmul

Change-Id: I41aa1a34521bd13a53f9dda16704b28d19f445b5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/178263
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/dawn/tests/perf_tests/ShaderRobustnessPerf.cpp b/src/dawn/tests/perf_tests/ShaderRobustnessPerf.cpp
index 8057e4f..48f2177 100644
--- a/src/dawn/tests/perf_tests/ShaderRobustnessPerf.cpp
+++ b/src/dawn/tests/perf_tests/ShaderRobustnessPerf.cpp
@@ -49,7 +49,6 @@
     ss << "const kWorkgroupSizeX = " << kWorkgroupSizeX << "u; // 8;\n";
     ss << "const kWorkgroupSizeY = " << kWorkgroupSizeY << "u; // 8;\n";
     ss << R"(
-        alias ElemT = f32;
         struct Uniforms {
             dimAOuter : u32,
             dimInner : u32,
@@ -228,7 +227,6 @@
     ss << "const kWorkgroupSizeX = " << kWorkgroupSizeX << "u; // 8;\n";
     ss << "const kWorkgroupSizeY = " << kWorkgroupSizeY << "u; // 8;\n";
     ss << R"(
-        alias ElemT = f32;
         alias VecT = vec4<ElemT>;
         const VecLen = 4;
         struct Uniforms {
@@ -417,10 +415,32 @@
     return ostream;
 }
 
+enum class ElemType {
+    F32,
+    F16,
+};
+
+std::ostream& operator<<(std::ostream& ostream, const ElemType& elemType) {
+    switch (elemType) {
+        case ElemType::F32:
+            ostream << "f32";
+            break;
+        case ElemType::F16:
+            ostream << "f16";
+            break;
+    }
+    return ostream;
+}
+
 using DimAOuter = uint32_t;
 using DimInner = uint32_t;
 using DimBOuter = uint32_t;
-DAWN_TEST_PARAM_STRUCT(ShaderRobustnessParams, MatMulMethod, DimAOuter, DimInner, DimBOuter);
+DAWN_TEST_PARAM_STRUCT(ShaderRobustnessParams,
+                       MatMulMethod,
+                       ElemType,
+                       DimAOuter,
+                       DimInner,
+                       DimBOuter);
 
 // Test the execution time of matrix multiplication (A [dimAOuter, dimInner] * B [dimInner,
 // dimBOuter]) on the GPU and see the difference between robustness on and off.
@@ -428,6 +448,7 @@
   public:
     ShaderRobustnessPerf()
         : DawnPerfTestWithParams(kNumIterations, 1),
+          mElemType(GetParam().mElemType),
           mDimAOuter(GetParam().mDimAOuter),
           mDimInner(GetParam().mDimInner),
           mDimBOuter(GetParam().mDimBOuter) {}
@@ -435,11 +456,29 @@
 
     void SetUp() override;
 
+  protected:
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        auto requirements = DawnPerfTestWithParams<ShaderRobustnessParams>::GetRequiredFeatures();
+        if ((GetParam().mElemType == ElemType::F16) &&
+            SupportsFeatures({wgpu::FeatureName::ShaderF16})) {
+            requirements.push_back(wgpu::FeatureName::ShaderF16);
+        }
+        return requirements;
+    }
+
   private:
     void Step() override;
 
+    // Returns the shader prefix required for parameters.
+    std::string GetShaderPreamble();
+    // Returns the shader body.
+    std::string GetShaderBody();
+    // Returns the shader source.
+    std::string GetShader();
+
     wgpu::BindGroup mBindGroup;
     wgpu::ComputePipeline mPipeline;
+    ElemType mElemType;
     uint32_t mDimAOuter;
     uint32_t mDimInner;
     uint32_t mDimBOuter;
@@ -448,6 +487,9 @@
 void ShaderRobustnessPerf::SetUp() {
     DawnPerfTestWithParams<ShaderRobustnessParams>::SetUp();
 
+    DAWN_TEST_UNSUPPORTED_IF((GetParam().mElemType == ElemType::F16) &&
+                             !SupportsFeatures({wgpu::FeatureName::ShaderF16}));
+
     const size_t dataASize = mDimAOuter * mDimInner;
     std::vector<float> dataA(dataASize);
     uint64_t byteASize = sizeof(float) * dataA.size();
@@ -471,29 +513,7 @@
     wgpu::Buffer uniformBuffer = utils::CreateBufferFromData(
         device, uniformData, sizeof(uniformData), wgpu::BufferUsage::Uniform);
 
-    std::string shader;
-    switch (GetParam().mMatMulMethod) {
-        case MatMulMethod::MatMulFloatOneDimSharedArray: {
-            shader = GenMatMulFloatOneDimensionalSharedArray();
-            break;
-        }
-
-        case MatMulMethod::MatMulFloatTwoDimSharedArray: {
-            shader = GenMatMulFloatTwoDimensionalSharedArray();
-            break;
-        }
-
-        case MatMulMethod::MatMulVec4OneDimSharedArray: {
-            shader = GenMatMulVec4OneDimensionalSharedArray();
-            break;
-        }
-
-        case MatMulMethod::MatMulVec4TwoDimSharedArray: {
-            shader = GenMatMulVec4TwoDimensionalSharedArray();
-            break;
-        }
-    }
-    wgpu::ShaderModule module = utils::CreateShaderModule(device, shader.c_str());
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, GetShader().c_str());
 
     wgpu::ComputePipelineDescriptor csDesc;
     csDesc.compute.module = module;
@@ -508,6 +528,34 @@
                                       });
 }
 
+std::string ShaderRobustnessPerf::GetShaderPreamble() {
+    switch (mElemType) {
+        case ElemType::F32:
+            return "alias ElemT = f32;\n";
+        case ElemType::F16:
+            return "enable f16;\nalias ElemT = f16;\n";
+    }
+    DAWN_UNREACHABLE();
+}
+
+std::string ShaderRobustnessPerf::GetShaderBody() {
+    switch (GetParam().mMatMulMethod) {
+        case MatMulMethod::MatMulFloatOneDimSharedArray:
+            return GenMatMulFloatOneDimensionalSharedArray();
+        case MatMulMethod::MatMulFloatTwoDimSharedArray:
+            return GenMatMulFloatTwoDimensionalSharedArray();
+        case MatMulMethod::MatMulVec4OneDimSharedArray:
+            return GenMatMulVec4OneDimensionalSharedArray();
+        case MatMulMethod::MatMulVec4TwoDimSharedArray:
+            return GenMatMulVec4TwoDimensionalSharedArray();
+    }
+    DAWN_UNREACHABLE();
+}
+
+std::string ShaderRobustnessPerf::GetShader() {
+    return GetShaderPreamble() + GetShaderBody();
+}
+
 void ShaderRobustnessPerf::Step() {
     bool useTimestamps = SupportsTimestampQuery();
 
@@ -555,6 +603,7 @@
                          MatMulMethod::MatMulFloatTwoDimSharedArray,
                          MatMulMethod::MatMulVec4OneDimSharedArray,
                          MatMulMethod::MatMulVec4TwoDimSharedArray},
+                        {ElemType::F32, ElemType::F16},
                         {512u},
                         {512u},
                         {512u});