[tint][val] Check that the type for @location is valid

Enforces that it is only attached to numeric scalars or vectors.

Also adds a capability to allow location on structs, matrices, and
arrays when they have only numeric elements, so that the 'from SPIRV'
path can handle inputs with this construction, instead of immediately
bailing.

The duplicate annotation test needed to be updated, since it
previously depended on location being attached to a struct directly
was valid, which is no longer allowed.

Fixes: 440157916
Change-Id: I73c4e2eeee5d2d93b7d3d6948889966b82f866dd
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/261354
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/lang/core/ir/transform/dead_code_elimination.h b/src/tint/lang/core/ir/transform/dead_code_elimination.h
index 190acec..a471e55 100644
--- a/src/tint/lang/core/ir/transform/dead_code_elimination.h
+++ b/src/tint/lang/core/ir/transform/dead_code_elimination.h
@@ -47,6 +47,7 @@
     core::ir::Capability::kAllowUnannotatedModuleIOVariables,
     core::ir::Capability::kAllowNonCoreTypes,
     core::ir::Capability::kAllowStructMatrixDecorations,
+    core::ir::Capability::kAllowLocationForNumericElements,
 };
 
 /// DeadCodeElimination is a transform that removes dead code from the given IR module.
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 1791dc0..342c578 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -2911,6 +2911,7 @@
     if (binding_point.has_value()) {
         annotations.Add(IOAnnotation::kBindingPoint);
     }
+
     if (auto* mv = ty->As<core::type::MemoryView>()) {
         if (mv->AddressSpace() == AddressSpace::kWorkgroup) {
             annotations.Add(IOAnnotation::kWorkgroup);
@@ -2924,6 +2925,45 @@
         return Success;
     }
 
+    if (attr.location.has_value()) {
+        if (capabilities_.Contains(Capability::kAllowLocationForNumericElements)) {
+            std::function<bool(const core::type::Type*)> is_numeric =
+                [&is_numeric](const core::type::Type* t) -> bool {
+                t = t->UnwrapPtrOrRef();
+                bool result = false;
+                tint::Switch(
+                    t,
+                    [&](const core::type::Struct* s) {
+                        for (auto* m : s->Members()) {
+                            if (!is_numeric(m->Type())) {
+                                return;
+                            }
+                        }
+                        result = true;
+                    },
+                    [&](Default) {
+                        auto* e = t->DeepestElement()->UnwrapPtrOrRef();
+                        tint::Switch(
+                            e, [&](const core::type::Struct* s) { result = is_numeric(s); },
+                            [&](Default) { result = e->IsNumericScalarOrVector(); });
+                    });
+                return result;
+            };
+            if (!is_numeric(ty)) {
+                return ToString(kind) +
+                       " with a location attribute must contain only numeric elements " +
+                       ty->FriendlyName();
+            }
+        } else {
+            if (!ty->UnwrapPtrOrRef()->IsNumericScalarOrVector()) {
+                return ToString(kind) +
+                       " with a location attribute must be a numeric scalar or vector, but has "
+                       "type " +
+                       ty->FriendlyName();
+            }
+        }
+    }
+
     if (auto* ty_struct = ty->UnwrapPtrOrRef()->As<core::type::Struct>()) {
         for (const auto* mem : ty_struct->Members()) {
             EnumSet<IOAnnotation> mem_annotations = annotations;
@@ -2934,6 +2974,25 @@
                        ToString(add_result.Failure()) + "'";
             }
 
+            if (mem->Attributes().location.has_value()) {
+                if (capabilities_.Contains(Capability::kAllowLocationForNumericElements)) {
+                    if (!mem->Type()->UnwrapPtrOrRef()->IsNumericScalarOrVector() &&
+                        !mem->Type()->UnwrapPtrOrRef()->Is<core::type::Struct>()) {
+                        return ToString(kind) +
+                               " struct member with a location attribute must be a numeric scalar, "
+                               "a numeric vector or a struct, but has type " +
+                               mem->Type()->FriendlyName();
+                    }
+                } else {
+                    if (!mem->Type()->UnwrapPtrOrRef()->IsNumericScalarOrVector()) {
+                        return ToString(kind) +
+                               " struct member with a location attribute must be a numeric scalar "
+                               "or vector, but has type " +
+                               mem->Type()->FriendlyName();
+                    }
+                }
+            }
+
             if (capabilities_.Contains(Capability::kAllowPointersAndHandlesInStructures)) {
                 if (auto* mv = mem->Type()->As<core::type::MemoryView>()) {
                     if (mv->AddressSpace() == AddressSpace::kWorkgroup) {
diff --git a/src/tint/lang/core/ir/validator.h b/src/tint/lang/core/ir/validator.h
index bed083f..fedfacd 100644
--- a/src/tint/lang/core/ir/validator.h
+++ b/src/tint/lang/core/ir/validator.h
@@ -81,6 +81,8 @@
     kAllowNonCoreTypes,
     /// Allows matrix annotations on structure members
     kAllowStructMatrixDecorations,
+    /// Allows @location on structs, matrices, and arrays that have numeric elements
+    kAllowLocationForNumericElements,
 };
 
 /// Capabilities is a set of Capability
diff --git a/src/tint/lang/core/ir/validator_function_test.cc b/src/tint/lang/core/ir/validator_function_test.cc
index d0db3ce..463fec0 100644
--- a/src/tint/lang/core/ir/validator_function_test.cc
+++ b/src/tint/lang/core/ir/validator_function_test.cc
@@ -333,13 +333,13 @@
 TEST_F(IR_ValidatorTest, Function_Param_Struct_DuplicateAnnotations) {
     auto* f = ComputeEntryPoint("my_func");
     IOAttributes attr;
-    attr.location = 0;
+    attr.builtin = BuiltinValue::kPosition;
     auto* str_ty =
         ty.Struct(mod.symbols.New("MyStruct"), {
                                                    {mod.symbols.New("a"), ty.vec4<f32>(), attr},
                                                });
     auto* p = b.FunctionParam("my_param", str_ty);
-    p->SetLocation(0);
+    p->SetBuiltin(BuiltinValue::kPosition);
     f->SetParams({p});
 
     b.Append(f->Block(), [&] { b.Return(f); });
@@ -349,8 +349,8 @@
     EXPECT_THAT(
         res.Failure().reason,
         testing::HasSubstr(
-            R"(:5:54 error: input param struct member has same IO annotation, as top-level struct, '@location'
-%my_func = @compute @workgroup_size(1u, 1u, 1u) func(%my_param:MyStruct [@location(0)]):void {
+            R"(:5:54 error: input param struct member has same IO annotation, as top-level struct, 'built-in'
+%my_func = @compute @workgroup_size(1u, 1u, 1u) func(%my_param:MyStruct [@position]):void {
                                                      ^^^^^^^^^^^^^^^^^^
 )")) << res.Failure();
 }
@@ -398,6 +398,90 @@
 )")) << res.Failure();
 }
 
+TEST_F(IR_ValidatorTest, Function_Param_Location_InvalidType) {
+    auto* f = FragmentEntryPoint("my_func");
+
+    auto* p = b.FunctionParam("my_param", ty.bool_());
+    p->SetLocation(0);
+    f->SetParams({p});
+
+    b.Append(f->Block(), [&] { b.Return(f); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:1:27 error: fragment entry point params can only be a bool if decorated with @builtin(front_facing)
+%my_func = @fragment func(%my_param:bool [@location(0)]):void {
+                          ^^^^^^^^^^^^^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Function_Param_Struct_Location_InvalidType) {
+    auto* f = FragmentEntryPoint("my_func");
+
+    IOAttributes attr;
+    attr.location = 0;
+    auto* str_ty =
+        ty.Struct(mod.symbols.New("MyStruct"), {
+                                                   {mod.symbols.New("a"), ty.bool_(), attr},
+                                               });
+    auto* p = b.FunctionParam("my_param", str_ty);
+    f->SetParams({p});
+
+    b.Append(f->Block(), [&] { b.Return(f); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:5:27 error: fragment entry point param members can only be a bool if decorated with @builtin(front_facing)
+%my_func = @fragment func(%my_param:MyStruct):void {
+                          ^^^^^^^^^^^^^^^^^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Function_Param_Location_Struct_WithCapability) {
+    auto* f = FragmentEntryPoint("my_func");
+
+    auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
+                                                              {mod.symbols.New("a"), ty.f32()},
+                                                          });
+    auto* p = b.FunctionParam("my_param", str_ty);
+    p->SetLocation(0);
+    f->SetParams({p});
+
+    b.Append(f->Block(), [&] { b.Return(f); });
+
+    auto res = ir::Validate(mod, Capabilities{Capability::kAllowLocationForNumericElements});
+    ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Function_Param_Location_Struct_WithoutCapability) {
+    auto* f = FragmentEntryPoint("my_func");
+
+    auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
+                                                              {mod.symbols.New("a"), ty.f32()},
+                                                          });
+    auto* p = b.FunctionParam("my_param", str_ty);
+    p->SetLocation(0);
+    f->SetParams({p});
+
+    b.Append(f->Block(), [&] { b.Return(f); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:5:27 error: input param with a location attribute must be a numeric scalar or vector, but has type MyStruct
+%my_func = @fragment func(%my_param:MyStruct [@location(0)]):void {
+                          ^^^^^^^^^^^^^^^^^^
+)")) << res.Failure();
+}
+
 TEST_F(IR_ValidatorTest, Function_ParameterWithConstructibleType) {
     auto* f = b.Function("my_func", ty.void_());
     auto* p = b.FunctionParam("my_param", ty.u32());
@@ -616,6 +700,44 @@
 )")) << res.Failure();
 }
 
+TEST_F(IR_ValidatorTest, Function_Return_Location_InvalidType) {
+    auto* f = FragmentEntryPoint("my_func");
+    f->SetReturnType(ty.bool_());
+    f->SetReturnLocation(0);
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:1:1 error: return value with a location attribute must be a numeric scalar or vector, but has type bool
+%my_func = @fragment func():bool [@location(0)] {
+^^^^^^^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Function_Return_Struct_Location_InvalidType) {
+    IOAttributes attr;
+    attr.location = 0;
+    auto* str_ty =
+        ty.Struct(mod.symbols.New("MyStruct"), {
+                                                   {mod.symbols.New("a"), ty.bool_(), attr},
+                                               });
+    auto* f = b.Function("my_func", str_ty, Function::PipelineStage::kFragment);
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:5:1 error: return value struct member with a location attribute must be a numeric scalar or vector, but has type bool
+%my_func = @fragment func():MyStruct {
+^^^^^^^^
+)")) << res.Failure();
+}
+
 TEST_F(IR_ValidatorTest, Function_Return_Void_IOAnnotation) {
     auto* f = FragmentEntryPoint();
     f->SetReturnLocation(0);
diff --git a/src/tint/lang/core/ir/validator_value_test.cc b/src/tint/lang/core/ir/validator_value_test.cc
index f6804e4..e06d84a 100644
--- a/src/tint/lang/core/ir/validator_value_test.cc
+++ b/src/tint/lang/core/ir/validator_value_test.cc
@@ -698,6 +698,75 @@
 )")) << res.Failure();
 }
 
+TEST_F(IR_ValidatorTest, Var_Location_InvalidType) {
+    auto* v = b.Var<AddressSpace::kIn, bool>();
+    v->SetLocation(0);
+    mod.root_block->Append(v);
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:2:30 error: var: module scope variable with a location attribute must be a numeric scalar or vector, but has type ptr<__in, bool, read>
+  %1:ptr<__in, bool, read> = var undef @location(0)
+                             ^^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Var_Struct_Location_InvalidType) {
+    IOAttributes attr;
+    attr.location = 0;
+
+    auto* str_ty =
+        ty.Struct(mod.symbols.New("MyStruct"), {
+                                                   {mod.symbols.New("a"), ty.bool_(), attr},
+                                               });
+    auto* v = b.Var(ty.ptr(AddressSpace::kOut, str_ty, read_write));
+    mod.root_block->Append(v);
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:6:41 error: var: module scope variable struct member with a location attribute must be a numeric scalar or vector, but has type bool
+  %1:ptr<__out, MyStruct, read_write> = var undef
+                                        ^^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Var_Location_Struct_WithCapability) {
+    auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
+                                                              {mod.symbols.New("a"), ty.f32()},
+                                                          });
+    auto* v = b.Var(ty.ptr(AddressSpace::kIn, str_ty, read));
+    v->SetLocation(0);
+    mod.root_block->Append(v);
+
+    auto res = ir::Validate(mod, Capabilities{Capability::kAllowLocationForNumericElements});
+    ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Var_Location_Struct_WithoutCapability) {
+    auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
+                                                              {mod.symbols.New("a"), ty.f32()},
+                                                          });
+    auto* v = b.Var(ty.ptr(AddressSpace::kIn, str_ty, read));
+    v->SetLocation(0);
+    mod.root_block->Append(v);
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:6:34 error: var: module scope variable with a location attribute must be a numeric scalar or vector, but has type ptr<__in, MyStruct, read>
+  %1:ptr<__in, MyStruct, read> = var undef @location(0)
+                                 ^^^
+)")) << res.Failure();
+}
+
 TEST_F(IR_ValidatorTest, Var_Sampler_NonHandleAddressSpace) {
     auto* v = b.Var(ty.ptr(AddressSpace::kPrivate, ty.sampler(), read_write));
     mod.root_block->Append(v);
diff --git a/src/tint/lang/spirv/reader/lower/shader_io.cc b/src/tint/lang/spirv/reader/lower/shader_io.cc
index 3057dc1..f198e03 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io.cc
@@ -746,14 +746,16 @@
 }  // namespace
 
 Result<SuccessType> ShaderIO(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.ShaderIO",
-                                          core::ir::Capabilities{
-                                              core::ir::Capability::kAllowMultipleEntryPoints,
-                                              core::ir::Capability::kAllowOverrides,
-                                              core::ir::Capability::kAllowPhonyInstructions,
-                                              core::ir::Capability::kAllowNonCoreTypes,
-                                              core::ir::Capability::kAllowStructMatrixDecorations,
-                                          });
+    auto result =
+        ValidateAndDumpIfNeeded(ir, "spirv.ShaderIO",
+                                core::ir::Capabilities{
+                                    core::ir::Capability::kAllowMultipleEntryPoints,
+                                    core::ir::Capability::kAllowOverrides,
+                                    core::ir::Capability::kAllowPhonyInstructions,
+                                    core::ir::Capability::kAllowNonCoreTypes,
+                                    core::ir::Capability::kAllowStructMatrixDecorations,
+                                    core::ir::Capability::kAllowLocationForNumericElements,
+                                });
     if (result != Success) {
         return result.Failure();
     }
diff --git a/src/tint/lang/spirv/reader/lower/shader_io_test.cc b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
index 90b6c81..a57ab3b 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io_test.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
@@ -39,7 +39,10 @@
 
 class SpirvReader_ShaderIOTest : public core::ir::transform::TransformTest {
   public:
-    void SetUp() override { capabilities.Add(core::ir::Capability::kAllowMultipleEntryPoints); }
+    void SetUp() override {
+        capabilities.Add(core::ir::Capability::kAllowMultipleEntryPoints);
+        capabilities.Add(core::ir::Capability::kAllowLocationForNumericElements);
+    }
 
   protected:
     core::IOAttributes BuiltinAttrs(core::BuiltinValue builtin) {
diff --git a/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc b/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
index 1612a04..b183ac8 100644
--- a/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
+++ b/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
@@ -154,15 +154,17 @@
 }  // namespace
 
 Result<SuccessType> VectorElementPointer(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.VectorElementPointer",
-                                          core::ir::Capabilities{
-                                              core::ir::Capability::kAllowMultipleEntryPoints,
-                                              core::ir::Capability::kAllowOverrides,
-                                              core::ir::Capability::kAllowVectorElementPointer,
-                                              core::ir::Capability::kAllowPhonyInstructions,
-                                              core::ir::Capability::kAllowNonCoreTypes,
-                                              core::ir::Capability::kAllowStructMatrixDecorations,
-                                          });
+    auto result =
+        ValidateAndDumpIfNeeded(ir, "spirv.VectorElementPointer",
+                                core::ir::Capabilities{
+                                    core::ir::Capability::kAllowMultipleEntryPoints,
+                                    core::ir::Capability::kAllowOverrides,
+                                    core::ir::Capability::kAllowVectorElementPointer,
+                                    core::ir::Capability::kAllowPhonyInstructions,
+                                    core::ir::Capability::kAllowNonCoreTypes,
+                                    core::ir::Capability::kAllowStructMatrixDecorations,
+                                    core::ir::Capability::kAllowLocationForNumericElements,
+                                });
     if (result != Success) {
         return result.Failure();
     }