spirv-reader: polyfill scalar faceForward

Bug: tint:1018
Change-Id: I912c6deaed4e3c3f4c5dfb76e7ed7e917b4c6498
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58820
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 819b9a3..43401f5 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -3988,43 +3988,65 @@
 
   auto* result_type = parser_impl_.ConvertType(inst.type_id());
 
-  if ((ext_opcode == GLSLstd450Normalize) && result_type->IsScalar()) {
-    // WGSL does not have scalar form of the normalize builtin.
-    // The answer would be 1 anyway, so return that directly.
-    return {ty_.F32(),
-            create<ast::ScalarConstructorExpression>(
-                Source{}, create<ast::FloatLiteral>(Source{}, 1.0f))};
-  }
-  if ((ext_opcode == GLSLstd450Refract) && result_type->IsScalar()) {
-    // WGSL does not have scalar form of the refract builtin.
-    // It's a complicated expression.  Implement it by /computing it in two
-    // dimensions, but with a 0-valued y component in both the incident and
-    // normal vectors, then take the x component of that result.
-    auto incident = MakeOperand(inst, 2);
-    auto normal = MakeOperand(inst, 3);
-    auto eta = MakeOperand(inst, 4);
-    TINT_ASSERT(Reader, incident.type->Is<F32>());
-    TINT_ASSERT(Reader, normal.type->Is<F32>());
-    TINT_ASSERT(Reader, eta.type->Is<F32>());
-    if (!success()) {
-      return {};
+  if (result_type->IsScalar()) {
+    // Some GLSLstd450 builtins have scalar forms not supported by WGSL.
+    // Emulate them.
+    switch (ext_opcode) {
+      case GLSLstd450Normalize:
+        // WGSL does not have scalar form of the normalize builtin.
+        // The answer would be 1 anyway, so return that directly.
+        return {ty_.F32(), builder_.Expr(1.0f)};
+      case GLSLstd450FaceForward: {
+        // If dot(Nref, Incident) < 0, the result is Normal, otherwise -Normal.
+        // Also: select(-normal,normal, Incident*Nref < 0)
+        // (The dot product of scalars is their product.)
+        // Use a multiply instead of comparing floating point signs. It should
+        // be among the fastest operations on a GPU.
+        auto normal = MakeOperand(inst, 2);
+        auto incident = MakeOperand(inst, 3);
+        auto nref = MakeOperand(inst, 4);
+        TINT_ASSERT(Reader, normal.type->Is<F32>());
+        TINT_ASSERT(Reader, incident.type->Is<F32>());
+        TINT_ASSERT(Reader, nref.type->Is<F32>());
+        return {ty_.F32(),
+                builder_.Call(
+                    Source{}, "select",
+                    ast::ExpressionList{
+                        create<ast::UnaryOpExpression>(
+                            Source{}, ast::UnaryOp::kNegation, normal.expr),
+                        normal.expr,
+                        create<ast::BinaryExpression>(
+                            Source{}, ast::BinaryOp::kLessThan,
+                            builder_.Mul({}, incident.expr, nref.expr),
+                            builder_.Expr(0.0f))})};
+      }
+
+      case GLSLstd450Refract: {
+        // It's a complicated expression. Compute it in two dimensions, but
+        // with a 0-valued y component in both the incident and normal vectors,
+        // then take the x component of that result.
+        auto incident = MakeOperand(inst, 2);
+        auto normal = MakeOperand(inst, 3);
+        auto eta = MakeOperand(inst, 4);
+        TINT_ASSERT(Reader, incident.type->Is<F32>());
+        TINT_ASSERT(Reader, normal.type->Is<F32>());
+        TINT_ASSERT(Reader, eta.type->Is<F32>());
+        if (!success()) {
+          return {};
+        }
+        const Type* f32 = eta.type;
+        return {f32,
+                builder_.MemberAccessor(
+                    builder_.Call(
+                        Source{}, "refract",
+                        ast::ExpressionList{
+                            builder_.vec2<float>(incident.expr, 0.0f),
+                            builder_.vec2<float>(normal.expr, 0.0f), eta.expr}),
+                    "x")};
+      }
+      default:
+        break;
     }
-    const Type* f32 = eta.type;
-    const Type* vec2 = ty_.Vector(f32, 2);
-    return {
-        f32,
-        builder_.MemberAccessor(
-            builder_.Call(
-                Source{}, "refract",
-                ast::ExpressionList{
-                    builder_.Construct(vec2->Build(builder_),
-                                       ast::ExpressionList{
-                                           incident.expr, builder_.Expr(0.0f)}),
-                    builder_.Construct(
-                        vec2->Build(builder_),
-                        ast::ExpressionList{normal.expr, builder_.Expr(0.0f)}),
-                    eta.expr}),
-            "x")};
   }
 
   const auto name = GetGlslStd450FuncName(ext_opcode);
@@ -4654,18 +4676,33 @@
 void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
   // Mark vector operands of OpVectorShuffle as needing a named definition,
   // but only if they are defined in this function as well.
+  auto require_named_const_def = [&](const spvtools::opt::Instruction& inst,
+                                     int in_operand_index) {
+    const auto id = inst.GetSingleWordInOperand(in_operand_index);
+    auto* const operand_def = GetDefInfo(id);
+    if (operand_def) {
+      operand_def->requires_named_const_def = true;
+    }
+  };
   for (auto& id_def_info_pair : def_info_) {
     const auto& inst = id_def_info_pair.second->inst;
     const auto opcode = inst.opcode();
     if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
       // We might access the vector operands multiple times. Make sure they
       // are evaluated only once.
-      for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
-        auto id = inst.GetSingleWordInOperand(vector_arg);
-        auto* operand_def = GetDefInfo(id);
-        if (operand_def) {
-          operand_def->requires_named_const_def = true;
-        }
+      require_named_const_def(inst, 0);
+      require_named_const_def(inst, 1);
+    }
+    if (parser_impl_.IsGlslExtendedInstruction(inst)) {
+      // Some emulations of GLSLstd450 instructions evaluate certain operands
+      // multiple times. Ensure their expressions are evaluated only once.
+      switch (inst.GetSingleWordInOperand(1)) {
+        case GLSLstd450FaceForward:
+          // The "normal" operand expression is used twice in code generation.
+          require_named_const_def(inst, 2);
+          break;
+        default:
+          break;
       }
     }
   }
diff --git a/src/reader/spirv/function_glsl_std_450_test.cc b/src/reader/spirv/function_glsl_std_450_test.cc
index 599a016..0699005 100644
--- a/src/reader/spirv/function_glsl_std_450_test.cc
+++ b/src/reader/spirv/function_glsl_std_450_test.cc
@@ -739,7 +739,6 @@
     ::testing::ValuesIn(std::vector<GlslStd450Case>{
         {"NClamp", "clamp"},
         {"FClamp", "clamp"},  // WGSL FClamp promises more for NaN
-        {"FaceForward", "faceForward"},
         {"Fma", "fma"},
         {"FMix", "mix"},
         {"SmoothStep", "smoothStep"}}));
@@ -1820,6 +1819,100 @@
   EXPECT_THAT(body, HasSubstr(expected));
 }
 
+TEST_F(SpvParserTest, GlslStd450_FaceForward_Scalar) {
+  const auto assembly = Preamble() + R"(
+     %99 = OpFAdd %float %f1 %f1 ; normal operand has only one use
+     %1 = OpExtInst %float %glsl FaceForward %99 %f2 %f3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  auto fe = p->function_emitter(100);
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(p->builder(), fe.ast_body());
+  // The %99 sum only has one use.  Ensure it is evaluated only once by
+  // making a let-declaration for it, since it is the normal operand to
+  // the builtin function, and code generation uses it twice.
+  const auto* expected = R"(VariableDeclStatement{
+  VariableConst{
+    x_99
+    none
+    undefined
+    __f32
+    {
+      Binary[not set]{
+        Identifier[not set]{f1}
+        add
+        Identifier[not set]{f1}
+      }
+    }
+  }
+}
+VariableDeclStatement{
+  VariableConst{
+    x_1
+    none
+    undefined
+    __f32
+    {
+      Call[not set]{
+        Identifier[not set]{select}
+        (
+          UnaryOp[not set]{
+            negation
+            Identifier[not set]{x_99}
+          }
+          Identifier[not set]{x_99}
+          Binary[not set]{
+            Binary[not set]{
+              Identifier[not set]{f2}
+              multiply
+              Identifier[not set]{f3}
+            }
+            less_than
+            ScalarConstructor[not set]{0.000000}
+          }
+        )
+      }
+    }
+  }
+})";
+
+  EXPECT_THAT(body, HasSubstr(expected)) << body;
+}
+
+TEST_F(SpvParserTest, GlslStd450_FaceForward_Vector) {
+  const auto assembly = Preamble() + R"(
+     %99 = OpFAdd %v2float %v2f1 %v2f1
+     %1 = OpExtInst %v2float %glsl FaceForward %v2f1 %v2f2 %v2f3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  auto fe = p->function_emitter(100);
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(p->builder(), fe.ast_body());
+  const auto* expected = R"(VariableConst{
+    x_1
+    none
+    undefined
+    __vec_2__f32
+    {
+      Call[not set]{
+        Identifier[not set]{faceForward}
+        (
+          Identifier[not set]{v2f1}
+          Identifier[not set]{v2f2}
+          Identifier[not set]{v2f3}
+        )
+      }
+    })";
+
+  EXPECT_THAT(body, HasSubstr(expected));
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader