[tint][ir][val] Check rules around bool for shader IO

Change-Id: I335c1e1091f734e96e58116ec33a5ecdfebab6a3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/212739
Reviewed-by: James Price <jrprice@google.com>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 2c3c9dd..e345413 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -257,6 +257,16 @@
     }
 }
 
+// Wrapper for CheckIOAttributesAndType, when the struct and non-struct impl are the same
+/// See @ref IOAttributesAndType for more details
+template <typename MSG_ANCHOR, typename IMPL>
+void CheckIOAttributesAndType(const MSG_ANCHOR* msg_anchor,
+                              const IOAttributes& ty_attr,
+                              const core::type::Type* ty,
+                              IMPL&& impl) {
+    CheckIOAttributesAndType(msg_anchor, ty_attr, ty, impl, impl);
+}
+
 /// Helper for calling IOAttributesAndType on a function param
 /// @param param function param to be tested
 /// See @ref IOAttributesAndType for more details
@@ -269,6 +279,14 @@
                              std::forward<IS_STRUCT>(is_struct_impl));
 }
 
+/// Helper for calling IOAttributesAndType on a function param
+/// @param param function param to be tested
+/// See @ref IOAttributesAndType for more details
+template <typename IMPL>
+void CheckFunctionParamAttributesAndType(const FunctionParam* param, IMPL&& impl) {
+    CheckIOAttributesAndType(param, param->Attributes(), param->Type(), std::forward<IMPL>(impl));
+}
+
 /// Helper for calling IOAttributesAndType on a function return
 /// @param func function's return to be tested
 /// See @ref IOAttributesAndType for more details
@@ -281,6 +299,15 @@
                              std::forward<IS_STRUCT>(is_struct_impl));
 }
 
+/// Helper for calling IOAttributesAndType on a function return
+/// @param func function's return to be tested
+/// See @ref IOAttributesAndType for more details
+template <typename IMPL>
+void CheckFunctionReturnAttributesAndType(const Function* func, IMPL&& impl) {
+    CheckIOAttributesAndType(func, func->ReturnAttributes(), func->ReturnType(),
+                             std::forward<IMPL>(impl));
+}
+
 /// A BuiltinChecker is the interface used to check that a usage of a builtin attribute meets the
 /// basic spec rules, i.e. correct shader stage, data type, and IO direction.
 /// It does not test more sophisticated rules like location and builtins being mutually exclusive or
@@ -918,6 +945,31 @@
         };
     }
 
+    /// @returns a function that validates that type is bool iff decorated with
+    /// @builtin(front_facing)
+    /// @param err error message to log when check fails
+    template <typename MSG_ANCHOR>
+    auto CheckFrontFacingIfBoolFunc(const std::string& err) {
+        return [this, err](const MSG_ANCHOR* msg_anchor, const IOAttributes& attr,
+                           const type::Type* ty) {
+            if (ty->Is<core::type::Bool>() && attr.builtin != BuiltinValue::kFrontFacing) {
+                AddError(msg_anchor) << err;
+            }
+        };
+    }
+
+    /// @returns a function that validates that type is not bool
+    /// @param err error message to log when check fails
+    template <typename MSG_ANCHOR>
+    auto CheckNotBool(const std::string& err) {
+        return [this, err](const MSG_ANCHOR* msg_anchor, [[maybe_unused]] const IOAttributes& attr,
+                           const type::Type* ty) {
+            if (ty->Is<core::type::Bool>()) {
+                AddError(msg_anchor) << err;
+            }
+        };
+    }
+
     /// Validates the given instruction
     /// @param inst the instruction to validate
     void CheckInstruction(const Instruction* inst);
@@ -1740,8 +1792,7 @@
             }
         }
 
-        CheckFunctionParamAttributesAndType(param, CheckBuiltinFunctionParam(""),
-                                            CheckBuiltinFunctionParam(""));
+        CheckFunctionParamAttributesAndType(param, CheckBuiltinFunctionParam(""));
 
         CheckFunctionParamAttributes(
             param,
@@ -1757,6 +1808,21 @@
             CheckDoesNotHaveBothLocationAndBuiltinFunc<FunctionParam>(
                 "a builtin and location cannot be both declared for a struct member"));
 
+        if (func->Stage() == Function::PipelineStage::kFragment) {
+            CheckFunctionParamAttributesAndType(
+                param,
+                CheckFrontFacingIfBoolFunc<FunctionParam>(
+                    "fragment entry point params can only be a bool if decorated with "
+                    "@builtin(front_facing)"),
+                CheckFrontFacingIfBoolFunc<FunctionParam>(
+                    "fragment entry point param memebers can only be a bool if "
+                    "decorated with @builtin(front_facing)"));
+        } else if (func->Stage() != Function::PipelineStage::kUndefined) {
+            CheckFunctionParamAttributesAndType(
+                param, CheckNotBool<FunctionParam>(
+                           "entry point params can only be a bool for fragment shaders"));
+        }
+
         scope_stack_.Add(param);
     }
 
@@ -1765,8 +1831,7 @@
         func->ReturnType(), [&]() -> diag::Diagnostic& { return AddError(func); },
         Capabilities{Capability::kAllowRefTypes});
 
-    CheckFunctionReturnAttributesAndType(func, CheckBuiltinFunctionReturn(""),
-                                         CheckBuiltinFunctionReturn(""));
+    CheckFunctionReturnAttributesAndType(func, CheckBuiltinFunctionReturn(""));
 
     CheckFunctionReturnAttributes(
         func,
@@ -1823,6 +1888,44 @@
         }
     }
 
+    if (func->Stage() != Function::PipelineStage::kUndefined) {
+        CheckFunctionReturnAttributesAndType(
+            func, CheckFrontFacingIfBoolFunc<Function>("entry point returns can not be bool"),
+            CheckFrontFacingIfBoolFunc<Function>("entry point return members can not be bool"));
+
+        for (auto var : referenced_module_vars_.TransitiveReferences(func)) {
+            const auto* mv = var->Result(0)->Type()->As<type::MemoryView>();
+            const auto* ty = var->Result(0)->Type()->UnwrapPtrOrRef();
+            const auto attr = var->Attributes();
+            if (!mv || !ty) {
+                continue;
+            }
+
+            if (mv->AddressSpace() != AddressSpace::kIn &&
+                mv->AddressSpace() != AddressSpace::kOut) {
+                continue;
+            }
+
+            if (func->Stage() == Function::PipelineStage::kFragment &&
+                mv->AddressSpace() == AddressSpace::kIn) {
+                CheckIOAttributesAndType(
+                    func, attr, ty,
+                    CheckFrontFacingIfBoolFunc<Function>("input address space values referenced by "
+                                                         "fragment shaders can only be a bool if "
+                                                         "decorated with @builtin(front_facing)"));
+
+            } else {
+                CheckIOAttributesAndType(
+                    func, attr, ty,
+                    CheckNotBool<Function>(
+                        "IO address space values referenced by shader entry points can only be "
+                        "bool if "
+                        "in the input space, used only by fragment shaders and decorated with "
+                        "@builtin(front_facing)"));
+            }
+        }
+    }
+
     if (func->Stage() == Function::PipelineStage::kVertex) {
         CheckVertexEntryPoint(func);
     }
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index f9fe78d..d650c31 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -84,21 +84,19 @@
         func->AppendParam(p);
     }
 
-    /// Adds to a function an return value of type @p type, and decorated with @p builtin.
+    /// Adds to a function an return value of type @p type with attributes @p attr.
     /// If there is an already existing non-structured return, both values are moved into a
     /// structured return using @p name as the name.
     /// If there is an already existing structured return, then this ICEs, since that is beyond the
     /// scope of this implementation.
-    void AddBuiltinReturn(Function* func,
-                          const std::string& name,
-                          BuiltinValue builtin,
-                          const core::type::Type* type) {
+    void AddReturn(Function* func,
+                   const std::string& name,
+                   const core::type::Type* type,
+                   const IOAttributes& attr = {}) {
         if (func->ReturnType()->Is<core::type::Struct>()) {
-            TINT_ICE() << "AddBuiltinReturn does not support adding to structured returns";
+            TINT_ICE() << "AddReturn does not support adding to structured returns";
         }
 
-        IOAttributes attr;
-        attr.builtin = builtin;
         if (func->ReturnType() == ty.void_()) {
             func->SetReturnAttributes(attr);
             func->SetReturnType(type);
@@ -117,6 +115,17 @@
         func->SetReturnAttributes({});
         func->SetReturnType(str_ty);
     }
+
+    /// Adds to a function an return value of type @p type, and decorated with @p builtin.
+    /// See @ref AddReturn for more details
+    void AddBuiltinReturn(Function* func,
+                          const std::string& name,
+                          BuiltinValue builtin,
+                          const core::type::Type* type) {
+        IOAttributes attr;
+        attr.builtin = builtin;
+        AddReturn(func, name, type, attr);
+    }
 };
 
 TEST_F(IR_ValidatorTest, RootBlock_Var) {
@@ -1058,6 +1067,168 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, Function_NonFragment_BoolInput) {
+    auto* f = VertexEntryPoint();
+    f->AppendParam(b.FunctionParam("invalid", ty.bool_()));
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:1:19 error: entry point params can only be a bool for fragment shaders
+%f = @vertex func(%invalid:bool):vec4<f32> [@position] {
+                  ^^^^^^^^^^^^^
+
+note: # Disassembly
+%f = @vertex func(%invalid:bool):vec4<f32> [@position] {
+  $B1: {
+    unreachable
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_NonFragment_BoolOutput) {
+    auto* f = VertexEntryPoint();
+    AddReturn(f, "invalid", ty.bool_());
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:6:1 error: entry point return members can not be bool
+%f = @vertex func():OutputStruct {
+^^
+
+note: # Disassembly
+OutputStruct = struct @align(16) {
+  pos:vec4<f32> @offset(0), @builtin(position)
+  invalid:bool @offset(16)
+}
+
+%f = @vertex func():OutputStruct {
+  $B1: {
+    unreachable
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_Fragment_BoolInputWithoutFrontFacing) {
+    auto* f = FragmentEntryPoint();
+    f->AppendParam(b.FunctionParam("invalid", ty.bool_()));
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(
+        res.Failure().reason.Str(),
+        R"(:1:21 error: fragment entry point params can only be a bool if decorated with @builtin(front_facing)
+%f = @fragment func(%invalid:bool):void {
+                    ^^^^^^^^^^^^^
+
+note: # Disassembly
+%f = @fragment func(%invalid:bool):void {
+  $B1: {
+    unreachable
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_Fragment_BoolOutput) {
+    auto* f = FragmentEntryPoint();
+    IOAttributes attr;
+    attr.location = 0;
+    AddReturn(f, "invalid", ty.bool_(), attr);
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:1:1 error: entry point returns can not be bool
+%f = @fragment func():bool [@location(0)] {
+^^
+
+note: # Disassembly
+%f = @fragment func():bool [@location(0)] {
+  $B1: {
+    unreachable
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_BoolOutput_via_MSV) {
+    auto* f = ComputeEntryPoint();
+
+    auto* v = b.Var(ty.ptr(AddressSpace::kOut, ty.bool_(), core::Access::kReadWrite));
+    mod.root_block->Append(v);
+
+    b.Append(f->Block(), [&] {
+        b.Append(
+            mod.CreateInstruction<ir::Store>(v->Result(0), b.Constant(b.ConstantValue(false))));
+        b.Unreachable();
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(
+        res.Failure().reason.Str(),
+        R"(:5:1 error: IO address space values referenced by shader entry points can only be bool if in the input space, used only by fragment shaders and decorated with @builtin(front_facing)
+%f = @compute @workgroup_size(1u, 1u, 1u) func():void {
+^^
+
+note: # Disassembly
+$B1: {  # root
+  %1:ptr<__out, bool, read_write> = var
+}
+
+%f = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B2: {
+    store %1, false
+    unreachable
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_BoolInputWithoutFrontFacing_via_MSV) {
+    auto* f = FragmentEntryPoint();
+
+    auto* invalid = b.Var("invalid", AddressSpace::kIn, ty.bool_());
+    mod.root_block->Append(invalid);
+
+    b.Append(f->Block(), [&] {
+        auto* l = b.Load(invalid);
+        auto* v = b.Var("v", AddressSpace::kPrivate, ty.bool_());
+        v->SetInitializer(l->Result(0));
+        b.Unreachable();
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(
+        res.Failure().reason.Str(),
+        R"(:5:1 error: input address space values referenced by fragment shaders can only be a bool if decorated with @builtin(front_facing)
+%f = @fragment func():void {
+^^
+
+note: # Disassembly
+$B1: {  # root
+  %invalid:ptr<__in, bool, read> = var
+}
+
+%f = @fragment func():void {
+  $B2: {
+    %3:bool = load %invalid
+    %v:ptr<private, bool, read_write> = var, %3
+    unreachable
+  }
+}
+)");
+}
+
 TEST_F(IR_ValidatorTest, Builtin_PointSize_WrongStage) {
     auto* f = FragmentEntryPoint();
     AddBuiltinReturn(f, "size", BuiltinValue::kPointSize, ty.f32());
@@ -1284,6 +1455,10 @@
 %f = @vertex func(%facing:bool [@front_facing]):vec4<f32> [@position] {
                   ^^^^^^^^^^^^
 
+:1:19 error: entry point params can only be a bool for fragment shaders
+%f = @vertex func(%facing:bool [@front_facing]):vec4<f32> [@position] {
+                  ^^^^^^^^^^^^
+
 note: # Disassembly
 %f = @vertex func(%facing:bool [@front_facing]):vec4<f32> [@position] {
   $B1: {
diff --git a/src/tint/lang/spirv/writer/raise/merge_return_test.cc b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
index 51e6de0..860f1b9 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return_test.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
@@ -212,11 +212,11 @@
 }
 
 TEST_F(SpirvWriter_MergeReturnTest, NoModify_EntryPoint_IfElse_OneSideReturns) {
-    auto* cond = b.FunctionParam(ty.bool_());
+    auto* cond = b.FunctionParam(ty.u32());
     auto* func = b.ComputeFunction("entrypointfunction", 2_u, 3_u, 4_u);
     func->SetParams({cond});
     b.Append(func->Block(), [&] {
-        auto* ifelse = b.If(cond);
+        auto* ifelse = b.If(b.Equal(ty.bool_(), cond, 0_u));
         b.Append(ifelse->True(), [&] { b.Return(func); });
         b.Append(ifelse->False(), [&] { b.ExitIf(ifelse); });
 
@@ -224,9 +224,10 @@
     });
 
     auto* src = R"(
-%entrypointfunction = @compute @workgroup_size(2u, 3u, 4u) func(%2:bool):void {
+%entrypointfunction = @compute @workgroup_size(2u, 3u, 4u) func(%2:u32):void {
   $B1: {
-    if %2 [t: $B2, f: $B3] {  # if_1
+    %3:bool = eq %2, 0u
+    if %3 [t: $B2, f: $B3] {  # if_1
       $B2: {  # true
         ret
       }