Compat: Pipeline error for textureLoad with depth texture

The spec changed so using a depth texture with textureLoad
should generate an error at pipeline creation time instead
of shader module creation time.

Bug: 357042305
Fixes: 357042305
Change-Id: I1ec148ebdd0efccf0f87840219ef6c3d7fa7a926
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202255
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Gregg Tavares <gman@chromium.org>
Reviewed-by: Gregg Tavares <gman@chromium.org>
diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp
index 51adb45..b7b855b 100644
--- a/src/dawn/native/Pipeline.cpp
+++ b/src/dawn/native/Pipeline.cpp
@@ -94,6 +94,11 @@
         DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
     }
 
+    DAWN_INVALID_IF(device->IsCompatibilityMode() && metadata.usesTextureLoadWithDepthTexture,
+                    "textureLoad can not be used with depth textures in compatibility mode in "
+                    "stage (%s), entry point \"%s\"",
+                    metadata.stage, entryPoint.name);
+
     // Validate if overridable constants exist in shader module
     // pipelineBase is not yet constructed at this moment so iterate constants from descriptor
     size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index f779fd8..7f96311 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -670,6 +670,8 @@
         metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
     }
 
+    metadata->usesTextureLoadWithDepthTexture = entryPoint.has_texture_load_with_depth_texture;
+
     const CombinedLimits& limits = device->GetLimits();
     const uint32_t maxVertexAttributes = limits.v1.maxVertexAttributes;
     const uint32_t maxInterStageShaderVariables = limits.v1.maxInterStageShaderVariables;
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index f62158f..8c4e797 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -279,6 +279,7 @@
     bool usesSampleMaskOutput = false;
     bool usesSampleIndex = false;
     bool usesVertexIndex = false;
+    bool usesTextureLoadWithDepthTexture = false;
 };
 
 class ShaderModuleBase : public RefCountedWithExternalCount<ApiObjectBase>,
diff --git a/src/dawn/tests/unittests/validation/CompatValidationTests.cpp b/src/dawn/tests/unittests/validation/CompatValidationTests.cpp
index ac6ed9e..f33574a 100644
--- a/src/dawn/tests/unittests/validation/CompatValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/CompatValidationTests.cpp
@@ -213,7 +213,48 @@
     ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&testDescriptor));
 }
 
-TEST_F(CompatValidationTest, CanNotUseFragmentShaderWithSampleMask) {
+TEST_F(CompatValidationTest, CanNotCreatePipelineWithTextureLoadOfDepthTexture) {
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        @group(0) @binding(0) var<storage, read_write> dstBuf : array<vec4f>;
+        @group(0) @binding(1) var tex1 : texture_2d<f32>;
+        @group(0) @binding(2) var tex2 : texture_depth_2d;
+        @group(0) @binding(3) var tex3 : texture_depth_2d_array;
+        @group(0) @binding(4) var tex4 : texture_depth_multisampled_2d;
+
+        @compute @workgroup_size(1) fn main1() {
+            dstBuf[0] = textureLoad(tex1, vec2(0), 0);
+        }
+
+        @compute @workgroup_size(1) fn main2() {
+            dstBuf[0] = vec4f(textureLoad(tex2, vec2(0), 0));
+        }
+
+        @compute @workgroup_size(1) fn main3() {
+            dstBuf[0] = vec4f(textureLoad(tex3, vec2(0), 0, 0));
+        }
+
+        @compute @workgroup_size(1) fn main4() {
+            dstBuf[4] = vec4f(textureLoad(tex4, vec2(0), 0));
+        }
+    )");
+
+    const char* entryPoints[] = {"main1", "main2", "main3", "main4"};
+    for (auto entryPoint : entryPoints) {
+        wgpu::ComputePipelineDescriptor pDesc;
+        pDesc.compute.module = module;
+        pDesc.compute.entryPoint = entryPoint;
+        if (entryPoint == entryPoints[0]) {
+            device.CreateComputePipeline(&pDesc);
+        } else {
+            ASSERT_DEVICE_ERROR(
+                device.CreateComputePipeline(&pDesc),
+                testing::HasSubstr(
+                    "textureLoad can not be used with depth textures in compatibility mode"));
+        }
+    }
+}
+
+TEST_F(CompatValidationTest, CanNotUseSampleMask) {
     wgpu::ShaderModule moduleSampleMaskOutput = utils::CreateShaderModule(device, R"(
         @vertex fn vs() -> @builtin(position) vec4f {
             return vec4f(1);
diff --git a/src/tint/lang/wgsl/inspector/entry_point.h b/src/tint/lang/wgsl/inspector/entry_point.h
index f64d55c..4f7a605 100644
--- a/src/tint/lang/wgsl/inspector/entry_point.h
+++ b/src/tint/lang/wgsl/inspector/entry_point.h
@@ -207,6 +207,8 @@
     bool vertex_index_used = false;
     /// Does the entry point use the instance_index builtin
     bool instance_index_used = false;
+    /// Does the entry point have a textureLoad call with a texture_depth??? texture
+    bool has_texture_load_with_depth_texture = false;
     /// The array length of the clip_distances builtin. Holding no value means the clip_distances
     /// is not used.
     std::optional<uint32_t> clip_distances_size;
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index 83d045a..79e01a8 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -27,6 +27,7 @@
 
 #include "src/tint/lang/wgsl/inspector/inspector.h"
 
+#include <algorithm>
 #include <unordered_set>
 #include <utility>
 
@@ -246,6 +247,23 @@
         }
     }
 
+    {
+        auto filter = [](const tint::sem::Call* call,
+                         tint::wgsl::BuiltinFn builtin_fn) -> std::optional<TextureUsageType> {
+            if (builtin_fn == wgsl::BuiltinFn::kTextureLoad) {
+                if (call->Arguments()[0]
+                        ->Type()
+                        ->IsAnyOf<core::type::DepthTexture,
+                                  core::type::DepthMultisampledTexture>()) {
+                    return TextureUsageType::kTextureLoad;
+                }
+            }
+            return {};
+        };
+        entry_point.has_texture_load_with_depth_texture =
+            !GetTextureUsagesForEntryPoint(*func, filter).empty();
+    }
+
     return entry_point;
 }
 
@@ -1081,32 +1099,87 @@
 }
 
 std::vector<Inspector::LevelSampleInfo> Inspector::GetTextureQueries(const std::string& ep_name) {
-    std::vector<LevelSampleInfo> res;
-
-    std::unordered_set<BindingPoint> seen = {};
-
-    Hashmap<const sem::Function*, Hashmap<const ast::Parameter*, TextureQueryType, 4>, 8>
-        fn_to_data;
-
-    auto record_function_param = [&fn_to_data](const sem::Function* func,
-                                               const ast::Parameter* param, TextureQueryType type) {
-        fn_to_data.GetOrAddZero(func).Add(param, type);
-    };
-
-    auto save_if_needed = [&res, &seen](const sem::GlobalVariable* global, TextureQueryType type) {
-        auto binding = global->Attributes().binding_point.value();
-        if (seen.insert(binding).second) {
-            res.emplace_back(LevelSampleInfo{type, binding.group, binding.binding});
-        }
-    };
-
-    auto& sem = program_.Sem();
-
     const auto* ep = FindEntryPointByName(ep_name);
     if (!ep) {
         return {};
     }
 
+    auto filter = [&](const tint::sem::Call* call,
+                      tint::wgsl::BuiltinFn builtin_fn) -> std::optional<TextureUsageType> {
+        switch (builtin_fn) {
+            case wgsl::BuiltinFn::kTextureNumLevels: {
+                return TextureUsageType::kTextureNumLevels;
+            }
+            case wgsl::BuiltinFn::kTextureDimensions: {
+                if (call->Declaration()->args.Length() <= 1) {
+                    // When textureDimension only takes a texture as the input,
+                    // it doesn't require calls to textureNumLevels to clamp mip levels.
+                    return {};
+                }
+                return TextureUsageType::kTextureNumLevels;
+            }
+            case wgsl::BuiltinFn::kTextureLoad: {
+                if (call->Arguments()[0]
+                        ->Type()
+                        ->IsAnyOf<core::type::MultisampledTexture,
+                                  core::type::DepthMultisampledTexture>()) {
+                    // When textureLoad takes a multisampled texture as the input,
+                    // it doesn't require to query the mip level.
+                    return {};
+                }
+                return TextureUsageType::kTextureNumLevels;
+            }
+            case wgsl::BuiltinFn::kTextureNumSamples: {
+                return TextureUsageType::kTextureNumSamples;
+            }
+            default:
+                return {};
+        }
+    };
+
+    auto usages = GetTextureUsagesForEntryPoint(*ep, filter);
+
+    auto t = [](const TextureUsageInfo& info) -> LevelSampleInfo {
+        return {
+            info.type == TextureUsageType::kTextureNumSamples ? TextureQueryType::kTextureNumSamples
+                                                              : TextureQueryType::kTextureNumLevels,
+            info.group,
+            info.binding,
+        };
+    };
+
+    std::vector<LevelSampleInfo> res;
+    std::transform(usages.begin(), usages.end(), std::back_inserter(res), t);
+    return res;
+}
+
+std::vector<Inspector::TextureUsageInfo> Inspector::GetTextureUsagesForEntryPoint(
+    const tint::ast::Function& ep,
+    std::function<std::optional<TextureUsageType>(const tint::sem::Call* call,
+                                                  tint::wgsl::BuiltinFn builtin_fn)> filter) {
+    TINT_ASSERT(ep.IsEntryPoint());
+
+    std::vector<TextureUsageInfo> res;
+
+    std::unordered_set<BindingPoint> seen = {};
+
+    Hashmap<const sem::Function*, Hashmap<const ast::Parameter*, TextureUsageType, 4>, 8>
+        fn_to_data;
+
+    auto record_function_param = [&fn_to_data](const sem::Function* func,
+                                               const ast::Parameter* param, TextureUsageType type) {
+        fn_to_data.GetOrAddZero(func).Add(param, type);
+    };
+
+    auto save_if_needed = [&res, &seen](const sem::GlobalVariable* global, TextureUsageType type) {
+        auto binding = global->Attributes().binding_point.value();
+        if (seen.insert(binding).second) {
+            res.emplace_back(TextureUsageInfo{type, binding.group, binding.binding});
+        }
+    };
+
+    auto& sem = program_.Sem();
+
     // This works in dependency order such that we'll see the texture call first and can record
     // any function parameter information and then as we walk up the function chain we can look
     // the call data.
@@ -1118,16 +1191,32 @@
 
         // This is an entrypoint, make sure it's the requested entry point
         if (fn->Declaration()->IsEntryPoint()) {
-            if (fn->Declaration() != ep) {
+            if (fn->Declaration() != &ep) {
                 continue;
             }
         } else {
             // Not an entry point, make sure it was called from the requested entry point
-            if (!fn->HasAncestorEntryPoint(ep->name->symbol)) {
+            if (!fn->HasAncestorEntryPoint(ep.name->symbol)) {
                 continue;
             }
         }
 
+        auto queryTextureBuiltin = [&](TextureUsageType type, const sem::Call* builtin_call,
+                                       const sem::Variable* texture_sem = nullptr) {
+            TINT_ASSERT(builtin_call);
+            if (!texture_sem) {
+                auto* texture_expr = builtin_call->Declaration()->args[0];
+                texture_sem = sem.GetVal(texture_expr)->RootIdentifier();
+            }
+            tint::Switch(
+                texture_sem,  //
+                [&](const sem::GlobalVariable* global) { save_if_needed(global, type); },
+                [&](const sem::Parameter* param) {
+                    record_function_param(fn, param->Declaration(), type);
+                },
+                TINT_ICE_ON_NO_MATCH);
+        };
+
         for (auto* call : fn->DirectCalls()) {
             // Builtin function call, record the texture information. If the used texture maps
             // back up to a function parameter just store the type of the call and we'll track the
@@ -1135,61 +1224,9 @@
             tint::Switch(
                 call->Target(),
                 [&](const sem::BuiltinFn* builtin) {
-                    auto queryTextureBuiltin = [&](TextureQueryType type,
-                                                   const sem::Call* builtin_call,
-                                                   const sem::Variable* texture_sem = nullptr) {
-                        TINT_ASSERT(builtin_call);
-                        if (!texture_sem) {
-                            auto* texture_expr = builtin_call->Declaration()->args[0];
-                            texture_sem = sem.GetVal(texture_expr)->RootIdentifier();
-                        }
-                        tint::Switch(
-                            texture_sem,  //
-                            [&](const sem::GlobalVariable* global) {
-                                save_if_needed(global, type);
-                            },
-                            [&](const sem::Parameter* param) {
-                                record_function_param(fn, param->Declaration(), type);
-                            },
-                            TINT_ICE_ON_NO_MATCH);
-                    };
-
-                    switch (builtin->Fn()) {
-                        case wgsl::BuiltinFn::kTextureNumLevels: {
-                            queryTextureBuiltin(TextureQueryType::kTextureNumLevels, call);
-                            break;
-                        }
-                        case wgsl::BuiltinFn::kTextureDimensions: {
-                            if (call->Declaration()->args.Length() <= 1) {
-                                // When textureDimension only takes a texture as the input,
-                                // it doesn't require calls to textureNumLevels to clamp mip levels.
-                                return;
-                            }
-                            queryTextureBuiltin(TextureQueryType::kTextureNumLevels, call);
-                            break;
-                        }
-                        case wgsl::BuiltinFn::kTextureLoad: {
-                            auto* texture_expr = call->Declaration()->args[0];
-                            auto* texture_sem = sem.GetVal(texture_expr)->RootIdentifier();
-                            TINT_ASSERT(texture_sem);
-                            if (texture_sem->Type()
-                                    ->UnwrapRef()
-                                    ->IsAnyOf<core::type::MultisampledTexture,
-                                              core::type::DepthMultisampledTexture>()) {
-                                // When textureLoad takes a multisampled texture as the input,
-                                // it doesn't require to query the mip level.
-                                return;
-                            }
-                            queryTextureBuiltin(TextureQueryType::kTextureNumLevels, call,
-                                                texture_sem);
-                            break;
-                        }
-                        case wgsl::BuiltinFn::kTextureNumSamples: {
-                            queryTextureBuiltin(TextureQueryType::kTextureNumSamples, call);
-                            break;
-                        }
-                        default:
-                            return;
+                    auto type = filter(call, builtin->Fn());
+                    if (type) {
+                        queryTextureBuiltin(*type, call);
                     }
                 },
                 [&](const sem::Function* func) {
diff --git a/src/tint/lang/wgsl/inspector/inspector.h b/src/tint/lang/wgsl/inspector/inspector.h
index a55e1a7..5b90b17 100644
--- a/src/tint/lang/wgsl/inspector/inspector.h
+++ b/src/tint/lang/wgsl/inspector/inspector.h
@@ -43,6 +43,7 @@
 #include "src/tint/lang/wgsl/inspector/resource_binding.h"
 #include "src/tint/lang/wgsl/inspector/scalar.h"
 #include "src/tint/lang/wgsl/program/program.h"
+#include "src/tint/lang/wgsl/sem/call.h"
 #include "src/tint/lang/wgsl/sem/sampler_texture_pair.h"
 #include "src/tint/utils/containers/unique_vector.h"
 
@@ -180,7 +181,7 @@
         uint32_t binding = 0;
     };
 
-    /// @param ep the entry point ot get the information for
+    /// @param ep the entry point to get the information for
     /// @returns a vector of information for textures which call textureNumLevels and
     /// textureNumSamples for backends which require additional support for those methods. Each
     /// binding point will only be returned once regardless of the number of calls made. The
@@ -299,6 +300,30 @@
     /// @param func the function of the entry point. Must be non-nullptr and true for IsEntryPoint()
     /// @returns the entry point information
     EntryPoint GetEntryPoint(const tint::ast::Function* func);
+
+    /// The information needed to be supplied.
+    enum class TextureUsageType : uint8_t {
+        /// textureLoad
+        kTextureLoad,
+        /// textureNumLevels
+        kTextureNumLevels,
+        /// textureNumSamples
+        kTextureNumSamples,
+    };
+    /// Information on level and sample calls by a given texture binding point
+    struct TextureUsageInfo {
+        /// The type of function
+        TextureUsageType type = TextureUsageType::kTextureNumLevels;
+        /// The group number
+        uint32_t group = 0;
+        /// The binding number
+        uint32_t binding = 0;
+    };
+
+    std::vector<Inspector::TextureUsageInfo> GetTextureUsagesForEntryPoint(
+        const tint::ast::Function& ep,
+        std::function<std::optional<TextureUsageType>(const tint::sem::Call* call,
+                                                      tint::wgsl::BuiltinFn builtin_fn)> filter);
 };
 
 }  // namespace tint::inspector
diff --git a/src/tint/lang/wgsl/resolver/compatibility_mode_test.cc b/src/tint/lang/wgsl/resolver/compatibility_mode_test.cc
index d3644e8..34ed0fe 100644
--- a/src/tint/lang/wgsl/resolver/compatibility_mode_test.cc
+++ b/src/tint/lang/wgsl/resolver/compatibility_mode_test.cc
@@ -246,79 +246,5 @@
         R"(12:34 error: flat interpolation must use 'either' sampling parameter in compatibility mode)");
 }
 
-class ResolverCompatibilityModeTest_TextureLoad : public ResolverCompatibilityModeTest {
-  protected:
-    void add_call_param(std::string name, ast::Type type, ExpressionList* call_params) {
-        const std::string type_name = type->identifier->symbol.Name();
-        if (tint::HasPrefix(type_name, "texture")) {
-            GlobalVar(name, type, Binding(0_a), Group(0_a));
-        } else {
-            GlobalVar(name, type, core::AddressSpace::kPrivate);
-        }
-        call_params->Push(Expr(Source{{12, 34}}, name));
-    }
-};
-
-TEST_F(ResolverCompatibilityModeTest_TextureLoad, TextureDepth2D) {
-    // textureLoad(someDepthTexture2D, coords, level)
-    const ast::Type coords_type = ty.vec2(ty.i32());
-    auto texture_type = ty.depth_texture(core::type::TextureDimension::k2d);
-
-    ExpressionList call_params;
-
-    add_call_param("texture", texture_type, &call_params);
-    add_call_param("coords", coords_type, &call_params);
-    add_call_param("level", ty.i32(), &call_params);
-
-    auto* expr = Call("textureLoad", call_params);
-    WrapInFunction(expr);
-
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(
-        r()->error(),
-        R"(12:34 error: use of texture_depth_2d with textureLoad is not allowed in compatibility mode)");
-}
-
-TEST_F(ResolverCompatibilityModeTest_TextureLoad, TextureDepth2DArray) {
-    // textureLoad(someDepthTexture2DArray, coords, layer, level)
-    const ast::Type coords_type = ty.vec2(ty.i32());
-    auto texture_type = ty.depth_texture(core::type::TextureDimension::k2dArray);
-
-    ExpressionList call_params;
-
-    add_call_param("texture", texture_type, &call_params);
-    add_call_param("coords", coords_type, &call_params);
-    add_call_param("array_index", ty.i32(), &call_params);
-    add_call_param("level", ty.i32(), &call_params);
-
-    auto* expr = Call("textureLoad", call_params);
-    WrapInFunction(expr);
-
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(
-        r()->error(),
-        R"(12:34 error: use of texture_depth_2d_array with textureLoad is not allowed in compatibility mode)");
-}
-
-TEST_F(ResolverCompatibilityModeTest_TextureLoad, TextureDepthMultisampled2D) {
-    // textureLoad(someDepthTextureMultisampled2D, coords, sample_index)
-    const ast::Type coords_type = ty.vec2(ty.i32());
-    auto texture_type = ty.depth_multisampled_texture(core::type::TextureDimension::k2d);
-
-    ExpressionList call_params;
-
-    add_call_param("texture", texture_type, &call_params);
-    add_call_param("coords", coords_type, &call_params);
-    add_call_param("sample_index", ty.i32(), &call_params);
-
-    auto* expr = Call("textureLoad", call_params);
-    WrapInFunction(expr);
-
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(
-        r()->error(),
-        R"(12:34 error: use of texture_depth_multisampled_2d with textureLoad is not allowed in compatibility mode)");
-}
-
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 92004bb..6ec74c9 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -35,7 +35,6 @@
 #include <utility>
 
 #include "src/tint/lang/core/fluent_types.h"
-#include "src/tint/lang/core/parameter_usage.h"
 #include "src/tint/lang/core/type/abstract_numeric.h"
 #include "src/tint/lang/core/type/atomic.h"
 #include "src/tint/lang/core/type/depth_multisampled_texture.h"
@@ -68,7 +67,6 @@
 #include "src/tint/lang/wgsl/ast/unary_op_expression.h"
 #include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
 #include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
-#include "src/tint/lang/wgsl/builtin_fn.h"
 #include "src/tint/lang/wgsl/sem/array.h"
 #include "src/tint/lang/wgsl/sem/break_if_statement.h"
 #include "src/tint/lang/wgsl/sem/call.h"
@@ -1912,19 +1910,6 @@
     std::string func_name = builtin->str();
     auto& signature = builtin->Signature();
 
-    if (mode_ == wgsl::ValidationMode::kCompat) {
-        if (builtin->Fn() == wgsl::BuiltinFn::kTextureLoad) {
-            auto* arg = call->Arguments()[0];
-            if (arg->Type()
-                    ->IsAnyOf<core::type::DepthTexture, core::type::DepthMultisampledTexture>()) {
-                AddError(arg->Declaration()->source)
-                    << "use of " << arg->Type()->FriendlyName()
-                    << " with textureLoad is not allowed in compatibility mode";
-                return false;
-            }
-        }
-    }
-
     auto check_arg_is_constexpr = [&](core::ParameterUsage usage, int min, int max) {
         auto signed_index = signature.IndexOf(usage);
         if (signed_index < 0) {