Use the new Switch() inferred types
Change-Id: I48ecd18957101631caa27480e7b1937a10791118
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81106
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index fe32d35..9c16094 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -1125,7 +1125,7 @@
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
- [&](const Matrix* matrix_type) -> bool {
+ [&](const Matrix* matrix_type) {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 42f8cb1..d7c13a6 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -121,20 +121,14 @@
for (auto* decl : dependencies_.ordered_globals) {
Mark(decl);
if (!Switch(
- decl, //
- [&](const ast::TypeDecl* td) { //
- return TypeDecl(td) != nullptr;
- },
- [&](const ast::Function* func) {
- return Function(func) != nullptr;
- },
- [&](const ast::Variable* var) {
- return GlobalVariable(var) != nullptr;
- },
+ decl, //
+ [&](const ast::TypeDecl* td) { return TypeDecl(td); },
+ [&](const ast::Function* func) { return Function(func); },
+ [&](const ast::Variable* var) { return GlobalVariable(var); },
[&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "unhandled global declaration: " << decl->TypeInfo().name;
- return false;
+ return nullptr;
})) {
return false;
}
@@ -165,23 +159,13 @@
sem::Type* Resolver::Type(const ast::Type* ty) {
Mark(ty);
auto* s = Switch(
- ty,
- [&](const ast::Void*) -> sem::Type* {
- return builder_->create<sem::Void>();
- },
- [&](const ast::Bool*) -> sem::Type* {
- return builder_->create<sem::Bool>();
- },
- [&](const ast::I32*) -> sem::Type* {
- return builder_->create<sem::I32>();
- },
- [&](const ast::U32*) -> sem::Type* {
- return builder_->create<sem::U32>();
- },
- [&](const ast::F32*) -> sem::Type* {
- return builder_->create<sem::F32>();
- },
- [&](const ast::Vector* t) -> sem::Type* {
+ ty, //
+ [&](const ast::Void*) { return builder_->create<sem::Void>(); },
+ [&](const ast::Bool*) { return builder_->create<sem::Bool>(); },
+ [&](const ast::I32*) { return builder_->create<sem::I32>(); },
+ [&](const ast::U32*) { return builder_->create<sem::U32>(); },
+ [&](const ast::F32*) { return builder_->create<sem::F32>(); },
+ [&](const ast::Vector* t) -> sem::Vector* {
if (!t->type) {
AddError("missing vector element type", t->source.End());
return nullptr;
@@ -195,7 +179,7 @@
}
return nullptr;
},
- [&](const ast::Matrix* t) -> sem::Type* {
+ [&](const ast::Matrix* t) -> sem::Matrix* {
if (!t->type) {
AddError("missing matrix element type", t->source.End());
return nullptr;
@@ -212,8 +196,8 @@
}
return nullptr;
},
- [&](const ast::Array* t) -> sem::Type* { return Array(t); },
- [&](const ast::Atomic* t) -> sem::Type* {
+ [&](const ast::Array* t) { return Array(t); },
+ [&](const ast::Atomic* t) -> sem::Atomic* {
if (auto* el = Type(t->type)) {
auto* a = builder_->create<sem::Atomic>(el);
if (!ValidateAtomic(t, a)) {
@@ -223,7 +207,7 @@
}
return nullptr;
},
- [&](const ast::Pointer* t) -> sem::Type* {
+ [&](const ast::Pointer* t) -> sem::Pointer* {
if (auto* el = Type(t->type)) {
auto access = t->access;
if (access == ast::kUndefined) {
@@ -233,28 +217,28 @@
}
return nullptr;
},
- [&](const ast::Sampler* t) -> sem::Type* {
+ [&](const ast::Sampler* t) {
return builder_->create<sem::Sampler>(t->kind);
},
- [&](const ast::SampledTexture* t) -> sem::Type* {
+ [&](const ast::SampledTexture* t) -> sem::SampledTexture* {
if (auto* el = Type(t->type)) {
return builder_->create<sem::SampledTexture>(t->dim, el);
}
return nullptr;
},
- [&](const ast::MultisampledTexture* t) -> sem::Type* {
+ [&](const ast::MultisampledTexture* t) -> sem::MultisampledTexture* {
if (auto* el = Type(t->type)) {
return builder_->create<sem::MultisampledTexture>(t->dim, el);
}
return nullptr;
},
- [&](const ast::DepthTexture* t) -> sem::Type* {
+ [&](const ast::DepthTexture* t) {
return builder_->create<sem::DepthTexture>(t->dim);
},
- [&](const ast::DepthMultisampledTexture* t) -> sem::Type* {
+ [&](const ast::DepthMultisampledTexture* t) {
return builder_->create<sem::DepthMultisampledTexture>(t->dim);
},
- [&](const ast::StorageTexture* t) -> sem::Type* {
+ [&](const ast::StorageTexture* t) -> sem::StorageTexture* {
if (auto* el = Type(t->type)) {
if (!ValidateStorageTexture(t)) {
return nullptr;
@@ -264,10 +248,10 @@
}
return nullptr;
},
- [&](const ast::ExternalTexture*) -> sem::Type* {
+ [&](const ast::ExternalTexture*) {
return builder_->create<sem::ExternalTexture>();
},
- [&](Default) -> sem::Type* {
+ [&](Default) {
auto* resolved = ResolvedSymbol(ty);
return Switch(
resolved, //
@@ -858,62 +842,40 @@
stmt,
// Compound statements. These create their own sem::CompoundStatement
// bindings.
- [&](const ast::BlockStatement* b) -> sem::Statement* {
- return BlockStatement(b);
- },
- [&](const ast::ForLoopStatement* l) -> sem::Statement* {
- return ForLoopStatement(l);
- },
- [&](const ast::LoopStatement* l) -> sem::Statement* {
- return LoopStatement(l);
- },
- [&](const ast::IfStatement* i) -> sem::Statement* {
- return IfStatement(i);
- },
- [&](const ast::SwitchStatement* s) -> sem::Statement* {
- return SwitchStatement(s);
- },
+ [&](const ast::BlockStatement* b) { return BlockStatement(b); },
+ [&](const ast::ForLoopStatement* l) { return ForLoopStatement(l); },
+ [&](const ast::LoopStatement* l) { return LoopStatement(l); },
+ [&](const ast::IfStatement* i) { return IfStatement(i); },
+ [&](const ast::SwitchStatement* s) { return SwitchStatement(s); },
// Non-Compound statements
- [&](const ast::AssignmentStatement* a) -> sem::Statement* {
- return AssignmentStatement(a);
- },
- [&](const ast::BreakStatement* b) -> sem::Statement* {
- return BreakStatement(b);
- },
- [&](const ast::CallStatement* c) -> sem::Statement* {
- return CallStatement(c);
- },
- [&](const ast::ContinueStatement* c) -> sem::Statement* {
- return ContinueStatement(c);
- },
- [&](const ast::DiscardStatement* d) -> sem::Statement* {
- return DiscardStatement(d);
- },
- [&](const ast::FallthroughStatement* f) -> sem::Statement* {
+ [&](const ast::AssignmentStatement* a) { return AssignmentStatement(a); },
+ [&](const ast::BreakStatement* b) { return BreakStatement(b); },
+ [&](const ast::CallStatement* c) { return CallStatement(c); },
+ [&](const ast::ContinueStatement* c) { return ContinueStatement(c); },
+ [&](const ast::DiscardStatement* d) { return DiscardStatement(d); },
+ [&](const ast::FallthroughStatement* f) {
return FallthroughStatement(f);
},
- [&](const ast::ReturnStatement* r) -> sem::Statement* {
- return ReturnStatement(r);
- },
- [&](const ast::VariableDeclStatement* v) -> sem::Statement* {
+ [&](const ast::ReturnStatement* r) { return ReturnStatement(r); },
+ [&](const ast::VariableDeclStatement* v) {
return VariableDeclStatement(v);
},
// Error cases
- [&](const ast::CaseStatement*) -> sem::Statement* {
+ [&](const ast::CaseStatement*) {
AddError("case statement can only be used inside a switch statement",
stmt->source);
return nullptr;
},
- [&](const ast::ElseStatement*) -> sem::Statement* {
+ [&](const ast::ElseStatement*) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::Statement() encountered an Else statement. Else "
"statements are embedded in If statements, so should never be "
"encountered as top-level statements";
return nullptr;
},
- [&](Default) -> sem::Statement* {
+ [&](Default) {
AddError(
"unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source);
@@ -1196,16 +1158,12 @@
auto* obj_ty = obj_raw_ty->UnwrapRef();
auto* ty = Switch(
obj_ty, //
- [&](const sem::Array* arr) -> const sem::Type* {
- return arr->ElemType();
- },
- [&](const sem::Vector* vec) -> const sem::Type* { //
- return vec->type();
- },
- [&](const sem::Matrix* mat) -> const sem::Type* {
+ [&](const sem::Array* arr) { return arr->ElemType(); },
+ [&](const sem::Vector* vec) { return vec->type(); },
+ [&](const sem::Matrix* mat) {
return builder_->create<sem::Vector>(mat->type(), mat->rows());
},
- [&](Default) -> const sem::Type* {
+ [&](Default) {
AddError("cannot index type '" + TypeNameOf(obj_ty) + "'",
expr->source);
return nullptr;
@@ -2188,19 +2146,19 @@
sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
return Switch(
lit,
- [&](const ast::SintLiteralExpression*) -> sem::Type* {
+ [&](const ast::SintLiteralExpression*) {
return builder_->create<sem::I32>();
},
- [&](const ast::UintLiteralExpression*) -> sem::Type* {
+ [&](const ast::UintLiteralExpression*) {
return builder_->create<sem::U32>();
},
- [&](const ast::FloatLiteralExpression*) -> sem::Type* {
+ [&](const ast::FloatLiteralExpression*) {
return builder_->create<sem::F32>();
},
- [&](const ast::BoolLiteralExpression*) -> sem::Type* {
+ [&](const ast::BoolLiteralExpression*) {
return builder_->create<sem::Bool>();
},
- [&](Default) -> sem::Type* {
+ [&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Unhandled literal type: " << lit->TypeInfo().name;
return nullptr;
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 83f2666..0f0623e 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -575,31 +575,29 @@
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
return Switch(
expr,
- [&](const ast::IndexAccessorExpression* a) { //
+ [&](const ast::IndexAccessorExpression* a) {
return GenerateAccessorExpression(a);
},
- [&](const ast::BinaryExpression* b) { //
+ [&](const ast::BinaryExpression* b) {
return GenerateBinaryExpression(b);
},
- [&](const ast::BitcastExpression* b) { //
+ [&](const ast::BitcastExpression* b) {
return GenerateBitcastExpression(b);
},
- [&](const ast::CallExpression* c) { //
- return GenerateCallExpression(c);
- },
- [&](const ast::IdentifierExpression* i) { //
+ [&](const ast::CallExpression* c) { return GenerateCallExpression(c); },
+ [&](const ast::IdentifierExpression* i) {
return GenerateIdentifierExpression(i);
},
- [&](const ast::LiteralExpression* l) { //
+ [&](const ast::LiteralExpression* l) {
return GenerateLiteralIfNeeded(nullptr, l);
},
- [&](const ast::MemberAccessorExpression* m) { //
+ [&](const ast::MemberAccessorExpression* m) {
return GenerateAccessorExpression(m);
},
- [&](const ast::UnaryOpExpression* u) { //
+ [&](const ast::UnaryOpExpression* u) {
return GenerateUnaryOpExpression(u);
},
- [&](Default) -> uint32_t {
+ [&](Default) {
error_ =
"unknown expression type: " + std::string(expr->TypeInfo().name);
return 0;
@@ -2271,7 +2269,7 @@
[&](const sem::TypeConstructor*) {
return GenerateTypeConstructorOrConversion(call, nullptr);
},
- [&](Default) -> uint32_t {
+ [&](Default) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name;
return 0;
@@ -4101,7 +4099,7 @@
[&](const sem::StorageTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
- [&](Default) -> uint32_t { //
+ [&](Default) {
return 0u;
});
if (type_id == 0u) {