Add tint::Switch()
A type dispatch helper with replaces chains of:
if (auto* a = obj->As<A>()) {
...
} else if (auto* b = obj->As<B>()) {
...
} else {
...
}
with:
Switch(obj,
[&](A* a) { ... },
[&](B* b) { ... },
[&](Default) { ... });
This new helper provides greater opportunities for optimizations, avoids
scoping issues with if-else blocks, and is slightly cleaner (IMO).
Bug: tint:1383
Change-Id: Ice469a03342ef57cbcf65f69753e4b528ac50137
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78543
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 8c0f20a..600e91f 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -1160,6 +1160,7 @@
endif()
set(TINT_BENCHMARK_SRC
+ "castable_bench.cc"
"bench/benchmark.cc"
"reader/wgsl/parser_bench.cc"
)
diff --git a/src/ast/module.cc b/src/ast/module.cc
index 24999d2..3f06a31 100644
--- a/src/ast/module.cc
+++ b/src/ast/module.cc
@@ -35,16 +35,15 @@
continue;
}
- if (auto* ty = decl->As<ast::TypeDecl>()) {
- type_decls_.push_back(ty);
- } else if (auto* func = decl->As<Function>()) {
- functions_.push_back(func);
- } else if (auto* var = decl->As<Variable>()) {
- global_variables_.push_back(var);
- } else {
- diag::List diagnostics;
- TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
- }
+ Switch(
+ decl, //
+ [&](const ast::TypeDecl* type) { type_decls_.push_back(type); },
+ [&](const Function* func) { functions_.push_back(func); },
+ [&](const Variable* var) { global_variables_.push_back(var); },
+ [&](Default) {
+ diag::List diagnostics;
+ TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
+ });
}
}
@@ -101,19 +100,24 @@
<< "src global declaration was nullptr";
continue;
}
- if (auto* type = decl->As<ast::TypeDecl>()) {
- TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
- type_decls_.push_back(type);
- } else if (auto* func = decl->As<Function>()) {
- TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
- functions_.push_back(func);
- } else if (auto* var = decl->As<Variable>()) {
- TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
- global_variables_.push_back(var);
- } else {
- TINT_ICE(AST, ctx->dst->Diagnostics())
- << "Unknown global declaration type";
- }
+ Switch(
+ decl,
+ [&](const ast::TypeDecl* type) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
+ type_decls_.push_back(type);
+ },
+ [&](const Function* func) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
+ functions_.push_back(func);
+ },
+ [&](const Variable* var) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
+ global_variables_.push_back(var);
+ },
+ [&](Default) {
+ TINT_ICE(AST, ctx->dst->Diagnostics())
+ << "Unknown global declaration type";
+ });
}
}
diff --git a/src/ast/traverse_expressions.h b/src/ast/traverse_expressions.h
index 88d3dfc..b578941 100644
--- a/src/ast/traverse_expressions.h
+++ b/src/ast/traverse_expressions.h
@@ -101,30 +101,47 @@
}
}
- if (auto* idx = expr->As<IndexAccessorExpression>()) {
- push_pair(idx->object, idx->index);
- } else if (auto* bin_op = expr->As<BinaryExpression>()) {
- push_pair(bin_op->lhs, bin_op->rhs);
- } else if (auto* bitcast = expr->As<BitcastExpression>()) {
- to_visit.push_back(bitcast->expr);
- } else if (auto* call = expr->As<CallExpression>()) {
- // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
- // function name in the traversal.
- // to_visit.push_back(call->func);
- push_list(call->args);
- } else if (auto* member = expr->As<MemberAccessorExpression>()) {
- // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
- // member name in the traversal.
- // push_pair(member->structure, member->member);
- to_visit.push_back(member->structure);
- } else if (auto* unary = expr->As<UnaryOpExpression>()) {
- to_visit.push_back(unary->expr);
- } else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
- PhonyExpression>()) {
- // Leaf expression
- } else {
- TINT_ICE(AST, diags) << "unhandled expression type: "
- << expr->TypeInfo().name;
+ bool ok = Switch(
+ expr,
+ [&](const IndexAccessorExpression* idx) {
+ push_pair(idx->object, idx->index);
+ return true;
+ },
+ [&](const BinaryExpression* bin_op) {
+ push_pair(bin_op->lhs, bin_op->rhs);
+ return true;
+ },
+ [&](const BitcastExpression* bitcast) {
+ to_visit.push_back(bitcast->expr);
+ return true;
+ },
+ [&](const CallExpression* call) {
+ // TODO(crbug.com/tint/1257): Resolver breaks if we actually include
+ // the function name in the traversal. to_visit.push_back(call->func);
+ push_list(call->args);
+ return true;
+ },
+ [&](const MemberAccessorExpression* member) {
+ // TODO(crbug.com/tint/1257): Resolver breaks if we actually include
+ // the member name in the traversal. push_pair(member->structure,
+ // member->member);
+ to_visit.push_back(member->structure);
+ return true;
+ },
+ [&](const UnaryOpExpression* unary) {
+ to_visit.push_back(unary->expr);
+ return true;
+ },
+ [&](Default) {
+ if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
+ PhonyExpression>()) {
+ return true; // Leaf expression
+ }
+ TINT_ICE(AST, diags)
+ << "unhandled expression type: " << expr->TypeInfo().name;
+ return false;
+ });
+ if (!ok) {
return false;
}
}
diff --git a/src/castable.h b/src/castable.h
index 3104492..04f2dcc 100644
--- a/src/castable.h
+++ b/src/castable.h
@@ -453,6 +453,105 @@
}
};
+/// Default can be used as the default case for a Switch(), when all previous
+/// cases failed to match.
+///
+/// Example:
+/// ```
+/// Switch(object,
+/// [&](TypeA*) { /* ... */ },
+/// [&](TypeB*) { /* ... */ },
+/// [&](Default) { /* If not TypeA or TypeB */ });
+/// ```
+struct Default {};
+
+/// Switch is used to dispatch one of the provided callback case handler
+/// functions based on the type of `object` and the parameter type of the case
+/// handlers. Switch will sequentially check the type of `object` against each
+/// of the switch case handler functions, and will invoke the first case handler
+/// function which has a parameter type that matches the object type. When a
+/// case handler is matched, it will be called with the single argument of
+/// `object` cast to the case handler's parameter type. Switch will invoke at
+/// most one case handler. Each of the case functions must have the signature
+/// `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
+/// is the return type, consistent across all case handlers.
+///
+/// An optional default case function with the signature `R(Default)` can be
+/// used as the last case. This default case will be called if all previous
+/// cases failed to match.
+///
+/// Example:
+/// ```
+/// Switch(object,
+/// [&](TypeA*) { /* ... */ },
+/// [&](TypeB*) { /* ... */ });
+///
+/// Switch(object,
+/// [&](TypeA*) { /* ... */ },
+/// [&](TypeB*) { /* ... */ },
+/// [&](Default) { /* Called if object is not TypeA or TypeB */ });
+/// ```
+///
+/// @param object the object who's type is used to
+/// @param first_case the first switch case
+/// @param other_cases additional switch cases (optional)
+/// @return the value returned by the called case. If no cases matched, then the
+/// zero value for the consistent case type.
+template <typename T, typename FIRST_CASE, typename... OTHER_CASES>
+traits::ReturnType<FIRST_CASE> //
+Switch(T* object, FIRST_CASE&& first_case, OTHER_CASES&&... other_cases) {
+ using ReturnType = traits::ReturnType<FIRST_CASE>;
+ using CaseType = std::remove_pointer_t<traits::ParameterType<FIRST_CASE, 0>>;
+ static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
+ static_assert(traits::SignatureOfT<FIRST_CASE>::parameter_count == 1,
+ "Switch case must have a single parameter");
+ if constexpr (std::is_same_v<CaseType, Default>) {
+ // Default case. Must be last.
+ (void)object; // 'object' is not used by the Default case.
+ static_assert(sizeof...(other_cases) == 0,
+ "Switch Default case must come last");
+ if constexpr (kHasReturnType) {
+ return first_case({});
+ } else {
+ first_case({});
+ return;
+ }
+ } else {
+ // Regular case.
+ static_assert(traits::IsTypeOrDerived<CaseType, CastableBase>::value,
+ "Switch case parameter is not a Castable pointer");
+ // Does the case match?
+ if (auto* ptr = As<CaseType>(object)) {
+ if constexpr (kHasReturnType) {
+ return first_case(ptr);
+ } else {
+ first_case(ptr);
+ return;
+ }
+ }
+ // Case did not match. Got any more cases to try?
+ if constexpr (sizeof...(other_cases) > 0) {
+ // Try the next cases...
+ if constexpr (kHasReturnType) {
+ auto res = Switch(object, std::forward<OTHER_CASES>(other_cases)...);
+ static_assert(std::is_same_v<decltype(res), ReturnType>,
+ "Switch case types do not have consistent return type");
+ return res;
+ } else {
+ Switch(object, std::forward<OTHER_CASES>(other_cases)...);
+ return;
+ }
+ } else {
+ // That was the last case. No cases matched.
+ if constexpr (kHasReturnType) {
+ return {};
+ } else {
+ return;
+ }
+ }
+ }
+}
+
} // namespace tint
TINT_CASTABLE_POP_DISABLE_WARNINGS();
diff --git a/src/castable_bench.cc b/src/castable_bench.cc
new file mode 100644
index 0000000..839a932
--- /dev/null
+++ b/src/castable_bench.cc
@@ -0,0 +1,270 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bench/benchmark.h"
+
+namespace tint {
+namespace {
+
+struct Base : public tint::Castable<Base> {};
+struct A : public tint::Castable<A, Base> {};
+struct AA : public tint::Castable<AA, A> {};
+struct AAA : public tint::Castable<AAA, AA> {};
+struct AAB : public tint::Castable<AAB, AA> {};
+struct AAC : public tint::Castable<AAC, AA> {};
+struct AB : public tint::Castable<AB, A> {};
+struct ABA : public tint::Castable<ABA, AB> {};
+struct ABB : public tint::Castable<ABB, AB> {};
+struct ABC : public tint::Castable<ABC, AB> {};
+struct AC : public tint::Castable<AC, A> {};
+struct ACA : public tint::Castable<ACA, AC> {};
+struct ACB : public tint::Castable<ACB, AC> {};
+struct ACC : public tint::Castable<ACC, AC> {};
+struct B : public tint::Castable<B, Base> {};
+struct BA : public tint::Castable<BA, B> {};
+struct BAA : public tint::Castable<BAA, BA> {};
+struct BAB : public tint::Castable<BAB, BA> {};
+struct BAC : public tint::Castable<BAC, BA> {};
+struct BB : public tint::Castable<BB, B> {};
+struct BBA : public tint::Castable<BBA, BB> {};
+struct BBB : public tint::Castable<BBB, BB> {};
+struct BBC : public tint::Castable<BBC, BB> {};
+struct BC : public tint::Castable<BC, B> {};
+struct BCA : public tint::Castable<BCA, BC> {};
+struct BCB : public tint::Castable<BCB, BC> {};
+struct BCC : public tint::Castable<BCC, BC> {};
+struct C : public tint::Castable<C, Base> {};
+struct CA : public tint::Castable<CA, C> {};
+struct CAA : public tint::Castable<CAA, CA> {};
+struct CAB : public tint::Castable<CAB, CA> {};
+struct CAC : public tint::Castable<CAC, CA> {};
+struct CB : public tint::Castable<CB, C> {};
+struct CBA : public tint::Castable<CBA, CB> {};
+struct CBB : public tint::Castable<CBB, CB> {};
+struct CBC : public tint::Castable<CBC, CB> {};
+struct CC : public tint::Castable<CC, C> {};
+struct CCA : public tint::Castable<CCA, CC> {};
+struct CCB : public tint::Castable<CCB, CC> {};
+struct CCC : public tint::Castable<CCC, CC> {};
+
+using AllTypes = std::tuple<Base,
+ A,
+ AA,
+ AAA,
+ AAB,
+ AAC,
+ AB,
+ ABA,
+ ABB,
+ ABC,
+ AC,
+ ACA,
+ ACB,
+ ACC,
+ B,
+ BA,
+ BAA,
+ BAB,
+ BAC,
+ BB,
+ BBA,
+ BBB,
+ BBC,
+ BC,
+ BCA,
+ BCB,
+ BCC,
+ C,
+ CA,
+ CAA,
+ CAB,
+ CAC,
+ CB,
+ CBA,
+ CBB,
+ CBC,
+ CC,
+ CCA,
+ CCB,
+ CCC>;
+
+std::vector<std::unique_ptr<Base>> MakeObjects() {
+ std::vector<std::unique_ptr<Base>> out;
+ out.emplace_back(std::make_unique<Base>());
+ out.emplace_back(std::make_unique<A>());
+ out.emplace_back(std::make_unique<AA>());
+ out.emplace_back(std::make_unique<AAA>());
+ out.emplace_back(std::make_unique<AAB>());
+ out.emplace_back(std::make_unique<AAC>());
+ out.emplace_back(std::make_unique<AB>());
+ out.emplace_back(std::make_unique<ABA>());
+ out.emplace_back(std::make_unique<ABB>());
+ out.emplace_back(std::make_unique<ABC>());
+ out.emplace_back(std::make_unique<AC>());
+ out.emplace_back(std::make_unique<ACA>());
+ out.emplace_back(std::make_unique<ACB>());
+ out.emplace_back(std::make_unique<ACC>());
+ out.emplace_back(std::make_unique<B>());
+ out.emplace_back(std::make_unique<BA>());
+ out.emplace_back(std::make_unique<BAA>());
+ out.emplace_back(std::make_unique<BAB>());
+ out.emplace_back(std::make_unique<BAC>());
+ out.emplace_back(std::make_unique<BB>());
+ out.emplace_back(std::make_unique<BBA>());
+ out.emplace_back(std::make_unique<BBB>());
+ out.emplace_back(std::make_unique<BBC>());
+ out.emplace_back(std::make_unique<BC>());
+ out.emplace_back(std::make_unique<BCA>());
+ out.emplace_back(std::make_unique<BCB>());
+ out.emplace_back(std::make_unique<BCC>());
+ out.emplace_back(std::make_unique<C>());
+ out.emplace_back(std::make_unique<CA>());
+ out.emplace_back(std::make_unique<CAA>());
+ out.emplace_back(std::make_unique<CAB>());
+ out.emplace_back(std::make_unique<CAC>());
+ out.emplace_back(std::make_unique<CB>());
+ out.emplace_back(std::make_unique<CBA>());
+ out.emplace_back(std::make_unique<CBB>());
+ out.emplace_back(std::make_unique<CBC>());
+ out.emplace_back(std::make_unique<CC>());
+ out.emplace_back(std::make_unique<CCA>());
+ out.emplace_back(std::make_unique<CCB>());
+ out.emplace_back(std::make_unique<CCC>());
+ return out;
+}
+
+void CastableLargeSwitch(::benchmark::State& state) {
+ auto objects = MakeObjects();
+ size_t i = 0;
+ for (auto _ : state) {
+ auto* object = objects[i % objects.size()].get();
+ Switch(
+ object, //
+ [&](const AAA*) { ::benchmark::DoNotOptimize(i += 40); },
+ [&](const AAB*) { ::benchmark::DoNotOptimize(i += 50); },
+ [&](const AAC*) { ::benchmark::DoNotOptimize(i += 60); },
+ [&](const ABA*) { ::benchmark::DoNotOptimize(i += 80); },
+ [&](const ABB*) { ::benchmark::DoNotOptimize(i += 90); },
+ [&](const ABC*) { ::benchmark::DoNotOptimize(i += 100); },
+ [&](const ACA*) { ::benchmark::DoNotOptimize(i += 120); },
+ [&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
+ [&](const ACC*) { ::benchmark::DoNotOptimize(i += 140); },
+ [&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
+ [&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
+ [&](const BAC*) { ::benchmark::DoNotOptimize(i += 190); },
+ [&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
+ [&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
+ [&](const BBC*) { ::benchmark::DoNotOptimize(i += 230); },
+ [&](const BCA*) { ::benchmark::DoNotOptimize(i += 250); },
+ [&](const BCB*) { ::benchmark::DoNotOptimize(i += 260); },
+ [&](const BCC*) { ::benchmark::DoNotOptimize(i += 270); },
+ [&](const CA*) { ::benchmark::DoNotOptimize(i += 290); },
+ [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
+ [&](const CAB*) { ::benchmark::DoNotOptimize(i += 310); },
+ [&](const CAC*) { ::benchmark::DoNotOptimize(i += 320); },
+ [&](const CBA*) { ::benchmark::DoNotOptimize(i += 340); },
+ [&](const CBB*) { ::benchmark::DoNotOptimize(i += 350); },
+ [&](const CBC*) { ::benchmark::DoNotOptimize(i += 360); },
+ [&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
+ [&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
+ [&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
+ [&](Default) { ::benchmark::DoNotOptimize(i += 123); });
+ i = (i * 31) ^ (i << 5);
+ }
+}
+
+BENCHMARK(CastableLargeSwitch);
+
+void CastableMediumSwitch(::benchmark::State& state) {
+ auto objects = MakeObjects();
+ size_t i = 0;
+ for (auto _ : state) {
+ auto* object = objects[i % objects.size()].get();
+ Switch(
+ object, //
+ [&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
+ [&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
+ [&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
+ [&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
+ [&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
+ [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
+ [&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
+ [&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
+ [&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
+ [&](Default) { ::benchmark::DoNotOptimize(i += 123); });
+ i = (i * 31) ^ (i << 5);
+ }
+}
+
+BENCHMARK(CastableMediumSwitch);
+
+void CastableSmallSwitch(::benchmark::State& state) {
+ auto objects = MakeObjects();
+ size_t i = 0;
+ for (auto _ : state) {
+ auto* object = objects[i % objects.size()].get();
+ Switch(
+ object, //
+ [&](const AAB*) { ::benchmark::DoNotOptimize(i += 30); },
+ [&](const CAC*) { ::benchmark::DoNotOptimize(i += 290); },
+ [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); });
+ i = (i * 31) ^ (i << 5);
+ }
+}
+
+BENCHMARK(CastableSmallSwitch);
+
+} // namespace
+} // namespace tint
+
+TINT_INSTANTIATE_TYPEINFO(tint::Base);
+TINT_INSTANTIATE_TYPEINFO(tint::A);
+TINT_INSTANTIATE_TYPEINFO(tint::AA);
+TINT_INSTANTIATE_TYPEINFO(tint::AAA);
+TINT_INSTANTIATE_TYPEINFO(tint::AAB);
+TINT_INSTANTIATE_TYPEINFO(tint::AAC);
+TINT_INSTANTIATE_TYPEINFO(tint::AB);
+TINT_INSTANTIATE_TYPEINFO(tint::ABA);
+TINT_INSTANTIATE_TYPEINFO(tint::ABB);
+TINT_INSTANTIATE_TYPEINFO(tint::ABC);
+TINT_INSTANTIATE_TYPEINFO(tint::AC);
+TINT_INSTANTIATE_TYPEINFO(tint::ACA);
+TINT_INSTANTIATE_TYPEINFO(tint::ACB);
+TINT_INSTANTIATE_TYPEINFO(tint::ACC);
+TINT_INSTANTIATE_TYPEINFO(tint::B);
+TINT_INSTANTIATE_TYPEINFO(tint::BA);
+TINT_INSTANTIATE_TYPEINFO(tint::BAA);
+TINT_INSTANTIATE_TYPEINFO(tint::BAB);
+TINT_INSTANTIATE_TYPEINFO(tint::BAC);
+TINT_INSTANTIATE_TYPEINFO(tint::BB);
+TINT_INSTANTIATE_TYPEINFO(tint::BBA);
+TINT_INSTANTIATE_TYPEINFO(tint::BBB);
+TINT_INSTANTIATE_TYPEINFO(tint::BBC);
+TINT_INSTANTIATE_TYPEINFO(tint::BC);
+TINT_INSTANTIATE_TYPEINFO(tint::BCA);
+TINT_INSTANTIATE_TYPEINFO(tint::BCB);
+TINT_INSTANTIATE_TYPEINFO(tint::BCC);
+TINT_INSTANTIATE_TYPEINFO(tint::C);
+TINT_INSTANTIATE_TYPEINFO(tint::CA);
+TINT_INSTANTIATE_TYPEINFO(tint::CAA);
+TINT_INSTANTIATE_TYPEINFO(tint::CAB);
+TINT_INSTANTIATE_TYPEINFO(tint::CAC);
+TINT_INSTANTIATE_TYPEINFO(tint::CB);
+TINT_INSTANTIATE_TYPEINFO(tint::CBA);
+TINT_INSTANTIATE_TYPEINFO(tint::CBB);
+TINT_INSTANTIATE_TYPEINFO(tint::CBC);
+TINT_INSTANTIATE_TYPEINFO(tint::CC);
+TINT_INSTANTIATE_TYPEINFO(tint::CCA);
+TINT_INSTANTIATE_TYPEINFO(tint::CCB);
+TINT_INSTANTIATE_TYPEINFO(tint::CCC);
diff --git a/src/castable_test.cc b/src/castable_test.cc
index e44983b..2a9a71a 100644
--- a/src/castable_test.cc
+++ b/src/castable_test.cc
@@ -252,6 +252,151 @@
ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get()));
}
+TEST(Castable, SwitchNoDefault) {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+ std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+ {
+ bool frog_matched_amphibian = false;
+ Switch(
+ frog.get(), //
+ [&](Reptile*) { FAIL() << "frog is not reptile"; },
+ [&](Mammal*) { FAIL() << "frog is not mammal"; },
+ [&](Amphibian* amphibian) {
+ EXPECT_EQ(amphibian, frog.get());
+ frog_matched_amphibian = true;
+ });
+ EXPECT_TRUE(frog_matched_amphibian);
+ }
+ {
+ bool bear_matched_mammal = false;
+ Switch(
+ bear.get(), //
+ [&](Reptile*) { FAIL() << "bear is not reptile"; },
+ [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+ [&](Mammal* mammal) {
+ EXPECT_EQ(mammal, bear.get());
+ bear_matched_mammal = true;
+ });
+ EXPECT_TRUE(bear_matched_mammal);
+ }
+ {
+ bool gecko_matched_reptile = false;
+ Switch(
+ gecko.get(), //
+ [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+ [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+ [&](Reptile* reptile) {
+ EXPECT_EQ(reptile, gecko.get());
+ gecko_matched_reptile = true;
+ });
+ EXPECT_TRUE(gecko_matched_reptile);
+ }
+}
+
+TEST(Castable, SwitchWithUnusedDefault) {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+ std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+ {
+ bool frog_matched_amphibian = false;
+ Switch(
+ frog.get(), //
+ [&](Reptile*) { FAIL() << "frog is not reptile"; },
+ [&](Mammal*) { FAIL() << "frog is not mammal"; },
+ [&](Amphibian* amphibian) {
+ EXPECT_EQ(amphibian, frog.get());
+ frog_matched_amphibian = true;
+ },
+ [&](Default) { FAIL() << "default should not have been selected"; });
+ EXPECT_TRUE(frog_matched_amphibian);
+ }
+ {
+ bool bear_matched_mammal = false;
+ Switch(
+ bear.get(), //
+ [&](Reptile*) { FAIL() << "bear is not reptile"; },
+ [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+ [&](Mammal* mammal) {
+ EXPECT_EQ(mammal, bear.get());
+ bear_matched_mammal = true;
+ },
+ [&](Default) { FAIL() << "default should not have been selected"; });
+ EXPECT_TRUE(bear_matched_mammal);
+ }
+ {
+ bool gecko_matched_reptile = false;
+ Switch(
+ gecko.get(), //
+ [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+ [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+ [&](Reptile* reptile) {
+ EXPECT_EQ(reptile, gecko.get());
+ gecko_matched_reptile = true;
+ },
+ [&](Default) { FAIL() << "default should not have been selected"; });
+ EXPECT_TRUE(gecko_matched_reptile);
+ }
+}
+
+TEST(Castable, SwitchDefault) {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+ std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+ {
+ bool frog_matched_default = false;
+ Switch(
+ frog.get(), //
+ [&](Reptile*) { FAIL() << "frog is not reptile"; },
+ [&](Mammal*) { FAIL() << "frog is not mammal"; },
+ [&](Default) { frog_matched_default = true; });
+ EXPECT_TRUE(frog_matched_default);
+ }
+ {
+ bool bear_matched_default = false;
+ Switch(
+ bear.get(), //
+ [&](Reptile*) { FAIL() << "bear is not reptile"; },
+ [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+ [&](Default) { bear_matched_default = true; });
+ EXPECT_TRUE(bear_matched_default);
+ }
+ {
+ bool gecko_matched_default = false;
+ Switch(
+ gecko.get(), //
+ [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+ [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+ [&](Default) { gecko_matched_default = true; });
+ EXPECT_TRUE(gecko_matched_default);
+ }
+}
+TEST(Castable, SwitchMatchFirst) {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ {
+ bool frog_matched_animal = false;
+ Switch(
+ frog.get(),
+ [&](Animal* animal) {
+ EXPECT_EQ(animal, frog.get());
+ frog_matched_animal = true;
+ },
+ [&](Amphibian*) { FAIL() << "animal should have been matched first"; });
+ EXPECT_TRUE(frog_matched_animal);
+ }
+ {
+ bool frog_matched_amphibian = false;
+ Switch(
+ frog.get(),
+ [&](Amphibian* amphibain) {
+ EXPECT_EQ(amphibain, frog.get());
+ frog_matched_amphibian = true;
+ },
+ [&](Animal*) { FAIL() << "amphibian should have been matched first"; });
+ EXPECT_TRUE(frog_matched_amphibian);
+ }
+}
+
} // namespace
TINT_INSTANTIATE_TYPEINFO(Animal);
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index b88916e..ed35ca4 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -953,7 +953,7 @@
bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* var_type,
- ast::AttributeList* decos,
+ ast::AttributeList* attrs,
std::vector<int> index_prefix,
const Type* tip_type,
const Type* forced_param_type,
@@ -966,105 +966,121 @@
}
// Recursively flatten matrices, arrays, and structures.
- if (auto* matrix_type = tip_type->As<Matrix>()) {
- index_prefix.push_back(0);
- const auto num_columns = static_cast<int>(matrix_type->columns);
- const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
- for (int col = 0; col < num_columns; col++) {
- index_prefix.back() = col;
- if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty,
- forced_param_type, params, statements)) {
- return false;
- }
- }
- return success();
- } else if (auto* array_type = tip_type->As<Array>()) {
- if (array_type->size == 0) {
- return Fail() << "runtime-size array not allowed on pipeline IO";
- }
- index_prefix.push_back(0);
- const Type* elem_ty = array_type->type;
- for (int i = 0; i < static_cast<int>(array_type->size); i++) {
- index_prefix.back() = i;
- if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty,
- forced_param_type, params, statements)) {
- return false;
- }
- }
- return success();
- } else if (auto* struct_type = tip_type->As<Struct>()) {
- const auto& members = struct_type->members;
- index_prefix.push_back(0);
- for (int i = 0; i < static_cast<int>(members.size()); ++i) {
- index_prefix.back() = i;
- ast::AttributeList member_decos(*decos);
- if (!parser_impl_.ConvertPipelineDecorations(
- struct_type,
- parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
- &member_decos)) {
- return false;
- }
- if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix,
- members[i], forced_param_type, params,
- statements)) {
- return false;
- }
- // Copy the location as updated by nested expansion of the member.
- parser_impl_.SetLocation(decos, GetLocation(member_decos));
- }
- return success();
- }
+ return Switch(
+ tip_type,
+ [&](const Matrix* matrix_type) -> bool {
+ index_prefix.push_back(0);
+ const auto num_columns = static_cast<int>(matrix_type->columns);
+ const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
+ for (int col = 0; col < num_columns; col++) {
+ index_prefix.back() = col;
+ if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
+ vec_ty, forced_param_type, params,
+ statements)) {
+ return false;
+ }
+ }
+ return success();
+ },
+ [&](const Array* array_type) -> bool {
+ if (array_type->size == 0) {
+ return Fail() << "runtime-size array not allowed on pipeline IO";
+ }
+ index_prefix.push_back(0);
+ const Type* elem_ty = array_type->type;
+ for (int i = 0; i < static_cast<int>(array_type->size); i++) {
+ index_prefix.back() = i;
+ if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
+ elem_ty, forced_param_type, params,
+ statements)) {
+ return false;
+ }
+ }
+ return success();
+ },
+ [&](const Struct* struct_type) -> bool {
+ const auto& members = struct_type->members;
+ index_prefix.push_back(0);
+ for (int i = 0; i < static_cast<int>(members.size()); ++i) {
+ index_prefix.back() = i;
+ ast::AttributeList member_attrs(*attrs);
+ if (!parser_impl_.ConvertPipelineDecorations(
+ struct_type,
+ parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
+ &member_attrs)) {
+ return false;
+ }
+ if (!EmitPipelineInput(var_name, var_type, &member_attrs,
+ index_prefix, members[i], forced_param_type,
+ params, statements)) {
+ return false;
+ }
+ // Copy the location as updated by nested expansion of the member.
+ parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
+ }
+ return success();
+ },
+ [&](Default) {
+ const bool is_builtin =
+ ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
- const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
+ const Type* param_type = is_builtin ? forced_param_type : tip_type;
- const Type* param_type = is_builtin ? forced_param_type : tip_type;
+ const auto param_name = namer_.MakeDerivedName(var_name + "_param");
+ // Create the parameter.
+ // TODO(dneto): Note: If the parameter has non-location decorations,
+ // then those decoration AST nodes will be reused between multiple
+ // elements of a matrix, array, or structure. Normally that's
+ // disallowed but currently the SPIR-V reader will make duplicates when
+ // the entire AST is cloned at the top level of the SPIR-V reader flow.
+ // Consider rewriting this to avoid this node-sharing.
+ params->push_back(
+ builder_.Param(param_name, param_type->Build(builder_), *attrs));
- const auto param_name = namer_.MakeDerivedName(var_name + "_param");
- // Create the parameter.
- // TODO(dneto): Note: If the parameter has non-location decorations,
- // then those decoration AST nodes will be reused between multiple elements
- // of a matrix, array, or structure. Normally that's disallowed but currently
- // the SPIR-V reader will make duplicates when the entire AST is cloned
- // at the top level of the SPIR-V reader flow. Consider rewriting this
- // to avoid this node-sharing.
- params->push_back(
- builder_.Param(param_name, param_type->Build(builder_), *decos));
+ // Add a body statement to copy the parameter to the corresponding
+ // private variable.
+ const ast::Expression* param_value = builder_.Expr(param_name);
+ const ast::Expression* store_dest = builder_.Expr(var_name);
- // Add a body statement to copy the parameter to the corresponding private
- // variable.
- const ast::Expression* param_value = builder_.Expr(param_name);
- const ast::Expression* store_dest = builder_.Expr(var_name);
+ // Index into the LHS as needed.
+ auto* current_type =
+ var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
+ for (auto index : index_prefix) {
+ Switch(
+ current_type,
+ [&](const Matrix* matrix_type) {
+ store_dest =
+ builder_.IndexAccessor(store_dest, builder_.Expr(index));
+ current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
+ },
+ [&](const Array* array_type) {
+ store_dest =
+ builder_.IndexAccessor(store_dest, builder_.Expr(index));
+ current_type = array_type->type->UnwrapAlias();
+ },
+ [&](const Struct* struct_type) {
+ store_dest = builder_.MemberAccessor(
+ store_dest, builder_.Expr(parser_impl_.GetMemberName(
+ *struct_type, index)));
+ current_type = struct_type->members[index];
+ });
+ }
- // Index into the LHS as needed.
- auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
- for (auto index : index_prefix) {
- if (auto* matrix_type = current_type->As<Matrix>()) {
- store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
- current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
- } else if (auto* array_type = current_type->As<Array>()) {
- store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
- current_type = array_type->type->UnwrapAlias();
- } else if (auto* struct_type = current_type->As<Struct>()) {
- store_dest = builder_.MemberAccessor(
- store_dest,
- builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
- current_type = struct_type->members[index];
- }
- }
+ if (is_builtin && (tip_type != forced_param_type)) {
+ // The parameter will have the WGSL type, but we need bitcast to
+ // the variable store type.
+ param_value = create<ast::BitcastExpression>(
+ tip_type->Build(builder_), param_value);
+ }
- if (is_builtin && (tip_type != forced_param_type)) {
- // The parameter will have the WGSL type, but we need bitcast to
- // the variable store type.
- param_value =
- create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
- }
+ statements->push_back(builder_.Assign(store_dest, param_value));
- statements->push_back(builder_.Assign(store_dest, param_value));
+ // Increment the location attribute, in case more parameters will
+ // follow.
+ IncrementLocation(attrs);
- // Increment the location attribute, in case more parameters will follow.
- IncrementLocation(decos);
-
- return success();
+ return success();
+ });
}
void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
@@ -1102,106 +1118,120 @@
}
// Recursively flatten matrices, arrays, and structures.
- if (auto* matrix_type = tip_type->As<Matrix>()) {
- index_prefix.push_back(0);
- const auto num_columns = static_cast<int>(matrix_type->columns);
- const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
- for (int col = 0; col < num_columns; col++) {
- index_prefix.back() = col;
- if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
- forced_member_type, return_members,
- return_exprs)) {
- return false;
- }
- }
- return success();
- } else if (auto* array_type = tip_type->As<Array>()) {
- if (array_type->size == 0) {
- return Fail() << "runtime-size array not allowed on pipeline IO";
- }
- index_prefix.push_back(0);
- const Type* elem_ty = array_type->type;
- for (int i = 0; i < static_cast<int>(array_type->size); i++) {
- index_prefix.back() = i;
- if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
- forced_member_type, return_members,
- return_exprs)) {
- return false;
- }
- }
- return success();
- } else if (auto* struct_type = tip_type->As<Struct>()) {
- const auto& members = struct_type->members;
- index_prefix.push_back(0);
- for (int i = 0; i < static_cast<int>(members.size()); ++i) {
- index_prefix.back() = i;
- ast::AttributeList member_decos(*decos);
- if (!parser_impl_.ConvertPipelineDecorations(
- struct_type,
- parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
- &member_decos)) {
- return false;
- }
- if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix,
- members[i], forced_member_type, return_members,
- return_exprs)) {
- return false;
- }
- // Copy the location as updated by nested expansion of the member.
- parser_impl_.SetLocation(decos, GetLocation(member_decos));
- }
- return success();
- }
+ return Switch(
+ tip_type,
+ [&](const Matrix* matrix_type) -> bool {
+ index_prefix.push_back(0);
+ const auto num_columns = static_cast<int>(matrix_type->columns);
+ const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
+ for (int col = 0; col < num_columns; col++) {
+ index_prefix.back() = col;
+ if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
+ vec_ty, forced_member_type, return_members,
+ return_exprs)) {
+ return false;
+ }
+ }
+ return success();
+ },
+ [&](const Array* array_type) -> bool {
+ if (array_type->size == 0) {
+ return Fail() << "runtime-size array not allowed on pipeline IO";
+ }
+ index_prefix.push_back(0);
+ const Type* elem_ty = array_type->type;
+ for (int i = 0; i < static_cast<int>(array_type->size); i++) {
+ index_prefix.back() = i;
+ if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
+ elem_ty, forced_member_type, return_members,
+ return_exprs)) {
+ return false;
+ }
+ }
+ return success();
+ },
+ [&](const Struct* struct_type) -> bool {
+ const auto& members = struct_type->members;
+ index_prefix.push_back(0);
+ for (int i = 0; i < static_cast<int>(members.size()); ++i) {
+ index_prefix.back() = i;
+ ast::AttributeList member_attrs(*decos);
+ if (!parser_impl_.ConvertPipelineDecorations(
+ struct_type,
+ parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
+ &member_attrs)) {
+ return false;
+ }
+ if (!EmitPipelineOutput(var_name, var_type, &member_attrs,
+ index_prefix, members[i], forced_member_type,
+ return_members, return_exprs)) {
+ return false;
+ }
+ // Copy the location as updated by nested expansion of the member.
+ parser_impl_.SetLocation(decos, GetLocation(member_attrs));
+ }
+ return success();
+ },
+ [&](Default) {
+ const bool is_builtin =
+ ast::HasAttribute<ast::BuiltinAttribute>(*decos);
- const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
+ const Type* member_type = is_builtin ? forced_member_type : tip_type;
+ // Derive the member name directly from the variable name. They can't
+ // collide.
+ const auto member_name = namer_.MakeDerivedName(var_name);
+ // Create the member.
+ // TODO(dneto): Note: If the parameter has non-location decorations,
+ // then those decoration AST nodes will be reused between multiple
+ // elements of a matrix, array, or structure. Normally that's
+ // disallowed but currently the SPIR-V reader will make duplicates when
+ // the entire AST is cloned at the top level of the SPIR-V reader flow.
+ // Consider rewriting this to avoid this node-sharing.
+ return_members->push_back(
+ builder_.Member(member_name, member_type->Build(builder_), *decos));
- const Type* member_type = is_builtin ? forced_member_type : tip_type;
- // Derive the member name directly from the variable name. They can't
- // collide.
- const auto member_name = namer_.MakeDerivedName(var_name);
- // Create the member.
- // TODO(dneto): Note: If the parameter has non-location decorations,
- // then those decoration AST nodes will be reused between multiple elements
- // of a matrix, array, or structure. Normally that's disallowed but currently
- // the SPIR-V reader will make duplicates when the entire AST is cloned
- // at the top level of the SPIR-V reader flow. Consider rewriting this
- // to avoid this node-sharing.
- return_members->push_back(
- builder_.Member(member_name, member_type->Build(builder_), *decos));
+ // Create an expression to evaluate the part of the variable indexed by
+ // the index_prefix.
+ const ast::Expression* load_source = builder_.Expr(var_name);
- // Create an expression to evaluate the part of the variable indexed by
- // the index_prefix.
- const ast::Expression* load_source = builder_.Expr(var_name);
+ // Index into the variable as needed to pick out the flattened member.
+ auto* current_type =
+ var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
+ for (auto index : index_prefix) {
+ Switch(
+ current_type,
+ [&](const Matrix* matrix_type) {
+ load_source =
+ builder_.IndexAccessor(load_source, builder_.Expr(index));
+ current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
+ },
+ [&](const Array* array_type) {
+ load_source =
+ builder_.IndexAccessor(load_source, builder_.Expr(index));
+ current_type = array_type->type->UnwrapAlias();
+ },
+ [&](const Struct* struct_type) {
+ load_source = builder_.MemberAccessor(
+ load_source, builder_.Expr(parser_impl_.GetMemberName(
+ *struct_type, index)));
+ current_type = struct_type->members[index];
+ });
+ }
- // Index into the variable as needed to pick out the flattened member.
- auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
- for (auto index : index_prefix) {
- if (auto* matrix_type = current_type->As<Matrix>()) {
- load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
- current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
- } else if (auto* array_type = current_type->As<Array>()) {
- load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
- current_type = array_type->type->UnwrapAlias();
- } else if (auto* struct_type = current_type->As<Struct>()) {
- load_source = builder_.MemberAccessor(
- load_source,
- builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
- current_type = struct_type->members[index];
- }
- }
+ if (is_builtin && (tip_type != forced_member_type)) {
+ // The member will have the WGSL type, but we need bitcast to
+ // the variable store type.
+ load_source = create<ast::BitcastExpression>(
+ forced_member_type->Build(builder_), load_source);
+ }
+ return_exprs->push_back(load_source);
- if (is_builtin && (tip_type != forced_member_type)) {
- // The member will have the WGSL type, but we need bitcast to
- // the variable store type.
- load_source = create<ast::BitcastExpression>(
- forced_member_type->Build(builder_), load_source);
- }
- return_exprs->push_back(load_source);
+ // Increment the location attribute, in case more parameters will
+ // follow.
+ IncrementLocation(decos);
- // Increment the location attribute, in case more parameters will follow.
- IncrementLocation(decos);
-
- return success();
+ return success();
+ });
}
bool FunctionEmitter::EmitEntryPointAsWrapper() {
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 6d20ace..8502427 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -239,39 +239,41 @@
}
last_kind = kind;
- if (auto* global = decl->As<ast::Variable>()) {
- if (!EmitGlobalVariable(global)) {
- return false;
- }
- } else if (auto* str = decl->As<ast::Struct>()) {
- auto* ty = builder_.Sem().Get(str);
- auto storage_class_uses = ty->StorageClassUsage();
- if (storage_class_uses.size() !=
- (storage_class_uses.count(ast::StorageClass::kStorage) +
- storage_class_uses.count(ast::StorageClass::kUniform))) {
- // The structure is used as something other than a storage buffer or
- // uniform buffer, so it needs to be emitted.
- // Storage buffer are read and written to via a ByteAddressBuffer
- // instead of true structure.
- // Structures used as uniform buffer are read from an array of vectors
- // instead of true structure.
- if (!EmitStructType(current_buffer_, ty)) {
+ bool ok = Switch(
+ decl,
+ [&](const ast::Variable* global) { //
+ return EmitGlobalVariable(global);
+ },
+ [&](const ast::Struct* str) {
+ auto* ty = builder_.Sem().Get(str);
+ auto storage_class_uses = ty->StorageClassUsage();
+ if (storage_class_uses.size() !=
+ (storage_class_uses.count(ast::StorageClass::kStorage) +
+ storage_class_uses.count(ast::StorageClass::kUniform))) {
+ // The structure is used as something other than a storage buffer or
+ // uniform buffer, so it needs to be emitted.
+ // Storage buffer are read and written to via a ByteAddressBuffer
+ // instead of true structure.
+ // Structures used as uniform buffer are read from an array of
+ // vectors instead of true structure.
+ return EmitStructType(current_buffer_, ty);
+ }
+ return true;
+ },
+ [&](const ast::Function* func) {
+ if (func->IsEntryPoint()) {
+ return EmitEntryPointFunction(func);
+ }
+ return EmitFunction(func);
+ },
+ [&](Default) {
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled module-scope declaration: "
+ << decl->TypeInfo().name;
return false;
- }
- }
- } else if (auto* func = decl->As<ast::Function>()) {
- if (func->IsEntryPoint()) {
- if (!EmitEntryPointFunction(func)) {
- return false;
- }
- } else {
- if (!EmitFunction(func)) {
- return false;
- }
- }
- } else {
- TINT_ICE(Writer, diagnostics_)
- << "unhandled module-scope declaration: " << decl->TypeInfo().name;
+ });
+
+ if (!ok) {
return false;
}
}
@@ -929,22 +931,25 @@
const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
-
- if (auto* func = target->As<sem::Function>()) {
- return EmitFunctionCall(out, call, func);
- }
- if (auto* builtin = target->As<sem::Builtin>()) {
- return EmitBuiltinCall(out, call, builtin);
- }
- if (auto* conv = target->As<sem::TypeConversion>()) {
- return EmitTypeConversion(out, call, conv);
- }
- if (auto* ctor = target->As<sem::TypeConstructor>()) {
- return EmitTypeConstructor(out, call, ctor);
- }
- TINT_ICE(Writer, diagnostics_)
- << "unhandled call target: " << target->TypeInfo().name;
- return false;
+ return Switch(
+ target,
+ [&](const sem::Function* func) {
+ return EmitFunctionCall(out, call, func);
+ },
+ [&](const sem::Builtin* builtin) {
+ return EmitBuiltinCall(out, call, builtin);
+ },
+ [&](const sem::TypeConversion* conv) {
+ return EmitTypeConversion(out, call, conv);
+ },
+ [&](const sem::TypeConstructor* ctor) {
+ return EmitTypeConstructor(out, call, ctor);
+ },
+ [&](Default) {
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled call target: " << target->TypeInfo().name;
+ return false;
+ });
}
bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@@ -2639,35 +2644,38 @@
bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) {
- if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
- return EmitIndexAccessor(out, a);
- }
- if (auto* b = expr->As<ast::BinaryExpression>()) {
- return EmitBinary(out, b);
- }
- if (auto* b = expr->As<ast::BitcastExpression>()) {
- return EmitBitcast(out, b);
- }
- if (auto* c = expr->As<ast::CallExpression>()) {
- return EmitCall(out, c);
- }
- if (auto* i = expr->As<ast::IdentifierExpression>()) {
- return EmitIdentifier(out, i);
- }
- if (auto* l = expr->As<ast::LiteralExpression>()) {
- return EmitLiteral(out, l);
- }
- if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
- return EmitMemberAccessor(out, m);
- }
- if (auto* u = expr->As<ast::UnaryOpExpression>()) {
- return EmitUnaryOp(out, u);
- }
-
- diagnostics_.add_error(
- diag::System::Writer,
- "unknown expression type: " + std::string(expr->TypeInfo().name));
- return false;
+ return Switch(
+ expr,
+ [&](const ast::IndexAccessorExpression* a) { //
+ return EmitIndexAccessor(out, a);
+ },
+ [&](const ast::BinaryExpression* b) { //
+ return EmitBinary(out, b);
+ },
+ [&](const ast::BitcastExpression* b) { //
+ return EmitBitcast(out, b);
+ },
+ [&](const ast::CallExpression* c) { //
+ return EmitCall(out, c);
+ },
+ [&](const ast::IdentifierExpression* i) { //
+ return EmitIdentifier(out, i);
+ },
+ [&](const ast::LiteralExpression* l) { //
+ return EmitLiteral(out, l);
+ },
+ [&](const ast::MemberAccessorExpression* m) { //
+ return EmitMemberAccessor(out, m);
+ },
+ [&](const ast::UnaryOpExpression* u) { //
+ return EmitUnaryOp(out, u);
+ },
+ [&](Default) { //
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown expression type: " + std::string(expr->TypeInfo().name));
+ return false;
+ });
}
bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@@ -3127,80 +3135,108 @@
bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) {
- if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
- out << (l->value ? "true" : "false");
- } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
- if (std::isinf(fl->value)) {
- out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
- } else if (std::isnan(fl->value)) {
- out << "asfloat(0x7fc00000u)";
- } else {
- out << FloatToString(fl->value) << "f";
- }
- } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
- out << sl->value;
- } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
- out << ul->value << "u";
- } else {
- diagnostics_.add_error(diag::System::Writer, "unknown literal type");
- return false;
- }
- return true;
+ return Switch(
+ lit,
+ [&](const ast::BoolLiteralExpression* l) {
+ out << (l->value ? "true" : "false");
+ return true;
+ },
+ [&](const ast::FloatLiteralExpression* fl) {
+ if (std::isinf(fl->value)) {
+ out << (fl->value >= 0 ? "asfloat(0x7f800000u)"
+ : "asfloat(0xff800000u)");
+ } else if (std::isnan(fl->value)) {
+ out << "asfloat(0x7fc00000u)";
+ } else {
+ out << FloatToString(fl->value) << "f";
+ }
+ return true;
+ },
+ [&](const ast::SintLiteralExpression* sl) {
+ out << sl->value;
+ return true;
+ },
+ [&](const ast::UintLiteralExpression* ul) {
+ out << ul->value << "u";
+ return true;
+ },
+ [&](Default) {
+ diagnostics_.add_error(diag::System::Writer, "unknown literal type");
+ return false;
+ });
}
bool GeneratorImpl::EmitValue(std::ostream& out,
const sem::Type* type,
int value) {
- if (type->Is<sem::Bool>()) {
- out << (value == 0 ? "false" : "true");
- } else if (type->Is<sem::F32>()) {
- out << value << ".0f";
- } else if (type->Is<sem::I32>()) {
- out << value;
- } else if (type->Is<sem::U32>()) {
- out << value << "u";
- } else if (auto* vec = type->As<sem::Vector>()) {
- if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
- "")) {
- return false;
- }
- ScopedParen sp(out);
- for (uint32_t i = 0; i < vec->Width(); i++) {
- if (i != 0) {
- out << ", ";
- }
- if (!EmitValue(out, vec->type(), value)) {
+ return Switch(
+ type,
+ [&](const sem::Bool*) {
+ out << (value == 0 ? "false" : "true");
+ return true;
+ },
+ [&](const sem::F32*) {
+ out << value << ".0f";
+ return true;
+ },
+ [&](const sem::I32*) {
+ out << value;
+ return true;
+ },
+ [&](const sem::U32*) {
+ out << value << "u";
+ return true;
+ },
+ [&](const sem::Vector* vec) {
+ if (!EmitType(out, type, ast::StorageClass::kNone,
+ ast::Access::kReadWrite, "")) {
+ return false;
+ }
+ ScopedParen sp(out);
+ for (uint32_t i = 0; i < vec->Width(); i++) {
+ if (i != 0) {
+ out << ", ";
+ }
+ if (!EmitValue(out, vec->type(), value)) {
+ return false;
+ }
+ }
+ return true;
+ },
+ [&](const sem::Matrix* mat) {
+ if (!EmitType(out, type, ast::StorageClass::kNone,
+ ast::Access::kReadWrite, "")) {
+ return false;
+ }
+ ScopedParen sp(out);
+ for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
+ if (i != 0) {
+ out << ", ";
+ }
+ if (!EmitValue(out, mat->type(), value)) {
+ return false;
+ }
+ }
+ return true;
+ },
+ [&](const sem::Struct*) {
+ out << "(";
+ TINT_DEFER(out << ")" << value);
+ return EmitType(out, type, ast::StorageClass::kNone,
+ ast::Access::kUndefined, "");
+ },
+ [&](const sem::Array*) {
+ out << "(";
+ TINT_DEFER(out << ")" << value);
+ return EmitType(out, type, ast::StorageClass::kNone,
+ ast::Access::kUndefined, "");
+ },
+ [&](Default) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "Invalid type for value emission: " + type->type_name());
return false;
- }
- }
- } else if (auto* mat = type->As<sem::Matrix>()) {
- if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
- "")) {
- return false;
- }
- ScopedParen sp(out);
- for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
- if (i != 0) {
- out << ", ";
- }
- if (!EmitValue(out, mat->type(), value)) {
- return false;
- }
- }
- } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
- out << "(";
- if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
- "")) {
- return false;
- }
- out << ")" << value;
- } else {
- diagnostics_.add_error(
- diag::System::Writer,
- "Invalid type for value emission: " + type->type_name());
- return false;
- }
- return true;
+ });
}
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
@@ -3375,56 +3411,59 @@
}
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
- if (auto* a = stmt->As<ast::AssignmentStatement>()) {
- return EmitAssign(a);
- }
- if (auto* b = stmt->As<ast::BlockStatement>()) {
- return EmitBlock(b);
- }
- if (auto* b = stmt->As<ast::BreakStatement>()) {
- return EmitBreak(b);
- }
- if (auto* c = stmt->As<ast::CallStatement>()) {
- auto out = line();
- if (!EmitCall(out, c->expr)) {
- return false;
- }
- out << ";";
- return true;
- }
- if (auto* c = stmt->As<ast::ContinueStatement>()) {
- return EmitContinue(c);
- }
- if (auto* d = stmt->As<ast::DiscardStatement>()) {
- return EmitDiscard(d);
- }
- if (stmt->As<ast::FallthroughStatement>()) {
- line() << "/* fallthrough */";
- return true;
- }
- if (auto* i = stmt->As<ast::IfStatement>()) {
- return EmitIf(i);
- }
- if (auto* l = stmt->As<ast::LoopStatement>()) {
- return EmitLoop(l);
- }
- if (auto* l = stmt->As<ast::ForLoopStatement>()) {
- return EmitForLoop(l);
- }
- if (auto* r = stmt->As<ast::ReturnStatement>()) {
- return EmitReturn(r);
- }
- if (auto* s = stmt->As<ast::SwitchStatement>()) {
- return EmitSwitch(s);
- }
- if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
- return EmitVariable(v->variable);
- }
-
- diagnostics_.add_error(
- diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- return false;
+ return Switch(
+ stmt,
+ [&](const ast::AssignmentStatement* a) { //
+ return EmitAssign(a);
+ },
+ [&](const ast::BlockStatement* b) { //
+ return EmitBlock(b);
+ },
+ [&](const ast::BreakStatement* b) { //
+ return EmitBreak(b);
+ },
+ [&](const ast::CallStatement* c) { //
+ auto out = line();
+ if (!EmitCall(out, c->expr)) {
+ return false;
+ }
+ out << ";";
+ return true;
+ },
+ [&](const ast::ContinueStatement* c) { //
+ return EmitContinue(c);
+ },
+ [&](const ast::DiscardStatement* d) { //
+ return EmitDiscard(d);
+ },
+ [&](const ast::FallthroughStatement*) { //
+ line() << "/* fallthrough */";
+ return true;
+ },
+ [&](const ast::IfStatement* i) { //
+ return EmitIf(i);
+ },
+ [&](const ast::LoopStatement* l) { //
+ return EmitLoop(l);
+ },
+ [&](const ast::ForLoopStatement* l) { //
+ return EmitForLoop(l);
+ },
+ [&](const ast::ReturnStatement* r) { //
+ return EmitReturn(r);
+ },
+ [&](const ast::SwitchStatement* s) { //
+ return EmitSwitch(s);
+ },
+ [&](const ast::VariableDeclStatement* v) { //
+ return EmitVariable(v->variable);
+ },
+ [&](Default) { //
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown statement type: " + std::string(stmt->TypeInfo().name));
+ return false;
+ });
}
bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
@@ -3516,156 +3555,181 @@
break;
}
- if (auto* ary = type->As<sem::Array>()) {
- const sem::Type* base_type = ary;
- std::vector<uint32_t> sizes;
- while (auto* arr = base_type->As<sem::Array>()) {
- if (arr->IsRuntimeSized()) {
+ return Switch(
+ type,
+ [&](const sem::Array* ary) {
+ const sem::Type* base_type = ary;
+ std::vector<uint32_t> sizes;
+ while (auto* arr = base_type->As<sem::Array>()) {
+ if (arr->IsRuntimeSized()) {
+ TINT_ICE(Writer, diagnostics_)
+ << "Runtime arrays may only exist in storage buffers, which "
+ "should "
+ "have been transformed into a ByteAddressBuffer";
+ return false;
+ }
+ sizes.push_back(arr->Count());
+ base_type = arr->ElemType();
+ }
+ if (!EmitType(out, base_type, storage_class, access, "")) {
+ return false;
+ }
+ if (!name.empty()) {
+ out << " " << name;
+ if (name_printed) {
+ *name_printed = true;
+ }
+ }
+ for (uint32_t size : sizes) {
+ out << "[" << size << "]";
+ }
+ return true;
+ },
+ [&](const sem::Bool*) {
+ out << "bool";
+ return true;
+ },
+ [&](const sem::F32*) {
+ out << "float";
+ return true;
+ },
+ [&](const sem::I32*) {
+ out << "int";
+ return true;
+ },
+ [&](const sem::Matrix* mat) {
+ if (!EmitType(out, mat->type(), storage_class, access, "")) {
+ return false;
+ }
+ // Note: HLSL's matrices are declared as <type>NxM, where N is the
+ // number of rows and M is the number of columns. Despite HLSL's
+ // matrices being column-major by default, the index operator and
+ // constructors actually operate on row-vectors, where as WGSL operates
+ // on column vectors. To simplify everything we use the transpose of the
+ // matrices. See:
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
+ out << mat->columns() << "x" << mat->rows();
+ return true;
+ },
+ [&](const sem::Pointer*) {
TINT_ICE(Writer, diagnostics_)
- << "Runtime arrays may only exist in storage buffers, which should "
- "have been transformed into a ByteAddressBuffer";
+ << "Attempting to emit pointer type. These should have been "
+ "removed with the InlinePointerLets transform";
return false;
- }
- sizes.push_back(arr->Count());
- base_type = arr->ElemType();
- }
- if (!EmitType(out, base_type, storage_class, access, "")) {
- return false;
- }
- if (!name.empty()) {
- out << " " << name;
- if (name_printed) {
- *name_printed = true;
- }
- }
- for (uint32_t size : sizes) {
- out << "[" << size << "]";
- }
- } else if (type->Is<sem::Bool>()) {
- out << "bool";
- } else if (type->Is<sem::F32>()) {
- out << "float";
- } else if (type->Is<sem::I32>()) {
- out << "int";
- } else if (auto* mat = type->As<sem::Matrix>()) {
- if (!EmitType(out, mat->type(), storage_class, access, "")) {
- return false;
- }
- // Note: HLSL's matrices are declared as <type>NxM, where N is the number of
- // rows and M is the number of columns. Despite HLSL's matrices being
- // column-major by default, the index operator and constructors actually
- // operate on row-vectors, where as WGSL operates on column vectors.
- // To simplify everything we use the transpose of the matrices.
- // See:
- // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
- out << mat->columns() << "x" << mat->rows();
- } else if (type->Is<sem::Pointer>()) {
- TINT_ICE(Writer, diagnostics_)
- << "Attempting to emit pointer type. These should have been removed "
- "with the InlinePointerLets transform";
- return false;
- } else if (auto* sampler = type->As<sem::Sampler>()) {
- out << "Sampler";
- if (sampler->IsComparison()) {
- out << "Comparison";
- }
- out << "State";
- } else if (auto* str = type->As<sem::Struct>()) {
- out << StructName(str);
- } else if (auto* tex = type->As<sem::Texture>()) {
- auto* storage = tex->As<sem::StorageTexture>();
- auto* ms = tex->As<sem::MultisampledTexture>();
- auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
- auto* sampled = tex->As<sem::SampledTexture>();
+ },
+ [&](const sem::Sampler* sampler) {
+ out << "Sampler";
+ if (sampler->IsComparison()) {
+ out << "Comparison";
+ }
+ out << "State";
+ return true;
+ },
+ [&](const sem::Struct* str) {
+ out << StructName(str);
+ return true;
+ },
+ [&](const sem::Texture* tex) {
+ auto* storage = tex->As<sem::StorageTexture>();
+ auto* ms = tex->As<sem::MultisampledTexture>();
+ auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
+ auto* sampled = tex->As<sem::SampledTexture>();
- if (storage && storage->access() != ast::Access::kRead) {
- out << "RW";
- }
- out << "Texture";
+ if (storage && storage->access() != ast::Access::kRead) {
+ out << "RW";
+ }
+ out << "Texture";
- switch (tex->dim()) {
- case ast::TextureDimension::k1d:
- out << "1D";
- break;
- case ast::TextureDimension::k2d:
- out << ((ms || depth_ms) ? "2DMS" : "2D");
- break;
- case ast::TextureDimension::k2dArray:
- out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
- break;
- case ast::TextureDimension::k3d:
- out << "3D";
- break;
- case ast::TextureDimension::kCube:
- out << "Cube";
- break;
- case ast::TextureDimension::kCubeArray:
- out << "CubeArray";
- break;
- default:
- TINT_UNREACHABLE(Writer, diagnostics_)
- << "unexpected TextureDimension " << tex->dim();
- return false;
- }
+ switch (tex->dim()) {
+ case ast::TextureDimension::k1d:
+ out << "1D";
+ break;
+ case ast::TextureDimension::k2d:
+ out << ((ms || depth_ms) ? "2DMS" : "2D");
+ break;
+ case ast::TextureDimension::k2dArray:
+ out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
+ break;
+ case ast::TextureDimension::k3d:
+ out << "3D";
+ break;
+ case ast::TextureDimension::kCube:
+ out << "Cube";
+ break;
+ case ast::TextureDimension::kCubeArray:
+ out << "CubeArray";
+ break;
+ default:
+ TINT_UNREACHABLE(Writer, diagnostics_)
+ << "unexpected TextureDimension " << tex->dim();
+ return false;
+ }
- if (storage) {
- auto* component = image_format_to_rwtexture_type(storage->texel_format());
- if (component == nullptr) {
- TINT_ICE(Writer, diagnostics_)
- << "Unsupported StorageTexture TexelFormat: "
- << static_cast<int>(storage->texel_format());
+ if (storage) {
+ auto* component =
+ image_format_to_rwtexture_type(storage->texel_format());
+ if (component == nullptr) {
+ TINT_ICE(Writer, diagnostics_)
+ << "Unsupported StorageTexture TexelFormat: "
+ << static_cast<int>(storage->texel_format());
+ return false;
+ }
+ out << "<" << component << ">";
+ } else if (depth_ms) {
+ out << "<float4>";
+ } else if (sampled || ms) {
+ auto* subtype = sampled ? sampled->type() : ms->type();
+ out << "<";
+ if (subtype->Is<sem::F32>()) {
+ out << "float4";
+ } else if (subtype->Is<sem::I32>()) {
+ out << "int4";
+ } else if (subtype->Is<sem::U32>()) {
+ out << "uint4";
+ } else {
+ TINT_ICE(Writer, diagnostics_)
+ << "Unsupported multisampled texture type";
+ return false;
+ }
+ out << ">";
+ }
+ return true;
+ },
+ [&](const sem::U32*) {
+ out << "uint";
+ return true;
+ },
+ [&](const sem::Vector* vec) {
+ auto width = vec->Width();
+ if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
+ out << "float" << width;
+ } else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
+ out << "int" << width;
+ } else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
+ out << "uint" << width;
+ } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
+ out << "bool" << width;
+ } else {
+ out << "vector<";
+ if (!EmitType(out, vec->type(), storage_class, access, "")) {
+ return false;
+ }
+ out << ", " << width << ">";
+ }
+ return true;
+ },
+ [&](const sem::Atomic* atomic) {
+ return EmitType(out, atomic->Type(), storage_class, access, name);
+ },
+ [&](const sem::Void*) {
+ out << "void";
+ return true;
+ },
+ [&](Default) {
+ diagnostics_.add_error(diag::System::Writer,
+ "unknown type in EmitType");
return false;
- }
- out << "<" << component << ">";
- } else if (depth_ms) {
- out << "<float4>";
- } else if (sampled || ms) {
- auto* subtype = sampled ? sampled->type() : ms->type();
- out << "<";
- if (subtype->Is<sem::F32>()) {
- out << "float4";
- } else if (subtype->Is<sem::I32>()) {
- out << "int4";
- } else if (subtype->Is<sem::U32>()) {
- out << "uint4";
- } else {
- TINT_ICE(Writer, diagnostics_)
- << "Unsupported multisampled texture type";
- return false;
- }
- out << ">";
- }
- } else if (type->Is<sem::U32>()) {
- out << "uint";
- } else if (auto* vec = type->As<sem::Vector>()) {
- auto width = vec->Width();
- if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
- out << "float" << width;
- } else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
- out << "int" << width;
- } else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
- out << "uint" << width;
- } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
- out << "bool" << width;
- } else {
- out << "vector<";
- if (!EmitType(out, vec->type(), storage_class, access, "")) {
- return false;
- }
- out << ", " << width << ">";
- }
- } else if (auto* atomic = type->As<sem::Atomic>()) {
- if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
- return false;
- }
- } else if (type->Is<sem::Void>()) {
- out << "void";
- } else {
- diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
- return false;
- }
-
- return true;
+ });
}
bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index ec2d748..129e969 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -538,23 +538,25 @@
const ast::CallExpression* expr) {
auto* call = program_->Sem().Get(expr);
auto* target = call->Target();
-
- if (auto* func = target->As<sem::Function>()) {
- return EmitFunctionCall(out, call, func);
- }
- if (auto* builtin = target->As<sem::Builtin>()) {
- return EmitBuiltinCall(out, call, builtin);
- }
- if (auto* conv = target->As<sem::TypeConversion>()) {
- return EmitTypeConversion(out, call, conv);
- }
- if (auto* ctor = target->As<sem::TypeConstructor>()) {
- return EmitTypeConstructor(out, call, ctor);
- }
-
- TINT_ICE(Writer, diagnostics_)
- << "unhandled call target: " << target->TypeInfo().name;
- return false;
+ return Switch(
+ target,
+ [&](const sem::Function* func) {
+ return EmitFunctionCall(out, call, func);
+ },
+ [&](const sem::Builtin* builtin) {
+ return EmitBuiltinCall(out, call, builtin);
+ },
+ [&](const sem::TypeConversion* conv) {
+ return EmitTypeConversion(out, call, conv);
+ },
+ [&](const sem::TypeConstructor* ctor) {
+ return EmitTypeConstructor(out, call, ctor);
+ },
+ [&](Default) {
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled call target: " << target->TypeInfo().name;
+ return false;
+ });
}
bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@@ -1476,106 +1478,128 @@
}
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
- if (type->Is<sem::Bool>()) {
- out << "false";
- } else if (type->Is<sem::F32>()) {
- out << "0.0f";
- } else if (type->Is<sem::I32>()) {
- out << "0";
- } else if (type->Is<sem::U32>()) {
- out << "0u";
- } else if (auto* vec = type->As<sem::Vector>()) {
- return EmitZeroValue(out, vec->type());
- } else if (auto* mat = type->As<sem::Matrix>()) {
- if (!EmitType(out, mat, "")) {
- return false;
- }
- out << "(";
- if (!EmitZeroValue(out, mat->type())) {
- return false;
- }
- out << ")";
- } else if (auto* arr = type->As<sem::Array>()) {
- out << "{";
- if (!EmitZeroValue(out, arr->ElemType())) {
- return false;
- }
- out << "}";
- } else if (type->As<sem::Struct>()) {
- out << "{}";
- } else {
- diagnostics_.add_error(
- diag::System::Writer,
- "Invalid type for zero emission: " + type->type_name());
- return false;
- }
- return true;
+ return Switch(
+ type,
+ [&](const sem::Bool*) {
+ out << "false";
+ return true;
+ },
+ [&](const sem::F32*) {
+ out << "0.0f";
+ return true;
+ },
+ [&](const sem::I32*) {
+ out << "0";
+ return true;
+ },
+ [&](const sem::U32*) {
+ out << "0u";
+ return true;
+ },
+ [&](const sem::Vector* vec) { //
+ return EmitZeroValue(out, vec->type());
+ },
+ [&](const sem::Matrix* mat) {
+ if (!EmitType(out, mat, "")) {
+ return false;
+ }
+ out << "(";
+ TINT_DEFER(out << ")");
+ return EmitZeroValue(out, mat->type());
+ },
+ [&](const sem::Array* arr) {
+ out << "{";
+ TINT_DEFER(out << "}");
+ return EmitZeroValue(out, arr->ElemType());
+ },
+ [&](const sem::Struct*) {
+ out << "{}";
+ return true;
+ },
+ [&](Default) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "Invalid type for zero emission: " + type->type_name());
+ return false;
+ });
}
bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) {
- if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
- out << (l->value ? "true" : "false");
- } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
- if (std::isinf(fl->value)) {
- out << (fl->value >= 0 ? "INFINITY" : "-INFINITY");
- } else if (std::isnan(fl->value)) {
- out << "NAN";
- } else {
- out << FloatToString(fl->value) << "f";
- }
- } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
- // MSL (and C++) parse `-2147483648` as a `long` because it parses unary
- // minus and `2147483648` as separate tokens, and the latter doesn't
- // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid
- // issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which
- // ensures the expression type is `int`.
- const auto int_min = std::numeric_limits<int32_t>::min();
- if (sl->ValueAsI32() == int_min) {
- out << "(" << int_min + 1 << " - 1)";
- } else {
- out << sl->value;
- }
- } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
- out << ul->value << "u";
- } else {
- diagnostics_.add_error(diag::System::Writer, "unknown literal type");
- return false;
- }
- return true;
+ return Switch(
+ lit,
+ [&](const ast::BoolLiteralExpression* l) {
+ out << (l->value ? "true" : "false");
+ return true;
+ },
+ [&](const ast::FloatLiteralExpression* l) {
+ if (std::isinf(l->value)) {
+ out << (l->value >= 0 ? "INFINITY" : "-INFINITY");
+ } else if (std::isnan(l->value)) {
+ out << "NAN";
+ } else {
+ out << FloatToString(l->value) << "f";
+ }
+ return true;
+ },
+ [&](const ast::SintLiteralExpression* l) {
+ // MSL (and C++) parse `-2147483648` as a `long` because it parses unary
+ // minus and `2147483648` as separate tokens, and the latter doesn't
+ // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To
+ // avoid issues with `long` to `int` casts, emit `(2147483647 - 1)`
+ // instead, which ensures the expression type is `int`.
+ const auto int_min = std::numeric_limits<int32_t>::min();
+ if (l->ValueAsI32() == int_min) {
+ out << "(" << int_min + 1 << " - 1)";
+ } else {
+ out << l->value;
+ }
+ return true;
+ },
+ [&](const ast::UintLiteralExpression* l) {
+ out << l->value << "u";
+ return true;
+ },
+ [&](Default) {
+ diagnostics_.add_error(diag::System::Writer, "unknown literal type");
+ return false;
+ });
}
bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) {
- if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
- return EmitIndexAccessor(out, a);
- }
- if (auto* b = expr->As<ast::BinaryExpression>()) {
- return EmitBinary(out, b);
- }
- if (auto* b = expr->As<ast::BitcastExpression>()) {
- return EmitBitcast(out, b);
- }
- if (auto* c = expr->As<ast::CallExpression>()) {
- return EmitCall(out, c);
- }
- if (auto* i = expr->As<ast::IdentifierExpression>()) {
- return EmitIdentifier(out, i);
- }
- if (auto* l = expr->As<ast::LiteralExpression>()) {
- return EmitLiteral(out, l);
- }
- if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
- return EmitMemberAccessor(out, m);
- }
- if (auto* u = expr->As<ast::UnaryOpExpression>()) {
- return EmitUnaryOp(out, u);
- }
-
- diagnostics_.add_error(
- diag::System::Writer,
- "unknown expression type: " + std::string(expr->TypeInfo().name));
- return false;
+ return Switch(
+ expr,
+ [&](const ast::IndexAccessorExpression* a) { //
+ return EmitIndexAccessor(out, a);
+ },
+ [&](const ast::BinaryExpression* b) { //
+ return EmitBinary(out, b);
+ },
+ [&](const ast::BitcastExpression* b) { //
+ return EmitBitcast(out, b);
+ },
+ [&](const ast::CallExpression* c) { //
+ return EmitCall(out, c);
+ },
+ [&](const ast::IdentifierExpression* i) { //
+ return EmitIdentifier(out, i);
+ },
+ [&](const ast::LiteralExpression* l) { //
+ return EmitLiteral(out, l);
+ },
+ [&](const ast::MemberAccessorExpression* m) { //
+ return EmitMemberAccessor(out, m);
+ },
+ [&](const ast::UnaryOpExpression* u) { //
+ return EmitUnaryOp(out, u);
+ },
+ [&](Default) { //
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown expression type: " + std::string(expr->TypeInfo().name));
+ return false;
+ });
}
void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) {
@@ -2106,57 +2130,60 @@
}
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
- if (auto* a = stmt->As<ast::AssignmentStatement>()) {
- return EmitAssign(a);
- }
- if (auto* b = stmt->As<ast::BlockStatement>()) {
- return EmitBlock(b);
- }
- if (auto* b = stmt->As<ast::BreakStatement>()) {
- return EmitBreak(b);
- }
- if (auto* c = stmt->As<ast::CallStatement>()) {
- auto out = line();
- if (!EmitCall(out, c->expr)) {
- return false;
- }
- out << ";";
- return true;
- }
- if (auto* c = stmt->As<ast::ContinueStatement>()) {
- return EmitContinue(c);
- }
- if (auto* d = stmt->As<ast::DiscardStatement>()) {
- return EmitDiscard(d);
- }
- if (stmt->As<ast::FallthroughStatement>()) {
- line() << "/* fallthrough */";
- return true;
- }
- if (auto* i = stmt->As<ast::IfStatement>()) {
- return EmitIf(i);
- }
- if (auto* l = stmt->As<ast::LoopStatement>()) {
- return EmitLoop(l);
- }
- if (auto* l = stmt->As<ast::ForLoopStatement>()) {
- return EmitForLoop(l);
- }
- if (auto* r = stmt->As<ast::ReturnStatement>()) {
- return EmitReturn(r);
- }
- if (auto* s = stmt->As<ast::SwitchStatement>()) {
- return EmitSwitch(s);
- }
- if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
- auto* var = program_->Sem().Get(v->variable);
- return EmitVariable(var);
- }
-
- diagnostics_.add_error(
- diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- return false;
+ return Switch(
+ stmt,
+ [&](const ast::AssignmentStatement* a) { //
+ return EmitAssign(a);
+ },
+ [&](const ast::BlockStatement* b) { //
+ return EmitBlock(b);
+ },
+ [&](const ast::BreakStatement* b) { //
+ return EmitBreak(b);
+ },
+ [&](const ast::CallStatement* c) { //
+ auto out = line();
+ if (!EmitCall(out, c->expr)) { //
+ return false;
+ }
+ out << ";";
+ return true;
+ },
+ [&](const ast::ContinueStatement* c) { //
+ return EmitContinue(c);
+ },
+ [&](const ast::DiscardStatement* d) { //
+ return EmitDiscard(d);
+ },
+ [&](const ast::FallthroughStatement*) { //
+ line() << "/* fallthrough */";
+ return true;
+ },
+ [&](const ast::IfStatement* i) { //
+ return EmitIf(i);
+ },
+ [&](const ast::LoopStatement* l) { //
+ return EmitLoop(l);
+ },
+ [&](const ast::ForLoopStatement* l) { //
+ return EmitForLoop(l);
+ },
+ [&](const ast::ReturnStatement* r) { //
+ return EmitReturn(r);
+ },
+ [&](const ast::SwitchStatement* s) { //
+ return EmitSwitch(s);
+ },
+ [&](const ast::VariableDeclStatement* v) { //
+ auto* var = program_->Sem().Get(v->variable);
+ return EmitVariable(var);
+ },
+ [&](Default) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown statement type: " + std::string(stmt->TypeInfo().name));
+ return false;
+ });
}
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
@@ -2204,203 +2231,210 @@
if (name_printed) {
*name_printed = false;
}
- if (auto* atomic = type->As<sem::Atomic>()) {
- if (atomic->Type()->Is<sem::I32>()) {
- out << "atomic_int";
- return true;
- }
- if (atomic->Type()->Is<sem::U32>()) {
- out << "atomic_uint";
- return true;
- }
- TINT_ICE(Writer, diagnostics_)
- << "unhandled atomic type " << atomic->Type()->type_name();
- return false;
- }
- if (auto* ary = type->As<sem::Array>()) {
- const sem::Type* base_type = ary;
- std::vector<uint32_t> sizes;
- while (auto* arr = base_type->As<sem::Array>()) {
- if (arr->IsRuntimeSized()) {
- sizes.push_back(1);
- } else {
- sizes.push_back(arr->Count());
- }
- base_type = arr->ElemType();
- }
- if (!EmitType(out, base_type, "")) {
- return false;
- }
- if (!name.empty()) {
- out << " " << name;
- if (name_printed) {
- *name_printed = true;
- }
- }
- for (uint32_t size : sizes) {
- out << "[" << size << "]";
- }
- return true;
- }
-
- if (type->Is<sem::Bool>()) {
- out << "bool";
- return true;
- }
-
- if (type->Is<sem::F32>()) {
- out << "float";
- return true;
- }
-
- if (type->Is<sem::I32>()) {
- out << "int";
- return true;
- }
-
- if (auto* mat = type->As<sem::Matrix>()) {
- if (!EmitType(out, mat->type(), "")) {
- return false;
- }
- out << mat->columns() << "x" << mat->rows();
- return true;
- }
-
- if (auto* ptr = type->As<sem::Pointer>()) {
- if (ptr->Access() == ast::Access::kRead) {
- out << "const ";
- }
- if (!EmitStorageClass(out, ptr->StorageClass())) {
- return false;
- }
- out << " ";
- if (ptr->StoreType()->Is<sem::Array>()) {
- std::string inner = "(*" + name + ")";
- if (!EmitType(out, ptr->StoreType(), inner)) {
+ return Switch(
+ type,
+ [&](const sem::Atomic* atomic) {
+ if (atomic->Type()->Is<sem::I32>()) {
+ out << "atomic_int";
+ return true;
+ }
+ if (atomic->Type()->Is<sem::U32>()) {
+ out << "atomic_uint";
+ return true;
+ }
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled atomic type " << atomic->Type()->type_name();
return false;
- }
- if (name_printed) {
- *name_printed = true;
- }
- } else {
- if (!EmitType(out, ptr->StoreType(), "")) {
+ },
+ [&](const sem::Array* ary) {
+ const sem::Type* base_type = ary;
+ std::vector<uint32_t> sizes;
+ while (auto* arr = base_type->As<sem::Array>()) {
+ if (arr->IsRuntimeSized()) {
+ sizes.push_back(1);
+ } else {
+ sizes.push_back(arr->Count());
+ }
+ base_type = arr->ElemType();
+ }
+ if (!EmitType(out, base_type, "")) {
+ return false;
+ }
+ if (!name.empty()) {
+ out << " " << name;
+ if (name_printed) {
+ *name_printed = true;
+ }
+ }
+ for (uint32_t size : sizes) {
+ out << "[" << size << "]";
+ }
+ return true;
+ },
+ [&](const sem::Bool*) {
+ out << "bool";
+ return true;
+ },
+ [&](const sem::F32*) {
+ out << "float";
+ return true;
+ },
+ [&](const sem::I32*) {
+ out << "int";
+ return true;
+ },
+ [&](const sem::Matrix* mat) {
+ if (!EmitType(out, mat->type(), "")) {
+ return false;
+ }
+ out << mat->columns() << "x" << mat->rows();
+ return true;
+ },
+ [&](const sem::Pointer* ptr) {
+ if (ptr->Access() == ast::Access::kRead) {
+ out << "const ";
+ }
+ if (!EmitStorageClass(out, ptr->StorageClass())) {
+ return false;
+ }
+ out << " ";
+ if (ptr->StoreType()->Is<sem::Array>()) {
+ std::string inner = "(*" + name + ")";
+ if (!EmitType(out, ptr->StoreType(), inner)) {
+ return false;
+ }
+ if (name_printed) {
+ *name_printed = true;
+ }
+ } else {
+ if (!EmitType(out, ptr->StoreType(), "")) {
+ return false;
+ }
+ out << "* " << name;
+ if (name_printed) {
+ *name_printed = true;
+ }
+ }
+ return true;
+ },
+ [&](const sem::Sampler*) {
+ out << "sampler";
+ return true;
+ },
+ [&](const sem::Struct* str) {
+ // The struct type emits as just the name. The declaration would be
+ // emitted as part of emitting the declared types.
+ out << StructName(str);
+ return true;
+ },
+ [&](const sem::Texture* tex) {
+ if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
+ out << "depth";
+ } else {
+ out << "texture";
+ }
+
+ switch (tex->dim()) {
+ case ast::TextureDimension::k1d:
+ out << "1d";
+ break;
+ case ast::TextureDimension::k2d:
+ out << "2d";
+ break;
+ case ast::TextureDimension::k2dArray:
+ out << "2d_array";
+ break;
+ case ast::TextureDimension::k3d:
+ out << "3d";
+ break;
+ case ast::TextureDimension::kCube:
+ out << "cube";
+ break;
+ case ast::TextureDimension::kCubeArray:
+ out << "cube_array";
+ break;
+ default:
+ diagnostics_.add_error(diag::System::Writer,
+ "Invalid texture dimensions");
+ return false;
+ }
+ if (tex->IsAnyOf<sem::MultisampledTexture,
+ sem::DepthMultisampledTexture>()) {
+ out << "_ms";
+ }
+ out << "<";
+ TINT_DEFER(out << ">");
+
+ return Switch(
+ tex,
+ [&](const sem::DepthTexture*) {
+ out << "float, access::sample";
+ return true;
+ },
+ [&](const sem::DepthMultisampledTexture*) {
+ out << "float, access::read";
+ return true;
+ },
+ [&](const sem::StorageTexture* storage) {
+ if (!EmitType(out, storage->type(), "")) {
+ return false;
+ }
+
+ std::string access_str;
+ if (storage->access() == ast::Access::kRead) {
+ out << ", access::read";
+ } else if (storage->access() == ast::Access::kWrite) {
+ out << ", access::write";
+ } else {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "Invalid access control for storage texture");
+ return false;
+ }
+ return true;
+ },
+ [&](const sem::MultisampledTexture* ms) {
+ if (!EmitType(out, ms->type(), "")) {
+ return false;
+ }
+ out << ", access::read";
+ return true;
+ },
+ [&](const sem::SampledTexture* sampled) {
+ if (!EmitType(out, sampled->type(), "")) {
+ return false;
+ }
+ out << ", access::sample";
+ return true;
+ },
+ [&](Default) {
+ diagnostics_.add_error(diag::System::Writer,
+ "invalid texture type");
+ return false;
+ });
+ },
+ [&](const sem::U32*) {
+ out << "uint";
+ return true;
+ },
+ [&](const sem::Vector* vec) {
+ if (!EmitType(out, vec->type(), "")) {
+ return false;
+ }
+ out << vec->Width();
+ return true;
+ },
+ [&](const sem::Void*) {
+ out << "void";
+ return true;
+ },
+ [&](Default) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown type in EmitType: " + type->type_name());
return false;
- }
- out << "* " << name;
- if (name_printed) {
- *name_printed = true;
- }
- }
- return true;
- }
-
- if (type->Is<sem::Sampler>()) {
- out << "sampler";
- return true;
- }
-
- if (auto* str = type->As<sem::Struct>()) {
- // The struct type emits as just the name. The declaration would be emitted
- // as part of emitting the declared types.
- out << StructName(str);
- return true;
- }
-
- if (auto* tex = type->As<sem::Texture>()) {
- if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
- out << "depth";
- } else {
- out << "texture";
- }
-
- switch (tex->dim()) {
- case ast::TextureDimension::k1d:
- out << "1d";
- break;
- case ast::TextureDimension::k2d:
- out << "2d";
- break;
- case ast::TextureDimension::k2dArray:
- out << "2d_array";
- break;
- case ast::TextureDimension::k3d:
- out << "3d";
- break;
- case ast::TextureDimension::kCube:
- out << "cube";
- break;
- case ast::TextureDimension::kCubeArray:
- out << "cube_array";
- break;
- default:
- diagnostics_.add_error(diag::System::Writer,
- "Invalid texture dimensions");
- return false;
- }
- if (tex->IsAnyOf<sem::MultisampledTexture,
- sem::DepthMultisampledTexture>()) {
- out << "_ms";
- }
- out << "<";
- if (tex->Is<sem::DepthTexture>()) {
- out << "float, access::sample";
- } else if (tex->Is<sem::DepthMultisampledTexture>()) {
- out << "float, access::read";
- } else if (auto* storage = tex->As<sem::StorageTexture>()) {
- if (!EmitType(out, storage->type(), "")) {
- return false;
- }
-
- std::string access_str;
- if (storage->access() == ast::Access::kRead) {
- out << ", access::read";
- } else if (storage->access() == ast::Access::kWrite) {
- out << ", access::write";
- } else {
- diagnostics_.add_error(diag::System::Writer,
- "Invalid access control for storage texture");
- return false;
- }
- } else if (auto* ms = tex->As<sem::MultisampledTexture>()) {
- if (!EmitType(out, ms->type(), "")) {
- return false;
- }
- out << ", access::read";
- } else if (auto* sampled = tex->As<sem::SampledTexture>()) {
- if (!EmitType(out, sampled->type(), "")) {
- return false;
- }
- out << ", access::sample";
- } else {
- diagnostics_.add_error(diag::System::Writer, "invalid texture type");
- return false;
- }
- out << ">";
- return true;
- }
-
- if (type->Is<sem::U32>()) {
- out << "uint";
- return true;
- }
-
- if (auto* vec = type->As<sem::Vector>()) {
- if (!EmitType(out, vec->type(), "")) {
- return false;
- }
- out << vec->Width();
- return true;
- }
-
- if (type->Is<sem::Void>()) {
- out << "void";
- return true;
- }
-
- diagnostics_.add_error(diag::System::Writer,
- "unknown type in EmitType: " + type->type_name());
- return false;
+ });
}
bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
@@ -2542,55 +2576,72 @@
// Emit attributes
if (auto* decl = mem->Declaration()) {
for (auto* attr : decl->attributes) {
- if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
- auto name = builtin_to_attribute(builtin->builtin);
- if (name.empty()) {
- diagnostics_.add_error(diag::System::Writer, "unknown builtin");
- return false;
- }
- out << " [[" << name << "]]";
- } else if (auto* loc = attr->As<ast::LocationAttribute>()) {
- auto& pipeline_stage_uses = str->PipelineStageUses();
- if (pipeline_stage_uses.size() != 1) {
- TINT_ICE(Writer, diagnostics_)
- << "invalid entry point IO struct uses";
- }
+ bool ok = Switch(
+ attr,
+ [&](const ast::BuiltinAttribute* builtin) {
+ auto name = builtin_to_attribute(builtin->builtin);
+ if (name.empty()) {
+ diagnostics_.add_error(diag::System::Writer, "unknown builtin");
+ return false;
+ }
+ out << " [[" << name << "]]";
+ return true;
+ },
+ [&](const ast::LocationAttribute* loc) {
+ auto& pipeline_stage_uses = str->PipelineStageUses();
+ if (pipeline_stage_uses.size() != 1) {
+ TINT_ICE(Writer, diagnostics_)
+ << "invalid entry point IO struct uses";
+ return false;
+ }
- if (pipeline_stage_uses.count(
- sem::PipelineStageUsage::kVertexInput)) {
- out << " [[attribute(" + std::to_string(loc->value) + ")]]";
- } else if (pipeline_stage_uses.count(
- sem::PipelineStageUsage::kVertexOutput)) {
- out << " [[user(locn" + std::to_string(loc->value) + ")]]";
- } else if (pipeline_stage_uses.count(
- sem::PipelineStageUsage::kFragmentInput)) {
- out << " [[user(locn" + std::to_string(loc->value) + ")]]";
- } else if (pipeline_stage_uses.count(
- sem::PipelineStageUsage::kFragmentOutput)) {
- out << " [[color(" + std::to_string(loc->value) + ")]]";
- } else {
- TINT_ICE(Writer, diagnostics_)
- << "invalid use of location attribute";
- }
- } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
- auto name = interpolation_to_attribute(interpolate->type,
- interpolate->sampling);
- if (name.empty()) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown interpolation attribute");
- return false;
- }
- out << " [[" << name << "]]";
- } else if (attr->Is<ast::InvariantAttribute>()) {
- if (invariant_define_name_.empty()) {
- invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
- }
- out << " " << invariant_define_name_;
- } else if (!attr->IsAnyOf<ast::StructMemberOffsetAttribute,
- ast::StructMemberAlignAttribute,
- ast::StructMemberSizeAttribute>()) {
- TINT_ICE(Writer, diagnostics_)
- << "unhandled struct member attribute: " << attr->Name();
+ if (pipeline_stage_uses.count(
+ sem::PipelineStageUsage::kVertexInput)) {
+ out << " [[attribute(" + std::to_string(loc->value) + ")]]";
+ } else if (pipeline_stage_uses.count(
+ sem::PipelineStageUsage::kVertexOutput)) {
+ out << " [[user(locn" + std::to_string(loc->value) + ")]]";
+ } else if (pipeline_stage_uses.count(
+ sem::PipelineStageUsage::kFragmentInput)) {
+ out << " [[user(locn" + std::to_string(loc->value) + ")]]";
+ } else if (pipeline_stage_uses.count(
+ sem::PipelineStageUsage::kFragmentOutput)) {
+ out << " [[color(" + std::to_string(loc->value) + ")]]";
+ } else {
+ TINT_ICE(Writer, diagnostics_)
+ << "invalid use of location decoration";
+ return false;
+ }
+ return true;
+ },
+ [&](const ast::InterpolateAttribute* interpolate) {
+ auto name = interpolation_to_attribute(interpolate->type,
+ interpolate->sampling);
+ if (name.empty()) {
+ diagnostics_.add_error(diag::System::Writer,
+ "unknown interpolation attribute");
+ return false;
+ }
+ out << " [[" << name << "]]";
+ return true;
+ },
+ [&](const ast::InvariantAttribute*) {
+ if (invariant_define_name_.empty()) {
+ invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
+ }
+ out << " " << invariant_define_name_;
+ return true;
+ },
+ [&](const ast::StructMemberOffsetAttribute*) { return true; },
+ [&](const ast::StructMemberAlignAttribute*) { return true; },
+ [&](const ast::StructMemberSizeAttribute*) { return true; },
+ [&](Default) {
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled struct member attribute: " << attr->Name();
+ return false;
+ });
+ if (!ok) {
+ return false;
}
}
}
@@ -2796,77 +2847,96 @@
GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
const sem::Type* ty) {
- if (ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
- // 2.1 Scalar Data Types
- return {4, 4};
- }
+ return Switch(
+ ty,
- if (auto* vec = ty->As<sem::Vector>()) {
- auto num_els = vec->Width();
- auto* el_ty = vec->type();
- if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
- // Use a packed_vec type for 3-element vectors only.
- if (num_els == 3) {
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+ // 2.1 Scalar Data Types
+ [&](const sem::U32*) {
+ return SizeAndAlign{4, 4};
+ },
+ [&](const sem::I32*) {
+ return SizeAndAlign{4, 4};
+ },
+ [&](const sem::F32*) {
+ return SizeAndAlign{4, 4};
+ },
+
+ [&](const sem::Vector* vec) {
+ auto num_els = vec->Width();
+ auto* el_ty = vec->type();
+ if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
+ // Use a packed_vec type for 3-element vectors only.
+ if (num_els == 3) {
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+ // 2.2.3 Packed Vector Types
+ return SizeAndAlign{num_els * 4, 4};
+ } else {
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+ // 2.2 Vector Data Types
+ return SizeAndAlign{num_els * 4, num_els * 4};
+ }
+ }
+ TINT_UNREACHABLE(Writer, diagnostics_)
+ << "Unhandled vector element type " << el_ty->TypeInfo().name;
+ return SizeAndAlign{};
+ },
+
+ [&](const sem::Matrix* mat) {
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
- // 2.2.3 Packed Vector Types
- return SizeAndAlign{num_els * 4, 4};
- } else {
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
- // 2.2 Vector Data Types
- return SizeAndAlign{num_els * 4, num_els * 4};
- }
- }
- }
+ // 2.3 Matrix Data Types
+ auto cols = mat->columns();
+ auto rows = mat->rows();
+ auto* el_ty = mat->type();
+ if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
+ static constexpr SizeAndAlign table[] = {
+ /* float2x2 */ {16, 8},
+ /* float2x3 */ {32, 16},
+ /* float2x4 */ {32, 16},
+ /* float3x2 */ {24, 8},
+ /* float3x3 */ {48, 16},
+ /* float3x4 */ {48, 16},
+ /* float4x2 */ {32, 8},
+ /* float4x3 */ {64, 16},
+ /* float4x4 */ {64, 16},
+ };
+ if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
+ return table[(3 * (cols - 2)) + (rows - 2)];
+ }
+ }
- if (auto* mat = ty->As<sem::Matrix>()) {
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
- // 2.3 Matrix Data Types
- auto cols = mat->columns();
- auto rows = mat->rows();
- auto* el_ty = mat->type();
- if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
- static constexpr SizeAndAlign table[] = {
- /* float2x2 */ {16, 8},
- /* float2x3 */ {32, 16},
- /* float2x4 */ {32, 16},
- /* float3x2 */ {24, 8},
- /* float3x3 */ {48, 16},
- /* float3x4 */ {48, 16},
- /* float4x2 */ {32, 8},
- /* float4x3 */ {64, 16},
- /* float4x4 */ {64, 16},
- };
- if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
- return table[(3 * (cols - 2)) + (rows - 2)];
- }
- }
- }
+ TINT_UNREACHABLE(Writer, diagnostics_)
+ << "Unhandled matrix element type " << el_ty->TypeInfo().name;
+ return SizeAndAlign{};
+ },
- if (auto* arr = ty->As<sem::Array>()) {
- if (!arr->IsStrideImplicit()) {
- TINT_ICE(Writer, diagnostics_)
- << "arrays with explicit strides should have "
- "removed with the PadArrayElements transform";
- return {};
- }
- auto num_els = std::max<uint32_t>(arr->Count(), 1);
- return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
- }
+ [&](const sem::Array* arr) {
+ if (!arr->IsStrideImplicit()) {
+ TINT_ICE(Writer, diagnostics_)
+ << "arrays with explicit strides should have "
+ "removed with the PadArrayElements transform";
+ return SizeAndAlign{};
+ }
+ auto num_els = std::max<uint32_t>(arr->Count(), 1);
+ return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
+ },
- if (auto* str = ty->As<sem::Struct>()) {
- // TODO(crbug.com/tint/650): There's an assumption here that MSL's default
- // structure size and alignment matches WGSL's. We need to confirm this.
- return SizeAndAlign{str->Size(), str->Align()};
- }
+ [&](const sem::Struct* str) {
+ // TODO(crbug.com/tint/650): There's an assumption here that MSL's
+ // default structure size and alignment matches WGSL's. We need to
+ // confirm this.
+ return SizeAndAlign{str->Size(), str->Align()};
+ },
- if (auto* atomic = ty->As<sem::Atomic>()) {
- return MslPackedTypeSizeAndAlign(atomic->Type());
- }
+ [&](const sem::Atomic* atomic) {
+ return MslPackedTypeSizeAndAlign(atomic->Type());
+ },
- TINT_UNREACHABLE(Writer, diagnostics_)
- << "Unhandled type " << ty->TypeInfo().name;
- return {};
+ [&](Default) {
+ TINT_UNREACHABLE(Writer, diagnostics_)
+ << "Unhandled type " << ty->TypeInfo().name;
+ return SizeAndAlign{};
+ });
}
template <typename F>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 19707c0..932ae8c 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -560,33 +560,37 @@
}
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
- if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
- return GenerateAccessorExpression(a);
- }
- if (auto* b = expr->As<ast::BinaryExpression>()) {
- return GenerateBinaryExpression(b);
- }
- if (auto* b = expr->As<ast::BitcastExpression>()) {
- return GenerateBitcastExpression(b);
- }
- if (auto* c = expr->As<ast::CallExpression>()) {
- return GenerateCallExpression(c);
- }
- if (auto* i = expr->As<ast::IdentifierExpression>()) {
- return GenerateIdentifierExpression(i);
- }
- if (auto* l = expr->As<ast::LiteralExpression>()) {
- return GenerateLiteralIfNeeded(nullptr, l);
- }
- if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
- return GenerateAccessorExpression(m);
- }
- if (auto* u = expr->As<ast::UnaryOpExpression>()) {
- return GenerateUnaryOpExpression(u);
- }
-
- error_ = "unknown expression type: " + std::string(expr->TypeInfo().name);
- return 0;
+ return Switch(
+ expr,
+ [&](const ast::IndexAccessorExpression* a) { //
+ return GenerateAccessorExpression(a);
+ },
+ [&](const ast::BinaryExpression* b) { //
+ return GenerateBinaryExpression(b);
+ },
+ [&](const ast::BitcastExpression* b) { //
+ return GenerateBitcastExpression(b);
+ },
+ [&](const ast::CallExpression* c) { //
+ return GenerateCallExpression(c);
+ },
+ [&](const ast::IdentifierExpression* i) { //
+ return GenerateIdentifierExpression(i);
+ },
+ [&](const ast::LiteralExpression* l) { //
+ return GenerateLiteralIfNeeded(nullptr, l);
+ },
+ [&](const ast::MemberAccessorExpression* m) { //
+ return GenerateAccessorExpression(m);
+ },
+ [&](const ast::UnaryOpExpression* u) { //
+ return GenerateUnaryOpExpression(u);
+ },
+ [&](Default) -> uint32_t {
+ error_ =
+ "unknown expression type: " + std::string(expr->TypeInfo().name);
+ return 0;
+ });
}
bool Builder::GenerateFunction(const ast::Function* func_ast) {
@@ -861,33 +865,56 @@
push_type(spv::Op::OpVariable, std::move(ops));
for (auto* attr : var->attributes) {
- if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
- push_annot(spv::Op::OpDecorate,
- {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
- Operand::Int(
- ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
- } else if (auto* location = attr->As<ast::LocationAttribute>()) {
- push_annot(spv::Op::OpDecorate,
- {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
- Operand::Int(location->value)});
- } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
- AddInterpolationDecorations(var_id, interpolate->type,
- interpolate->sampling);
- } else if (attr->Is<ast::InvariantAttribute>()) {
- push_annot(spv::Op::OpDecorate,
- {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
- } else if (auto* binding = attr->As<ast::BindingAttribute>()) {
- push_annot(spv::Op::OpDecorate,
- {Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
- Operand::Int(binding->value)});
- } else if (auto* group = attr->As<ast::GroupAttribute>()) {
- push_annot(spv::Op::OpDecorate, {Operand::Int(var_id),
- Operand::Int(SpvDecorationDescriptorSet),
- Operand::Int(group->value)});
- } else if (attr->Is<ast::OverrideAttribute>()) {
- // Spec constants are handled elsewhere
- } else if (!attr->Is<ast::InternalAttribute>()) {
- error_ = "unknown attribute";
+ bool ok = Switch(
+ attr,
+ [&](const ast::BuiltinAttribute* builtin) {
+ push_annot(spv::Op::OpDecorate,
+ {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
+ Operand::Int(ConvertBuiltin(builtin->builtin,
+ sem->StorageClass()))});
+ return true;
+ },
+ [&](const ast::LocationAttribute* location) {
+ push_annot(spv::Op::OpDecorate,
+ {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
+ Operand::Int(location->value)});
+ return true;
+ },
+ [&](const ast::InterpolateAttribute* interpolate) {
+ AddInterpolationDecorations(var_id, interpolate->type,
+ interpolate->sampling);
+ return true;
+ },
+ [&](const ast::InvariantAttribute*) {
+ push_annot(
+ spv::Op::OpDecorate,
+ {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
+ return true;
+ },
+ [&](const ast::BindingAttribute* binding) {
+ push_annot(spv::Op::OpDecorate,
+ {Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
+ Operand::Int(binding->value)});
+ return true;
+ },
+ [&](const ast::GroupAttribute* group) {
+ push_annot(
+ spv::Op::OpDecorate,
+ {Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
+ Operand::Int(group->value)});
+ return true;
+ },
+ [&](const ast::OverrideAttribute*) {
+ return true; // Spec constants are handled elsewhere
+ },
+ [&](const ast::InternalAttribute*) {
+ return true; // ignored
+ },
+ [&](Default) {
+ error_ = "unknown attribute";
+ return false;
+ });
+ if (!ok) {
return false;
}
}
@@ -1123,19 +1150,21 @@
// promoted to storage with the VarForDynamicIndex transform.
for (auto* accessor : accessors) {
- if (auto* array = accessor->As<ast::IndexAccessorExpression>()) {
- if (!GenerateIndexAccessor(array, &info)) {
- return 0;
- }
- } else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) {
- if (!GenerateMemberAccessor(member, &info)) {
- return 0;
- }
-
- } else {
- error_ =
- "invalid accessor in list: " + std::string(accessor->TypeInfo().name);
- return 0;
+ bool ok = Switch(
+ accessor,
+ [&](const ast::IndexAccessorExpression* array) {
+ return GenerateIndexAccessor(array, &info);
+ },
+ [&](const ast::MemberAccessorExpression* member) {
+ return GenerateMemberAccessor(member, &info);
+ },
+ [&](Default) {
+ error_ = "invalid accessor in list: " +
+ std::string(accessor->TypeInfo().name);
+ return false;
+ });
+ if (!ok) {
+ return false;
}
}
@@ -1653,21 +1682,28 @@
constant.constant_id = global->ConstantId();
}
- if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
- constant.kind = ScalarConstant::Kind::kBool;
- constant.value.b = l->value;
- } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
- constant.kind = ScalarConstant::Kind::kI32;
- constant.value.i32 = sl->value;
- } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
- constant.kind = ScalarConstant::Kind::kU32;
- constant.value.u32 = ul->value;
- } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
- constant.kind = ScalarConstant::Kind::kF32;
- constant.value.f32 = fl->value;
- } else {
- error_ = "unknown literal type";
- return 0;
+ Switch(
+ lit,
+ [&](const ast::BoolLiteralExpression* l) {
+ constant.kind = ScalarConstant::Kind::kBool;
+ constant.value.b = l->value;
+ },
+ [&](const ast::SintLiteralExpression* sl) {
+ constant.kind = ScalarConstant::Kind::kI32;
+ constant.value.i32 = sl->value;
+ },
+ [&](const ast::UintLiteralExpression* ul) {
+ constant.kind = ScalarConstant::Kind::kU32;
+ constant.value.u32 = ul->value;
+ },
+ [&](const ast::FloatLiteralExpression* fl) {
+ constant.kind = ScalarConstant::Kind::kF32;
+ constant.value.f32 = fl->value;
+ },
+ [&](Default) { error_ = "unknown literal type"; });
+
+ if (!error_.empty()) {
+ return false;
}
return GenerateConstantIfNeeded(constant);
@@ -2209,19 +2245,25 @@
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
-
- if (auto* func = target->As<sem::Function>()) {
- return GenerateFunctionCall(call, func);
- }
- if (auto* builtin = target->As<sem::Builtin>()) {
- return GenerateBuiltinCall(call, builtin);
- }
- if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
- return GenerateTypeConstructorOrConversion(call, nullptr);
- }
- TINT_ICE(Writer, builder_.Diagnostics())
- << "unhandled call target: " << target->TypeInfo().name;
- return false;
+ return Switch(
+ target,
+ [&](const sem::Function* func) {
+ return GenerateFunctionCall(call, func);
+ },
+ [&](const sem::Builtin* builtin) {
+ return GenerateBuiltinCall(call, builtin);
+ },
+ [&](const sem::TypeConversion*) {
+ return GenerateTypeConstructorOrConversion(call, nullptr);
+ },
+ [&](const sem::TypeConstructor*) {
+ return GenerateTypeConstructorOrConversion(call, nullptr);
+ },
+ [&](Default) -> uint32_t {
+ TINT_ICE(Writer, builder_.Diagnostics())
+ << "unhandled call target: " << target->TypeInfo().name;
+ return 0;
+ });
}
uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
@@ -3790,46 +3832,49 @@
}
bool Builder::GenerateStatement(const ast::Statement* stmt) {
- if (auto* a = stmt->As<ast::AssignmentStatement>()) {
- return GenerateAssignStatement(a);
- }
- if (auto* b = stmt->As<ast::BlockStatement>()) {
- return GenerateBlockStatement(b);
- }
- if (auto* b = stmt->As<ast::BreakStatement>()) {
- return GenerateBreakStatement(b);
- }
- if (auto* c = stmt->As<ast::CallStatement>()) {
- return GenerateCallExpression(c->expr) != 0;
- }
- if (auto* c = stmt->As<ast::ContinueStatement>()) {
- return GenerateContinueStatement(c);
- }
- if (auto* d = stmt->As<ast::DiscardStatement>()) {
- return GenerateDiscardStatement(d);
- }
- if (stmt->Is<ast::FallthroughStatement>()) {
- // Do nothing here, the fallthrough gets handled by the switch code.
- return true;
- }
- if (auto* i = stmt->As<ast::IfStatement>()) {
- return GenerateIfStatement(i);
- }
- if (auto* l = stmt->As<ast::LoopStatement>()) {
- return GenerateLoopStatement(l);
- }
- if (auto* r = stmt->As<ast::ReturnStatement>()) {
- return GenerateReturnStatement(r);
- }
- if (auto* s = stmt->As<ast::SwitchStatement>()) {
- return GenerateSwitchStatement(s);
- }
- if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
- return GenerateVariableDeclStatement(v);
- }
-
- error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
- return false;
+ return Switch(
+ stmt,
+ [&](const ast::AssignmentStatement* a) {
+ return GenerateAssignStatement(a);
+ },
+ [&](const ast::BlockStatement* b) { //
+ return GenerateBlockStatement(b);
+ },
+ [&](const ast::BreakStatement* b) { //
+ return GenerateBreakStatement(b);
+ },
+ [&](const ast::CallStatement* c) {
+ return GenerateCallExpression(c->expr) != 0;
+ },
+ [&](const ast::ContinueStatement* c) {
+ return GenerateContinueStatement(c);
+ },
+ [&](const ast::DiscardStatement* d) {
+ return GenerateDiscardStatement(d);
+ },
+ [&](const ast::FallthroughStatement*) {
+ // Do nothing here, the fallthrough gets handled by the switch code.
+ return true;
+ },
+ [&](const ast::IfStatement* i) { //
+ return GenerateIfStatement(i);
+ },
+ [&](const ast::LoopStatement* l) { //
+ return GenerateLoopStatement(l);
+ },
+ [&](const ast::ReturnStatement* r) { //
+ return GenerateReturnStatement(r);
+ },
+ [&](const ast::SwitchStatement* s) { //
+ return GenerateSwitchStatement(s);
+ },
+ [&](const ast::VariableDeclStatement* v) {
+ return GenerateVariableDeclStatement(v);
+ },
+ [&](Default) {
+ error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
+ return false;
+ });
}
bool Builder::GenerateVariableDeclStatement(
@@ -3872,78 +3917,91 @@
return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
auto result = result_op();
auto id = result.to_i();
- if (auto* arr = type->As<sem::Array>()) {
- if (!GenerateArrayType(arr, result)) {
- return 0;
- }
- } else if (type->Is<sem::Bool>()) {
- push_type(spv::Op::OpTypeBool, {result});
- } else if (type->Is<sem::F32>()) {
- push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
- } else if (type->Is<sem::I32>()) {
- push_type(spv::Op::OpTypeInt,
- {result, Operand::Int(32), Operand::Int(1)});
- } else if (auto* mat = type->As<sem::Matrix>()) {
- if (!GenerateMatrixType(mat, result)) {
- return 0;
- }
- } else if (auto* ptr = type->As<sem::Pointer>()) {
- if (!GeneratePointerType(ptr, result)) {
- return 0;
- }
- } else if (auto* ref = type->As<sem::Reference>()) {
- if (!GenerateReferenceType(ref, result)) {
- return 0;
- }
- } else if (auto* str = type->As<sem::Struct>()) {
- if (!GenerateStructType(str, result)) {
- return 0;
- }
- } else if (type->Is<sem::U32>()) {
- push_type(spv::Op::OpTypeInt,
- {result, Operand::Int(32), Operand::Int(0)});
- } else if (auto* vec = type->As<sem::Vector>()) {
- if (!GenerateVectorType(vec, result)) {
- return 0;
- }
- } else if (type->Is<sem::Void>()) {
- push_type(spv::Op::OpTypeVoid, {result});
- } else if (auto* tex = type->As<sem::Texture>()) {
- if (!GenerateTextureType(tex, result)) {
- return 0;
- }
+ bool ok = Switch(
+ type,
+ [&](const sem::Array* arr) { //
+ return GenerateArrayType(arr, result);
+ },
+ [&](const sem::Bool*) {
+ push_type(spv::Op::OpTypeBool, {result});
+ return true;
+ },
+ [&](const sem::F32*) {
+ push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
+ return true;
+ },
+ [&](const sem::I32*) {
+ push_type(spv::Op::OpTypeInt,
+ {result, Operand::Int(32), Operand::Int(1)});
+ return true;
+ },
+ [&](const sem::Matrix* mat) { //
+ return GenerateMatrixType(mat, result);
+ },
+ [&](const sem::Pointer* ptr) { //
+ return GeneratePointerType(ptr, result);
+ },
+ [&](const sem::Reference* ref) { //
+ return GenerateReferenceType(ref, result);
+ },
+ [&](const sem::Struct* str) { //
+ return GenerateStructType(str, result);
+ },
+ [&](const sem::U32*) {
+ push_type(spv::Op::OpTypeInt,
+ {result, Operand::Int(32), Operand::Int(0)});
+ return true;
+ },
+ [&](const sem::Vector* vec) { //
+ return GenerateVectorType(vec, result);
+ },
+ [&](const sem::Void*) {
+ push_type(spv::Op::OpTypeVoid, {result});
+ return true;
+ },
+ [&](const sem::StorageTexture* tex) {
+ if (!GenerateTextureType(tex, result)) {
+ return false;
+ }
- if (auto* st = tex->As<sem::StorageTexture>()) {
- // Register all three access types of StorageTexture names. In SPIR-V,
- // we must output a single type, while the variable is annotated with
- // the access type. Doing this ensures we de-dupe.
- type_name_to_id_[builder_
- .create<sem::StorageTexture>(
- st->dim(), st->texel_format(),
- ast::Access::kRead, st->type())
- ->type_name()] = id;
- type_name_to_id_[builder_
- .create<sem::StorageTexture>(
- st->dim(), st->texel_format(),
- ast::Access::kWrite, st->type())
- ->type_name()] = id;
- type_name_to_id_[builder_
- .create<sem::StorageTexture>(
- st->dim(), st->texel_format(),
- ast::Access::kReadWrite, st->type())
- ->type_name()] = id;
- }
+ // Register all three access types of StorageTexture names. In
+ // SPIR-V, we must output a single type, while the variable is
+ // annotated with the access type. Doing this ensures we de-dupe.
+ type_name_to_id_[builder_
+ .create<sem::StorageTexture>(
+ tex->dim(), tex->texel_format(),
+ ast::Access::kRead, tex->type())
+ ->type_name()] = id;
+ type_name_to_id_[builder_
+ .create<sem::StorageTexture>(
+ tex->dim(), tex->texel_format(),
+ ast::Access::kWrite, tex->type())
+ ->type_name()] = id;
+ type_name_to_id_[builder_
+ .create<sem::StorageTexture>(
+ tex->dim(), tex->texel_format(),
+ ast::Access::kReadWrite, tex->type())
+ ->type_name()] = id;
+ return true;
+ },
+ [&](const sem::Texture* tex) {
+ return GenerateTextureType(tex, result);
+ },
+ [&](const sem::Sampler*) {
+ push_type(spv::Op::OpTypeSampler, {result});
- } else if (type->Is<sem::Sampler>()) {
- push_type(spv::Op::OpTypeSampler, {result});
+ // Register both of the sampler type names. In SPIR-V they're the same
+ // sampler type, so we need to match that when we do the dedup check.
+ type_name_to_id_["__sampler_sampler"] = id;
+ type_name_to_id_["__sampler_comparison"] = id;
+ return true;
+ },
+ [&](Default) {
+ error_ = "unable to convert type: " + type->type_name();
+ return false;
+ });
- // Register both of the sampler type names. In SPIR-V they're the same
- // sampler type, so we need to match that when we do the dedup check.
- type_name_to_id_["__sampler_sampler"] = id;
- type_name_to_id_["__sampler_comparison"] = id;
-
- } else {
- error_ = "unable to convert type: " + type->type_name();
+ if (!ok) {
return 0;
}
@@ -3995,22 +4053,31 @@
}
if (dim == ast::TextureDimension::kCubeArray) {
- if (texture->Is<sem::SampledTexture>() ||
- texture->Is<sem::DepthTexture>()) {
+ if (texture->IsAnyOf<sem::SampledTexture, sem::DepthTexture>()) {
push_capability(SpvCapabilitySampledCubeArray);
}
}
- uint32_t type_id = 0u;
- if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
- type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
- } else if (auto* s = texture->As<sem::SampledTexture>()) {
- type_id = GenerateTypeIfNeeded(s->type());
- } else if (auto* ms = texture->As<sem::MultisampledTexture>()) {
- type_id = GenerateTypeIfNeeded(ms->type());
- } else if (auto* st = texture->As<sem::StorageTexture>()) {
- type_id = GenerateTypeIfNeeded(st->type());
- }
+ uint32_t type_id = Switch(
+ texture,
+ [&](const sem::DepthTexture*) {
+ return GenerateTypeIfNeeded(builder_.create<sem::F32>());
+ },
+ [&](const sem::DepthMultisampledTexture*) {
+ return GenerateTypeIfNeeded(builder_.create<sem::F32>());
+ },
+ [&](const sem::SampledTexture* t) {
+ return GenerateTypeIfNeeded(t->type());
+ },
+ [&](const sem::MultisampledTexture* t) {
+ return GenerateTypeIfNeeded(t->type());
+ },
+ [&](const sem::StorageTexture* t) {
+ return GenerateTypeIfNeeded(t->type());
+ },
+ [&](Default) -> uint32_t { //
+ return 0u;
+ });
if (type_id == 0u) {
return false;
}
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 094e889..96a459a 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -68,23 +68,17 @@
bool GeneratorImpl::Generate() {
// Generate global declarations in the order they appear in the module.
for (auto* decl : program_->AST().GlobalDeclarations()) {
- if (auto* td = decl->As<ast::TypeDecl>()) {
- if (!EmitTypeDecl(td)) {
- return false;
- }
- } else if (auto* func = decl->As<ast::Function>()) {
- if (!EmitFunction(func)) {
- return false;
- }
- } else if (auto* var = decl->As<ast::Variable>()) {
- if (!EmitVariable(line(), var)) {
- return false;
- }
- } else {
- TINT_UNREACHABLE(Writer, diagnostics_);
+ if (!Switch(
+ decl, //
+ [&](const ast::TypeDecl* td) { return EmitTypeDecl(td); },
+ [&](const ast::Function* func) { return EmitFunction(func); },
+ [&](const ast::Variable* var) { return EmitVariable(line(), var); },
+ [&](Default) {
+ TINT_UNREACHABLE(Writer, diagnostics_);
+ return false;
+ })) {
return false;
}
-
if (decl != program_->AST().GlobalDeclarations().back()) {
line();
}
@@ -94,59 +88,64 @@
}
bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
- if (auto* alias = ty->As<ast::Alias>()) {
- auto out = line();
- out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
- if (!EmitType(out, alias->type)) {
- return false;
- }
- out << ";";
- } else if (auto* str = ty->As<ast::Struct>()) {
- if (!EmitStructType(str)) {
- return false;
- }
- } else {
- diagnostics_.add_error(
- diag::System::Writer,
- "unknown declared type: " + std::string(ty->TypeInfo().name));
- return false;
- }
- return true;
+ return Switch(
+ ty,
+ [&](const ast::Alias* alias) { //
+ auto out = line();
+ out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
+ if (!EmitType(out, alias->type)) {
+ return false;
+ }
+ out << ";";
+ return true;
+ },
+ [&](const ast::Struct* str) { //
+ return EmitStructType(str);
+ },
+ [&](Default) { //
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown declared type: " + std::string(ty->TypeInfo().name));
+ return false;
+ });
}
bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) {
- if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
- return EmitIndexAccessor(out, a);
- }
- if (auto* b = expr->As<ast::BinaryExpression>()) {
- return EmitBinary(out, b);
- }
- if (auto* b = expr->As<ast::BitcastExpression>()) {
- return EmitBitcast(out, b);
- }
- if (auto* c = expr->As<ast::CallExpression>()) {
- return EmitCall(out, c);
- }
- if (auto* i = expr->As<ast::IdentifierExpression>()) {
- return EmitIdentifier(out, i);
- }
- if (auto* l = expr->As<ast::LiteralExpression>()) {
- return EmitLiteral(out, l);
- }
- if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
- return EmitMemberAccessor(out, m);
- }
- if (expr->Is<ast::PhonyExpression>()) {
- out << "_";
- return true;
- }
- if (auto* u = expr->As<ast::UnaryOpExpression>()) {
- return EmitUnaryOp(out, u);
- }
-
- diagnostics_.add_error(diag::System::Writer, "unknown expression type");
- return false;
+ return Switch(
+ expr,
+ [&](const ast::IndexAccessorExpression* a) { //
+ return EmitIndexAccessor(out, a);
+ },
+ [&](const ast::BinaryExpression* b) { //
+ return EmitBinary(out, b);
+ },
+ [&](const ast::BitcastExpression* b) { //
+ return EmitBitcast(out, b);
+ },
+ [&](const ast::CallExpression* c) { //
+ return EmitCall(out, c);
+ },
+ [&](const ast::IdentifierExpression* i) { //
+ return EmitIdentifier(out, i);
+ },
+ [&](const ast::LiteralExpression* l) { //
+ return EmitLiteral(out, l);
+ },
+ [&](const ast::MemberAccessorExpression* m) { //
+ return EmitMemberAccessor(out, m);
+ },
+ [&](const ast::PhonyExpression*) { //
+ out << "_";
+ return true;
+ },
+ [&](const ast::UnaryOpExpression* u) { //
+ return EmitUnaryOp(out, u);
+ },
+ [&](Default) {
+ diagnostics_.add_error(diag::System::Writer, "unknown expression type");
+ return false;
+ });
}
bool GeneratorImpl::EmitIndexAccessor(
@@ -250,19 +249,28 @@
bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) {
- if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {
- out << (bl->value ? "true" : "false");
- } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
- out << FloatToBitPreservingString(fl->value);
- } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
- out << sl->value;
- } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
- out << ul->value << "u";
- } else {
- diagnostics_.add_error(diag::System::Writer, "unknown literal type");
- return false;
- }
- return true;
+ return Switch(
+ lit,
+ [&](const ast::BoolLiteralExpression* bl) { //
+ out << (bl->value ? "true" : "false");
+ return true;
+ },
+ [&](const ast::FloatLiteralExpression* fl) { //
+ out << FloatToBitPreservingString(fl->value);
+ return true;
+ },
+ [&](const ast::SintLiteralExpression* sl) { //
+ out << sl->value;
+ return true;
+ },
+ [&](const ast::UintLiteralExpression* ul) { //
+ out << ul->value << "u";
+ return true;
+ },
+ [&](Default) { //
+ diagnostics_.add_error(diag::System::Writer, "unknown literal type");
+ return false;
+ });
}
bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@@ -366,155 +374,208 @@
}
bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
- if (auto* ary = ty->As<ast::Array>()) {
- for (auto* attr : ary->attributes) {
- if (auto* stride = attr->As<ast::StrideAttribute>()) {
- out << "@stride(" << stride->stride << ") ";
- }
- }
+ return Switch(
+ ty,
+ [&](const ast::Array* ary) {
+ for (auto* attr : ary->attributes) {
+ if (auto* stride = attr->As<ast::StrideAttribute>()) {
+ out << "@stride(" << stride->stride << ") ";
+ }
+ }
- out << "array<";
- if (!EmitType(out, ary->type)) {
- return false;
- }
+ out << "array<";
+ if (!EmitType(out, ary->type)) {
+ return false;
+ }
- if (!ary->IsRuntimeArray()) {
- out << ", ";
- if (!EmitExpression(out, ary->count)) {
- return false;
- }
- }
+ if (!ary->IsRuntimeArray()) {
+ out << ", ";
+ if (!EmitExpression(out, ary->count)) {
+ return false;
+ }
+ }
- out << ">";
- } else if (ty->Is<ast::Bool>()) {
- out << "bool";
- } else if (ty->Is<ast::F32>()) {
- out << "f32";
- } else if (ty->Is<ast::I32>()) {
- out << "i32";
- } else if (auto* mat = ty->As<ast::Matrix>()) {
- out << "mat" << mat->columns << "x" << mat->rows;
- if (auto* el_ty = mat->type) {
- out << "<";
- if (!EmitType(out, el_ty)) {
- return false;
- }
- out << ">";
- }
- } else if (auto* ptr = ty->As<ast::Pointer>()) {
- out << "ptr<" << ptr->storage_class << ", ";
- if (!EmitType(out, ptr->type)) {
- return false;
- }
- if (ptr->access != ast::Access::kUndefined) {
- out << ", ";
- if (!EmitAccess(out, ptr->access)) {
- return false;
- }
- }
- out << ">";
- } else if (auto* atomic = ty->As<ast::Atomic>()) {
- out << "atomic<";
- if (!EmitType(out, atomic->type)) {
- return false;
- }
- out << ">";
- } else if (auto* sampler = ty->As<ast::Sampler>()) {
- out << "sampler";
+ out << ">";
+ return true;
+ },
+ [&](const ast::Bool*) {
+ out << "bool";
+ return true;
+ },
+ [&](const ast::F32*) {
+ out << "f32";
+ return true;
+ },
+ [&](const ast::I32*) {
+ out << "i32";
+ return true;
+ },
+ [&](const ast::Matrix* mat) {
+ out << "mat" << mat->columns << "x" << mat->rows;
+ if (auto* el_ty = mat->type) {
+ out << "<";
+ if (!EmitType(out, el_ty)) {
+ return false;
+ }
+ out << ">";
+ }
+ return true;
+ },
+ [&](const ast::Pointer* ptr) {
+ out << "ptr<" << ptr->storage_class << ", ";
+ if (!EmitType(out, ptr->type)) {
+ return false;
+ }
+ if (ptr->access != ast::Access::kUndefined) {
+ out << ", ";
+ if (!EmitAccess(out, ptr->access)) {
+ return false;
+ }
+ }
+ out << ">";
+ return true;
+ },
+ [&](const ast::Atomic* atomic) {
+ out << "atomic<";
+ if (!EmitType(out, atomic->type)) {
+ return false;
+ }
+ out << ">";
+ return true;
+ },
+ [&](const ast::Sampler* sampler) {
+ out << "sampler";
- if (sampler->IsComparison()) {
- out << "_comparison";
- }
- } else if (ty->Is<ast::ExternalTexture>()) {
- out << "texture_external";
- } else if (auto* texture = ty->As<ast::Texture>()) {
- out << "texture_";
- if (texture->Is<ast::DepthTexture>()) {
- out << "depth_";
- } else if (texture->Is<ast::DepthMultisampledTexture>()) {
- out << "depth_multisampled_";
- } else if (texture->Is<ast::SampledTexture>()) {
- /* nothing to emit */
- } else if (texture->Is<ast::MultisampledTexture>()) {
- out << "multisampled_";
- } else if (texture->Is<ast::StorageTexture>()) {
- out << "storage_";
- } else {
- diagnostics_.add_error(diag::System::Writer, "unknown texture type");
- return false;
- }
+ if (sampler->IsComparison()) {
+ out << "_comparison";
+ }
+ return true;
+ },
+ [&](const ast::ExternalTexture*) {
+ out << "texture_external";
+ return true;
+ },
+ [&](const ast::Texture* texture) {
+ out << "texture_";
+ bool ok = Switch(
+ texture,
+ [&](const ast::DepthTexture*) { //
+ out << "depth_";
+ return true;
+ },
+ [&](const ast::DepthMultisampledTexture*) { //
+ out << "depth_multisampled_";
+ return true;
+ },
+ [&](const ast::SampledTexture*) { //
+ /* nothing to emit */
+ return true;
+ },
+ [&](const ast::MultisampledTexture*) { //
+ out << "multisampled_";
+ return true;
+ },
+ [&](const ast::StorageTexture*) { //
+ out << "storage_";
+ return true;
+ },
+ [&](Default) { //
+ diagnostics_.add_error(diag::System::Writer,
+ "unknown texture type");
+ return false;
+ });
+ if (!ok) {
+ return false;
+ }
- switch (texture->dim) {
- case ast::TextureDimension::k1d:
- out << "1d";
- break;
- case ast::TextureDimension::k2d:
- out << "2d";
- break;
- case ast::TextureDimension::k2dArray:
- out << "2d_array";
- break;
- case ast::TextureDimension::k3d:
- out << "3d";
- break;
- case ast::TextureDimension::kCube:
- out << "cube";
- break;
- case ast::TextureDimension::kCubeArray:
- out << "cube_array";
- break;
- default:
- diagnostics_.add_error(diag::System::Writer,
- "unknown texture dimension");
- return false;
- }
+ switch (texture->dim) {
+ case ast::TextureDimension::k1d:
+ out << "1d";
+ break;
+ case ast::TextureDimension::k2d:
+ out << "2d";
+ break;
+ case ast::TextureDimension::k2dArray:
+ out << "2d_array";
+ break;
+ case ast::TextureDimension::k3d:
+ out << "3d";
+ break;
+ case ast::TextureDimension::kCube:
+ out << "cube";
+ break;
+ case ast::TextureDimension::kCubeArray:
+ out << "cube_array";
+ break;
+ default:
+ diagnostics_.add_error(diag::System::Writer,
+ "unknown texture dimension");
+ return false;
+ }
- if (auto* sampled = texture->As<ast::SampledTexture>()) {
- out << "<";
- if (!EmitType(out, sampled->type)) {
+ return Switch(
+ texture,
+ [&](const ast::SampledTexture* sampled) { //
+ out << "<";
+ if (!EmitType(out, sampled->type)) {
+ return false;
+ }
+ out << ">";
+ return true;
+ },
+ [&](const ast::MultisampledTexture* ms) { //
+ out << "<";
+ if (!EmitType(out, ms->type)) {
+ return false;
+ }
+ out << ">";
+ return true;
+ },
+ [&](const ast::StorageTexture* storage) { //
+ out << "<";
+ if (!EmitImageFormat(out, storage->format)) {
+ return false;
+ }
+ out << ", ";
+ if (!EmitAccess(out, storage->access)) {
+ return false;
+ }
+ out << ">";
+ return true;
+ },
+ [&](Default) { //
+ return true;
+ });
+ },
+ [&](const ast::U32*) {
+ out << "u32";
+ return true;
+ },
+ [&](const ast::Vector* vec) {
+ out << "vec" << vec->width;
+ if (auto* el_ty = vec->type) {
+ out << "<";
+ if (!EmitType(out, el_ty)) {
+ return false;
+ }
+ out << ">";
+ }
+ return true;
+ },
+ [&](const ast::Void*) {
+ out << "void";
+ return true;
+ },
+ [&](const ast::TypeName* tn) {
+ out << program_->Symbols().NameFor(tn->name);
+ return true;
+ },
+ [&](Default) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown type in EmitType: " + std::string(ty->TypeInfo().name));
return false;
- }
- out << ">";
- } else if (auto* ms = texture->As<ast::MultisampledTexture>()) {
- out << "<";
- if (!EmitType(out, ms->type)) {
- return false;
- }
- out << ">";
- } else if (auto* storage = texture->As<ast::StorageTexture>()) {
- out << "<";
- if (!EmitImageFormat(out, storage->format)) {
- return false;
- }
- out << ", ";
- if (!EmitAccess(out, storage->access)) {
- return false;
- }
- out << ">";
- }
-
- } else if (ty->Is<ast::U32>()) {
- out << "u32";
- } else if (auto* vec = ty->As<ast::Vector>()) {
- out << "vec" << vec->width;
- if (auto* el_ty = vec->type) {
- out << "<";
- if (!EmitType(out, el_ty)) {
- return false;
- }
- out << ">";
- }
- } else if (ty->Is<ast::Void>()) {
- out << "void";
- } else if (auto* tn = ty->As<ast::TypeName>()) {
- out << program_->Symbols().NameFor(tn->name);
- } else {
- diagnostics_.add_error(
- diag::System::Writer,
- "unknown type in EmitType: " + std::string(ty->TypeInfo().name));
- return false;
- }
- return true;
+ });
}
bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
@@ -632,56 +693,90 @@
}
first = false;
out << "@";
- if (auto* workgroup = attr->As<ast::WorkgroupAttribute>()) {
- auto values = workgroup->Values();
- out << "workgroup_size(";
- for (int i = 0; i < 3; i++) {
- if (values[i]) {
- if (i > 0) {
- out << ", ";
+ bool ok = Switch(
+ attr,
+ [&](const ast::WorkgroupAttribute* workgroup) {
+ auto values = workgroup->Values();
+ out << "workgroup_size(";
+ for (int i = 0; i < 3; i++) {
+ if (values[i]) {
+ if (i > 0) {
+ out << ", ";
+ }
+ if (!EmitExpression(out, values[i])) {
+ return false;
+ }
+ }
}
- if (!EmitExpression(out, values[i])) {
- return false;
+ out << ")";
+ return true;
+ },
+ [&](const ast::StructBlockAttribute*) { //
+ out << "block";
+ return true;
+ },
+ [&](const ast::StageAttribute* stage) {
+ out << "stage(" << stage->stage << ")";
+ return true;
+ },
+ [&](const ast::BindingAttribute* binding) {
+ out << "binding(" << binding->value << ")";
+ return true;
+ },
+ [&](const ast::GroupAttribute* group) {
+ out << "group(" << group->value << ")";
+ return true;
+ },
+ [&](const ast::LocationAttribute* location) {
+ out << "location(" << location->value << ")";
+ return true;
+ },
+ [&](const ast::BuiltinAttribute* builtin) {
+ out << "builtin(" << builtin->builtin << ")";
+ return true;
+ },
+ [&](const ast::InterpolateAttribute* interpolate) {
+ out << "interpolate(" << interpolate->type;
+ if (interpolate->sampling != ast::InterpolationSampling::kNone) {
+ out << ", " << interpolate->sampling;
}
- }
- }
- out << ")";
- } else if (attr->Is<ast::StructBlockAttribute>()) {
- out << "block";
- } else if (auto* stage = attr->As<ast::StageAttribute>()) {
- out << "stage(" << stage->stage << ")";
- } else if (auto* binding = attr->As<ast::BindingAttribute>()) {
- out << "binding(" << binding->value << ")";
- } else if (auto* group = attr->As<ast::GroupAttribute>()) {
- out << "group(" << group->value << ")";
- } else if (auto* location = attr->As<ast::LocationAttribute>()) {
- out << "location(" << location->value << ")";
- } else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
- out << "builtin(" << builtin->builtin << ")";
- } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
- out << "interpolate(" << interpolate->type;
- if (interpolate->sampling != ast::InterpolationSampling::kNone) {
- out << ", " << interpolate->sampling;
- }
- out << ")";
- } else if (attr->Is<ast::InvariantAttribute>()) {
- out << "invariant";
- } else if (auto* override_attr = attr->As<ast::OverrideAttribute>()) {
- out << "override";
- if (override_attr->has_value) {
- out << "(" << override_attr->value << ")";
- }
- } else if (auto* size = attr->As<ast::StructMemberSizeAttribute>()) {
- out << "size(" << size->size << ")";
- } else if (auto* align = attr->As<ast::StructMemberAlignAttribute>()) {
- out << "align(" << align->align << ")";
- } else if (auto* stride = attr->As<ast::StrideAttribute>()) {
- out << "stride(" << stride->stride << ")";
- } else if (auto* internal = attr->As<ast::InternalAttribute>()) {
- out << "internal(" << internal->InternalName() << ")";
- } else {
- TINT_ICE(Writer, diagnostics_)
- << "Unsupported attribute '" << attr->TypeInfo().name << "'";
+ out << ")";
+ return true;
+ },
+ [&](const ast::InvariantAttribute*) {
+ out << "invariant";
+ return true;
+ },
+ [&](const ast::OverrideAttribute* override_deco) {
+ out << "override";
+ if (override_deco->has_value) {
+ out << "(" << override_deco->value << ")";
+ }
+ return true;
+ },
+ [&](const ast::StructMemberSizeAttribute* size) {
+ out << "size(" << size->size << ")";
+ return true;
+ },
+ [&](const ast::StructMemberAlignAttribute* align) {
+ out << "align(" << align->align << ")";
+ return true;
+ },
+ [&](const ast::StrideAttribute* stride) {
+ out << "stride(" << stride->stride << ")";
+ return true;
+ },
+ [&](const ast::InternalAttribute* internal) {
+ out << "internal(" << internal->InternalName() << ")";
+ return true;
+ },
+ [&](Default) {
+ TINT_ICE(Writer, diagnostics_)
+ << "Unsupported attribute '" << attr->TypeInfo().name << "'";
+ return false;
+ });
+
+ if (!ok) {
return false;
}
}
@@ -809,55 +904,36 @@
}
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
- if (auto* a = stmt->As<ast::AssignmentStatement>()) {
- return EmitAssign(a);
- }
- if (auto* b = stmt->As<ast::BlockStatement>()) {
- return EmitBlock(b);
- }
- if (auto* b = stmt->As<ast::BreakStatement>()) {
- return EmitBreak(b);
- }
- if (auto* c = stmt->As<ast::CallStatement>()) {
- auto out = line();
- if (!EmitCall(out, c->expr)) {
- return false;
- }
- out << ";";
- return true;
- }
- if (auto* c = stmt->As<ast::ContinueStatement>()) {
- return EmitContinue(c);
- }
- if (auto* d = stmt->As<ast::DiscardStatement>()) {
- return EmitDiscard(d);
- }
- if (auto* f = stmt->As<ast::FallthroughStatement>()) {
- return EmitFallthrough(f);
- }
- if (auto* i = stmt->As<ast::IfStatement>()) {
- return EmitIf(i);
- }
- if (auto* l = stmt->As<ast::LoopStatement>()) {
- return EmitLoop(l);
- }
- if (auto* l = stmt->As<ast::ForLoopStatement>()) {
- return EmitForLoop(l);
- }
- if (auto* r = stmt->As<ast::ReturnStatement>()) {
- return EmitReturn(r);
- }
- if (auto* s = stmt->As<ast::SwitchStatement>()) {
- return EmitSwitch(s);
- }
- if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
- return EmitVariable(line(), v->variable);
- }
-
- diagnostics_.add_error(
- diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- return false;
+ return Switch(
+ stmt, //
+ [&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
+ [&](const ast::BlockStatement* b) { return EmitBlock(b); },
+ [&](const ast::BreakStatement* b) { return EmitBreak(b); },
+ [&](const ast::CallStatement* c) {
+ auto out = line();
+ if (!EmitCall(out, c->expr)) {
+ return false;
+ }
+ out << ";";
+ return true;
+ },
+ [&](const ast::ContinueStatement* c) { return EmitContinue(c); },
+ [&](const ast::DiscardStatement* d) { return EmitDiscard(d); },
+ [&](const ast::FallthroughStatement* f) { return EmitFallthrough(f); },
+ [&](const ast::IfStatement* i) { return EmitIf(i); },
+ [&](const ast::LoopStatement* l) { return EmitLoop(l); },
+ [&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
+ [&](const ast::ReturnStatement* r) { return EmitReturn(r); },
+ [&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
+ [&](const ast::VariableDeclStatement* v) {
+ return EmitVariable(line(), v->variable);
+ },
+ [&](Default) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "unknown statement type: " + std::string(stmt->TypeInfo().name));
+ return false;
+ });
}
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {