resolver: Clean up workgroup_size validation

Actually call Expression() on the workgroup sizes.
This generates the semantic information for the expressions that would otherwise be missing.

Bug: tint:910
Change-Id: I9d7f9d6b029165dfb3bd1e0bf7ce86c0a71dd4d5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/60205
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 16cbeaf..58aaf57 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1802,33 +1802,51 @@
   if (auto* workgroup =
           ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
     auto values = workgroup->values();
-    auto is_i32 = false;
-    auto is_less_than_one = true;
+    auto any_i32 = false;
+    auto any_u32 = false;
     for (int i = 0; i < 3; i++) {
       // Each argument to this decoration can either be a literal, an
       // identifier for a module-scope constants, or nullptr if not specified.
 
-      if (!values[i]) {
+      auto* expr = values[i];
+      if (!expr) {
         // Not specified, just use the default.
         continue;
       }
 
-      Mark(values[i]);
+      Mark(expr);
+      if (!Expression(expr)) {
+        return false;
+      }
 
-      uint32_t value = 0;
-      if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
+      constexpr const char* kErrBadType =
+          "workgroup_size parameter must be either literal or module-scope "
+          "constant of type i32 or u32";
+      constexpr const char* kErrInconsistentType =
+          "workgroup_size parameters must be of the same type, either i32 "
+          "or u32";
+
+      auto* ty = TypeOf(expr);
+      bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
+      bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
+      if (!is_i32 && !is_u32) {
+        AddError(kErrBadType, expr->source());
+        return false;
+      }
+
+      any_i32 = any_i32 || is_i32;
+      any_u32 = any_u32 || is_u32;
+      if (any_i32 && any_u32) {
+        AddError(kErrInconsistentType, expr->source());
+        return false;
+      }
+
+      if (auto* ident = expr->As<ast::IdentifierExpression>()) {
         // We have an identifier of a module-scope constant.
-        if (!Identifier(ident)) {
-          return false;
-        }
-
-        VariableInfo* var;
+        VariableInfo* var = nullptr;
         if (!variable_stack_.get(ident->symbol(), &var) ||
-            !(var->declaration->is_const() && var->type->is_integer_scalar())) {
-          AddError(
-              "workgroup_size parameter must be either literal or module-scope "
-              "constant of type i32 or u32",
-              values[i]->source());
+            !(var->declaration->is_const())) {
+          AddError(kErrBadType, expr->source());
           return false;
         }
 
@@ -1838,75 +1856,30 @@
           info->workgroup_size[i].overridable_const = var->declaration;
         }
 
-        auto* constructor = var->declaration->constructor();
-        if (constructor) {
-          // Resolve the constructor expression to use as the default value.
-          auto val = ConstantValueOf(constructor);
-          if (!val.IsValid() || !val.Type()->is_integer_scalar()) {
-            TINT_ICE(Resolver, diagnostics_)
-                << "failed to resolve workgroup_size constant value";
-            return false;
-          }
-
-          if (i == 0) {
-            is_i32 = val.Type()->Is<sem::I32>();
-          } else {
-            if (is_i32 != val.Type()->Is<sem::I32>()) {
-              AddError(
-                  "workgroup_size parameters must be of the same type, "
-                  "either i32 or u32",
-                  values[i]->source());
-              return false;
-            }
-          }
-          is_less_than_one =
-              is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1;
-
-          value = is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
-                         : val.Elements()[0].u32;
-        } else {
+        expr = var->declaration->constructor();
+        if (!expr) {
           // No constructor means this value must be overriden by the user.
           info->workgroup_size[i].value = 0;
           continue;
         }
-      } else if (auto* scalar =
-                     values[i]->As<ast::ScalarConstructorExpression>()) {
-        // We have a literal.
-        Mark(scalar->literal());
-        auto* literal = scalar->literal()->As<ast::IntLiteral>();
-        if (!literal) {
-          AddError(
-              "workgroup_size parameter must be either literal or module-scope "
-              "constant of type i32 or u32",
-              values[i]->source());
-          return false;
-        }
-
-        if (i == 0) {
-          is_i32 = literal->Is<ast::SintLiteral>();
-        } else {
-          if (literal->Is<ast::SintLiteral>() != is_i32) {
-            AddError(
-                "workgroup_size parameters must be of the same type, "
-                "either i32 or u32",
-                values[i]->source());
-            return false;
-          }
-        }
-
-        is_less_than_one =
-            is_i32 ? literal->value_as_i32() < 1 : literal->value_as_u32() < 1;
-        value = is_i32 ? static_cast<uint32_t>(literal->value_as_i32())
-                       : literal->value_as_u32();
       }
 
+      auto val = ConstantValueOf(expr);
+      if (!val) {
+        TINT_ICE(Resolver, diagnostics_)
+            << "could not resolve constant workgroup_size constant value";
+        continue;
+      }
       // Validate and set the default value for this dimension.
-      if (is_less_than_one) {
+      if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) {
         AddError("workgroup_size parameter must be at least 1",
                  values[i]->source());
         return false;
       }
-      info->workgroup_size[i].value = value;
+
+      info->workgroup_size[i].value =
+          is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
+                 : val.Elements()[0].u32;
     }
   }