[ir] Add function attributes

This CL adds the pipeline_stage and workgroup_size attributes into the
IR function.

Bug: tint:1915
Change-Id: I245dbf0104a1784cff364535106b3e520322ac73
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/130920
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc
index ff19814..eedf344 100644
--- a/src/tint/ir/builder_impl.cc
+++ b/src/tint/ir/builder_impl.cc
@@ -65,6 +65,7 @@
 #include "src/tint/program.h"
 #include "src/tint/sem/builtin.h"
 #include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
 #include "src/tint/sem/materialize.h"
 #include "src/tint/sem/module.h"
 #include "src/tint/sem/switch_statement.h"
@@ -212,6 +213,39 @@
 
     if (ast_func->IsEntryPoint()) {
         builder.ir.entry_points.Push(ir_func);
+
+        switch (ast_func->PipelineStage()) {
+            case ast::PipelineStage::kVertex:
+                ir_func->pipeline_stage = Function::PipelineStage::kVertex;
+                break;
+            case ast::PipelineStage::kFragment:
+                ir_func->pipeline_stage = Function::PipelineStage::kFragment;
+                break;
+            case ast::PipelineStage::kCompute: {
+                ir_func->pipeline_stage = Function::PipelineStage::kCompute;
+
+                const auto* sem = program_->Sem().Get(ast_func);
+                auto wg_size = sem->WorkgroupSize();
+
+                uint32_t x = wg_size[0].value();
+                uint32_t y = 1;
+                uint32_t z = 1;
+                if (wg_size[1].has_value()) {
+                    y = wg_size[1].value();
+
+                    if (wg_size[2].has_value()) {
+                        z = wg_size[2].value();
+                    }
+                }
+
+                ir_func->workgroup_size = {x, y, z};
+                break;
+            }
+            default: {
+                TINT_ICE(IR, diagnostics_) << "Invalid pipeline stage";
+                return;
+            }
+        }
     }
 
     {
@@ -222,7 +256,6 @@
 
         // TODO(dsinclair): Store return type and attributes
         // TODO(dsinclair): Store parameters
-        // TODO(dsinclair): Store attributes
 
         // If the branch target has already been set then a `return` was called. Only set in the
         // case where `return` wasn't called.
diff --git a/src/tint/ir/builder_impl_test.cc b/src/tint/ir/builder_impl_test.cc
index f9de91e..431423f 100644
--- a/src/tint/ir/builder_impl_test.cc
+++ b/src/tint/ir/builder_impl_test.cc
@@ -88,7 +88,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -136,7 +136,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(2u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -183,7 +183,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(2u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -230,7 +230,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(2u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -273,7 +273,7 @@
     ASSERT_NE(loop_flow->continuing.target, nullptr);
     ASSERT_NE(loop_flow->merge.target, nullptr);
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -330,7 +330,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -388,7 +388,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -463,7 +463,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -538,7 +538,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -595,7 +595,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -657,7 +657,7 @@
     // This is 1 because only the loop branch happens. The subsequent if return is dead code.
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -711,7 +711,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -857,7 +857,7 @@
     EXPECT_EQ(1u, func->start_target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1001,7 +1001,7 @@
     EXPECT_EQ(1u, if_flow->false_.target->inbound_branches.Length());
     EXPECT_EQ(1u, if_flow->merge.target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1071,7 +1071,7 @@
     EXPECT_EQ(1u, if_flow->false_.target->inbound_branches.Length());
     EXPECT_EQ(1u, if_flow->merge.target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1178,7 +1178,7 @@
     EXPECT_EQ(1u, flow->merge.target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1237,7 +1237,7 @@
     EXPECT_EQ(3u, flow->merge.target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1301,7 +1301,7 @@
     EXPECT_EQ(1u, flow->merge.target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1345,7 +1345,7 @@
     EXPECT_EQ(1u, flow->merge.target->inbound_branches.Length());
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1398,7 +1398,7 @@
     // This is 1 because the if is dead-code eliminated and the return doesn't happen.
     EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1457,7 +1457,7 @@
     EXPECT_EQ(0u, flow->merge.target->inbound_branches.Length());
     EXPECT_EQ(2u, func->end_target->inbound_branches.Length());
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   branch %fn2
 
@@ -1596,7 +1596,7 @@
     ASSERT_TRUE(r) << Error();
     auto m = r.Move();
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   %1(ref<function, u32, read_write>) = var function read_write
   ret
@@ -1614,7 +1614,7 @@
     ASSERT_TRUE(r) << Error();
     auto m = r.Move();
 
-    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn1 = block
   %1(ref<function, u32, read_write>) = var function read_write
   store %1(ref<function, u32, read_write>), 2u
@@ -1782,7 +1782,7 @@
   ret true
 func_end
 
-%fn2 = func test_function
+%fn2 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn3 = block
   %1(bool) = call my_func
   %2(bool) = var function read_write
@@ -1817,7 +1817,7 @@
   ret true
 func_end
 
-%fn2 = func test_function
+%fn2 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn3 = block
   %1(bool) = call my_func
   %2(bool) = var function read_write
@@ -1999,7 +1999,7 @@
   ret 0.0f
 func_end
 
-%fn2 = func test_function
+%fn2 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn3 = block
   %1(f32) = call my_func
   %2(bool) = lt %1(f32), 2.0f
@@ -2041,7 +2041,7 @@
   ret true
 func_end
 
-%fn2 = func test_function
+%fn2 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn3 = block
   %1(bool) = call my_func, false
   ret
@@ -2137,7 +2137,7 @@
 store %1(ref<private, f32, read_write>), 1.0f
 ret
 
-%fn1 = func test_function
+%fn1 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn2 = block
   %2(vec3<f32>) = construct 2.0f, 3.0f, %1(ref<private, f32, read_write>)
   ret
@@ -2161,7 +2161,7 @@
 store %1(ref<private, i32, read_write>), 1i
 ret
 
-%fn1 = func test_function
+%fn1 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn2 = block
   %2(f32) = convert i32, %1(ref<private, i32, read_write>)
   ret
@@ -2201,7 +2201,7 @@
 store %1(ref<private, f32, read_write>), 1.0f
 ret
 
-%fn1 = func test_function
+%fn1 = func test_function [@compute @workgroup_size(1, 1, 1)]
   %fn2 = block
   %2(f32) = asin %1(ref<private, f32, read_write>)
   ret
@@ -2210,5 +2210,54 @@
 )");
 }
 
+TEST_F(IR_BuilderImplTest, EmitFunction_Vertex) {
+    Func("test", utils::Empty, ty.vec4<f32>(), utils::Vector{Return(vec4<f32>(0_f, 0_f, 0_f, 0_f))},
+         utils::Vector{Stage(ast::PipelineStage::kVertex)},
+         utils::Vector{Builtin(builtin::BuiltinValue::kPosition)});
+
+    auto r = Build();
+    ASSERT_TRUE(r) << Error();
+    auto m = r.Move();
+
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test [@vertex]
+  %fn1 = block
+  ret vec4<f32> 0.0f
+func_end
+
+)");
+}
+
+TEST_F(IR_BuilderImplTest, EmitFunction_Fragment) {
+    Func("test", utils::Empty, ty.void_(), utils::Empty,
+         utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+    auto r = Build();
+    ASSERT_TRUE(r) << Error();
+    auto m = r.Move();
+
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test [@fragment]
+  %fn1 = block
+  ret
+func_end
+
+)");
+}
+
+TEST_F(IR_BuilderImplTest, EmitFunction_Compute) {
+    Func("test", utils::Empty, ty.void_(), utils::Empty,
+         utils::Vector{Stage(ast::PipelineStage::kCompute), WorkgroupSize(8_i, 4_i, 2_i)});
+
+    auto r = Build();
+    ASSERT_TRUE(r) << Error();
+    auto m = r.Move();
+
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func test [@compute @workgroup_size(8, 4, 2)]
+  %fn1 = block
+  ret
+func_end
+
+)");
+}
+
 }  // namespace
 }  // namespace tint::ir
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 30da559..9211da9 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -92,7 +92,18 @@
         [&](const ir::Function* f) {
             TINT_SCOPED_ASSIGNMENT(in_function_, true);
 
-            Indent() << "%fn" << GetIdForNode(f) << " = func " << f->name.Name() << std::endl;
+            Indent() << "%fn" << GetIdForNode(f) << " = func " << f->name.Name();
+            if (f->pipeline_stage != Function::PipelineStage::kUndefined) {
+                out_ << " [@" << f->pipeline_stage;
+
+                if (f->workgroup_size) {
+                    auto arr = f->workgroup_size.value();
+                    out_ << " @workgroup_size(" << arr[0] << ", " << arr[1] << ", " << arr[2]
+                         << ")";
+                }
+                out_ << "]";
+            }
+            out_ << std::endl;
 
             {
                 ScopedIndent func_indent(&indent_size_);
diff --git a/src/tint/ir/function.cc b/src/tint/ir/function.cc
index 167664a..7f73c89 100644
--- a/src/tint/ir/function.cc
+++ b/src/tint/ir/function.cc
@@ -22,4 +22,18 @@
 
 Function::~Function() = default;
 
+utils::StringStream& operator<<(utils::StringStream& out, Function::PipelineStage value) {
+    switch (value) {
+        case Function::PipelineStage::kVertex:
+            return out << "vertex";
+        case Function::PipelineStage::kFragment:
+            return out << "fragment";
+        case Function::PipelineStage::kCompute:
+            return out << "compute";
+        default:
+            break;
+    }
+    return out << "<unknown>";
+}
+
 }  // namespace tint::ir
diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h
index ab3c6eb..bbbd893 100644
--- a/src/tint/ir/function.h
+++ b/src/tint/ir/function.h
@@ -15,6 +15,8 @@
 #ifndef SRC_TINT_IR_FUNCTION_H_
 #define SRC_TINT_IR_FUNCTION_H_
 
+#include <optional>
+
 #include "src/tint/ir/flow_node.h"
 #include "src/tint/symbol.h"
 
@@ -29,6 +31,18 @@
 /// An IR representation of a function
 class Function : public utils::Castable<Function, FlowNode> {
   public:
+    /// The pipeline stage for an entry point
+    enum class PipelineStage {
+        /// Not a pipeline entry point
+        kUndefined,
+        /// Vertex
+        kCompute,
+        /// Fragment
+        kFragment,
+        /// Vertex
+        kVertex,
+    };
+
     /// Constructor
     Function();
     ~Function() override;
@@ -36,6 +50,12 @@
     /// The function name
     Symbol name;
 
+    /// The pipeline stage for the function, `kUndefined` if the function is not an entry point
+    PipelineStage pipeline_stage = PipelineStage::kUndefined;
+
+    /// If this is a `compute` entry point, holds the workgroup size information
+    std::optional<std::array<uint32_t, 3>> workgroup_size;
+
     /// The start target is the first block in a function.
     Block* start_target = nullptr;
     /// The end target is the end of the function. It is used as the branch target if a return is
@@ -43,6 +63,8 @@
     Terminator* end_target = nullptr;
 };
 
+utils::StringStream& operator<<(utils::StringStream& out, Function::PipelineStage value);
+
 }  // namespace tint::ir
 
 #endif  // SRC_TINT_IR_FUNCTION_H_