reader/spirv - add type hierarchy
Don't create disjoint AST type nodes.
Instead use a new bespoke type hierarchy that can Build() the required
AST nodes.
Change-Id: I523f97054de2c553095056c0bafc17c48064cf53
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49966
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 0be18bd..4cb6563 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -590,6 +590,8 @@
"reader/spirv/function.h",
"reader/spirv/namer.cc",
"reader/spirv/namer.h",
+ "reader/spirv/parser_type.cc",
+ "reader/spirv/parser_type.h",
"reader/spirv/parser.cc",
"reader/spirv/parser.h",
"reader/spirv/parser_impl.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index b067715..4b2467d 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -360,6 +360,8 @@
reader/spirv/function.h
reader/spirv/namer.cc
reader/spirv/namer.h
+ reader/spirv/parser_type.cc
+ reader/spirv/parser_type.h
reader/spirv/parser.cc
reader/spirv/parser.h
reader/spirv/parser_impl.cc
@@ -627,6 +629,7 @@
reader/spirv/parser_impl_test_helper.h
reader/spirv/parser_impl_test.cc
reader/spirv/parser_impl_user_name_test.cc
+ reader/spirv/parser_type_test.cc
reader/spirv/parser_test.cc
reader/spirv/spirv_tools_helpers_test.cc
reader/spirv/spirv_tools_helpers_test.h
diff --git a/src/program_builder.h b/src/program_builder.h
index b013ffb..41d3842 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -617,26 +617,32 @@
/// @param subtype the array element type
/// @param n the array size. 0 represents a runtime-array
- /// @param stride the array stride
+ /// @param stride the array stride. 0 represents implicit stride
/// @return the tint AST type for a array of size `n` of type `T`
ast::Array* array(typ::Type subtype, uint32_t n, uint32_t stride) const {
subtype = MaybeCreateTypename(subtype);
- return array(subtype, n,
- {builder->create<ast::StrideDecoration>(stride)});
+ ast::DecorationList decos;
+ if (stride) {
+ decos.emplace_back(builder->create<ast::StrideDecoration>(stride));
+ }
+ return array(subtype, n, std::move(decos));
}
/// @param source the Source of the node
/// @param subtype the array element type
/// @param n the array size. 0 represents a runtime-array
- /// @param stride the array stride
+ /// @param stride the array stride. 0 represents implicit stride
/// @return the tint AST type for a array of size `n` of type `T`
ast::Array* array(const Source& source,
typ::Type subtype,
uint32_t n,
uint32_t stride) const {
subtype = MaybeCreateTypename(subtype);
- return array(source, subtype, n,
- {builder->create<ast::StrideDecoration>(stride)});
+ ast::DecorationList decos;
+ if (stride) {
+ decos.emplace_back(builder->create<ast::StrideDecoration>(stride));
+ }
+ return array(source, subtype, n, std::move(decos));
}
/// @return the tint AST type for an array of size `N` of type `T`
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 1cbbc34..e9c15c9 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -713,6 +713,7 @@
const spvtools::opt::Function& function,
const EntryPointInfo* ep_info)
: parser_impl_(*pi),
+ ty_(pi->type_manager()),
builder_(pi->builder()),
ir_context_(*(pi->ir_context())),
def_use_mgr_(ir_context_.get_def_use_mgr()),
@@ -733,6 +734,7 @@
FunctionEmitter::FunctionEmitter(FunctionEmitter&& other)
: parser_impl_(other.parser_impl_),
+ ty_(other.ty_),
builder_(other.builder_),
ir_context_(other.ir_context_),
def_use_mgr_(ir_context_.get_def_use_mgr()),
@@ -875,7 +877,7 @@
auto* body = create<ast::BlockStatement>(Source{}, statements);
builder_.AST().AddFunction(create<ast::Function>(
decl.source, builder_.Symbols().Register(decl.name),
- std::move(decl.params), decl.return_type, body,
+ std::move(decl.params), decl.return_type->Build(builder_), body,
std::move(decl.decorations), ast::DecorationList{}));
// Maintain the invariant by repopulating the one and only element.
@@ -943,7 +945,7 @@
return success();
}
-ast::Type* FunctionEmitter::GetVariableStoreType(
+const Type* FunctionEmitter::GetVariableStoreType(
const spvtools::opt::Instruction& var_decl_inst) {
const auto type_id = var_decl_inst.type_id();
auto* var_ref_type = type_mgr_->GetType(type_id);
@@ -2042,7 +2044,7 @@
<< id;
return {};
case SkipReason::kPointSizeBuiltinValue: {
- return {create<ast::F32>(),
+ return {ty_.F32(),
create<ast::ScalarConstructorExpression>(
Source{}, create<ast::FloatLiteral>(Source{}, 1.0f))};
}
@@ -2072,7 +2074,7 @@
case SkipReason::kSampleMaskOutBuiltinPointer:
// The result type is always u32.
auto name = namer_.Name(sample_mask_out_id);
- return TypedExpression{builder_.ty.u32(),
+ return TypedExpression{ty_.U32(),
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name))};
}
@@ -2348,14 +2350,10 @@
const std::string guard_name = block_info.flow_guard_name;
if (!guard_name.empty()) {
// Declare the guard variable just before the "if", initialized to true.
- auto* guard_var = create<ast::Variable>(
- Source{}, // source
- builder_.Symbols().Register(guard_name), // symbol
- ast::StorageClass::kFunction, // storage_class
- builder_.ty.bool_(), // type
- false, // is_const
- MakeTrue(Source{}), // constructor
- ast::DecorationList{}); // decorations
+ auto* guard_var =
+ create<ast::Variable>(Source{}, builder_.Symbols().Register(guard_name),
+ ast::StorageClass::kFunction, builder_.ty.bool_(),
+ false, MakeTrue(Source{}), ast::DecorationList{});
auto* guard_decl = create<ast::VariableDeclStatement>(Source{}, guard_var);
AddStatement(guard_decl);
}
@@ -2557,7 +2555,7 @@
// The rest of this module can handle up to 64 bit switch values.
// The Tint AST handles 32-bit values.
const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
- if (selector.type->is_unsigned_scalar_or_vector()) {
+ if (selector.type->IsUnsignedScalarOrVector()) {
selectors.emplace_back(create<ast::UintLiteral>(Source{}, value32));
} else {
selectors.emplace_back(create<ast::SintLiteral>(Source{}, value32));
@@ -2914,13 +2912,10 @@
const auto phi_var_name = GetDefInfo(id)->phi_var;
TINT_ASSERT(!phi_var_name.empty());
auto* var = create<ast::Variable>(
- Source{}, // source
- builder_.Symbols().Register(phi_var_name), // symbol
- ast::StorageClass::kFunction, // storage_class
- parser_impl_.ConvertType(def_inst->type_id()), // type
- false, // is_const
- nullptr, // constructor
- ast::DecorationList{}); // decorations
+ Source{}, builder_.Symbols().Register(phi_var_name),
+ ast::StorageClass::kFunction,
+ parser_impl_.ConvertType(def_inst->type_id())->Build(builder_), false,
+ nullptr, ast::DecorationList{});
AddStatement(create<ast::VariableDeclStatement>(Source{}, var));
}
@@ -3102,9 +3097,9 @@
case SkipReason::kSampleMaskOutBuiltinPointer:
ptr_id = sample_mask_out_id;
- if (!rhs.type->Is<ast::U32>()) {
+ if (!rhs.type->Is<U32>()) {
// WGSL requires sample_mask_out to be signed.
- rhs = TypedExpression{builder_.ty.u32(),
+ rhs = TypedExpression{ty_.U32(),
create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.u32(),
ast::ExpressionList{rhs.expr})};
@@ -3148,7 +3143,7 @@
ast::Expression* id_expr = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name));
auto expr = TypedExpression{
- builder_.ty.i32(),
+ ty_.I32(),
create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.i32(), ast::ExpressionList{id_expr})};
return EmitConstDefinition(inst, expr);
@@ -3159,10 +3154,10 @@
Source{}, builder_.Symbols().Register(name));
auto* load_result_type = parser_impl_.ConvertType(inst.type_id());
ast::Expression* ast_expr = nullptr;
- if (load_result_type->Is<ast::I32>()) {
+ if (load_result_type->Is<I32>()) {
ast_expr = create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.i32(), ast::ExpressionList{id_expr});
- } else if (load_result_type->Is<ast::U32>()) {
+ } else if (load_result_type->Is<U32>()) {
ast_expr = id_expr;
} else {
return Fail() << "loading the whole SampleMask input array is not "
@@ -3181,8 +3176,8 @@
}
// The load result type is the pointee type of its operand.
- TINT_ASSERT(expr.type->Is<ast::Pointer>());
- expr.type = expr.type->As<ast::Pointer>()->type();
+ TINT_ASSERT(expr.type->Is<Pointer>());
+ expr.type = expr.type->As<Pointer>()->type;
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
@@ -3284,7 +3279,7 @@
const auto opcode = inst.opcode();
- ast::Type* ast_type =
+ const Type* ast_type =
inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr;
auto binary_op = ConvertBinaryOp(opcode);
@@ -3329,8 +3324,9 @@
}
if (opcode == SpvOpBitcast) {
- return {ast_type, create<ast::BitcastExpression>(
- Source{}, ast_type, MakeOperand(inst, 0).expr)};
+ return {ast_type,
+ create<ast::BitcastExpression>(Source{}, ast_type->Build(builder_),
+ MakeOperand(inst, 0).expr)};
}
if (opcode == SpvOpShiftLeftLogical || opcode == SpvOpShiftRightLogical ||
@@ -3390,9 +3386,9 @@
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
operands.emplace_back(MakeOperand(inst, iarg).expr);
}
- return {ast_type, create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(ast_type),
- std::move(operands))};
+ return {ast_type,
+ create<ast::TypeConstructorExpression>(
+ Source{}, ast_type->Build(builder_), std::move(operands))};
}
if (opcode == SpvOpCompositeExtract) {
@@ -3457,7 +3453,7 @@
auto* func = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name));
ast::ExpressionList operands;
- ast::Type* first_operand_type = nullptr;
+ const Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs.
for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg);
@@ -3703,7 +3699,7 @@
type_mgr_->FindPointerToType(pointee_type_id, storage_class);
auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id);
TINT_ASSERT(ast_pointer_type);
- TINT_ASSERT(ast_pointer_type->Is<ast::Pointer>());
+ TINT_ASSERT(ast_pointer_type->Is<Pointer>());
current_expr = TypedExpression{ast_pointer_type, next_expr};
}
return current_expr;
@@ -3887,8 +3883,8 @@
// Generate an ast::TypeConstructor expression.
// Assume the literal indices are valid, and there is a valid number of them.
auto source = GetSourceForInst(inst);
- ast::Vector* result_type =
- parser_impl_.ConvertType(inst.type_id())->As<ast::Vector>();
+ const Vector* result_type =
+ As<Vector>(parser_impl_.ConvertType(inst.type_id()));
ast::ExpressionList values;
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
const auto index = inst.GetSingleWordInOperand(i);
@@ -3910,16 +3906,15 @@
source, expr.expr, Swizzle(sub_index)));
} else if (index == 0xFFFFFFFF) {
// By rule, this maps to OpUndef. Instead, make it zero.
- values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));
+ values.emplace_back(parser_impl_.MakeNullValue(result_type->type));
} else {
Fail() << "invalid vectorshuffle ID %" << inst.result_id()
<< ": index too large: " << index;
return {};
}
}
- return {result_type,
- create<ast::TypeConstructorExpression>(
- source, builder_.ty.MaybeCreateTypename(result_type), values)};
+ return {result_type, create<ast::TypeConstructorExpression>(
+ source, result_type->Build(builder_), values)};
}
bool FunctionEmitter::RegisterSpecialBuiltInVariables() {
@@ -3988,8 +3983,8 @@
if (type) {
if (type->AsPointer()) {
if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
- if (auto* ptr = ast_type->As<ast::Pointer>()) {
- info->storage_class = ptr->storage_class();
+ if (auto* ptr = ast_type->As<Pointer>()) {
+ info->storage_class = ptr->storage_class;
}
}
switch (inst.opcode()) {
@@ -4033,21 +4028,21 @@
const auto type_id = def_use_mgr_->GetDef(id)->type_id();
if (type_id) {
auto* ast_type = parser_impl_.ConvertType(type_id);
- if (auto* ptr = As<ast::Pointer>(ast_type)) {
- return ptr->storage_class();
+ if (auto* ptr = As<Pointer>(ast_type)) {
+ return ptr->storage_class;
}
}
return ast::StorageClass::kNone;
}
-ast::Type* FunctionEmitter::RemapStorageClass(ast::Type* type,
- uint32_t result_id) {
- if (auto* ast_ptr_type = type->As<ast::Pointer>()) {
+const Type* FunctionEmitter::RemapStorageClass(const Type* type,
+ uint32_t result_id) {
+ if (auto* ast_ptr_type = As<Pointer>(type)) {
// Remap an old-style storage buffer pointer to a new-style storage
// buffer pointer.
const auto sc = GetStorageClassForPointerValue(result_id);
- if (ast_ptr_type->storage_class() != sc) {
- return builder_.ty.pointer(ast_ptr_type->type(), sc);
+ if (ast_ptr_type->storage_class != sc) {
+ return ty_.Pointer(ast_ptr_type->type, sc);
}
}
return type;
@@ -4230,30 +4225,27 @@
return {};
}
- ast::Type* expr_type = nullptr;
+ const Type* expr_type = nullptr;
if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) {
- if (arg_expr.type->is_integer_scalar_or_vector()) {
+ if (arg_expr.type->IsIntegerScalarOrVector()) {
expr_type = requested_type;
} else {
Fail() << "operand for conversion to floating point must be integral "
- "scalar or vector, but got: "
- << arg_expr.type->type_name();
+ "scalar or vector";
}
} else if (inst.opcode() == SpvOpConvertFToU) {
- if (arg_expr.type->is_float_scalar_or_vector()) {
+ if (arg_expr.type->IsFloatScalarOrVector()) {
expr_type = parser_impl_.GetUnsignedIntMatchingShape(arg_expr.type);
} else {
Fail() << "operand for conversion to unsigned integer must be floating "
- "point scalar or vector, but got: "
- << arg_expr.type->type_name();
+ "point scalar or vector";
}
} else if (inst.opcode() == SpvOpConvertFToS) {
- if (arg_expr.type->is_float_scalar_or_vector()) {
+ if (arg_expr.type->IsFloatScalarOrVector()) {
expr_type = parser_impl_.GetSignedIntMatchingShape(arg_expr.type);
} else {
Fail() << "operand for conversion to signed integer must be floating "
- "point scalar or vector, but got: "
- << arg_expr.type->type_name();
+ "point scalar or vector";
}
}
if (expr_type == nullptr) {
@@ -4265,14 +4257,14 @@
params.push_back(arg_expr.expr);
TypedExpression result{
expr_type, create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(expr_type),
- std::move(params))};
+ Source{}, expr_type->Build(builder_), std::move(params))};
- if (AstTypesEquivalent(requested_type, expr_type)) {
+ if (requested_type == expr_type) {
return result;
}
- return {requested_type, create<ast::BitcastExpression>(
- Source{}, requested_type, result.expr)};
+ return {requested_type,
+ create<ast::BitcastExpression>(
+ Source{}, requested_type->Build(builder_), result.expr)};
}
bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
@@ -4296,7 +4288,7 @@
<< inst.PrettyPrint();
}
- if (result_type->Is<ast::Void>()) {
+ if (result_type->Is<Void>()) {
return nullptr !=
AddStatement(create<ast::CallStatement>(Source{}, call_expr));
}
@@ -4359,7 +4351,7 @@
Source{}, builder_.Symbols().Register(name));
ast::ExpressionList params;
- ast::Type* first_operand_type = nullptr;
+ const Type* first_operand_type = nullptr;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg);
if (first_operand_type == nullptr) {
@@ -4391,8 +4383,8 @@
// - you can't select over pointers or pointer vectors, unless you also have
// a VariablePointers* capability, which is not allowed in by WebGPU.
auto* op_ty = operand1.type;
- if (op_ty->Is<ast::Vector>() || op_ty->is_float_scalar() ||
- op_ty->is_integer_scalar() || op_ty->Is<ast::Bool>()) {
+ if (op_ty->Is<Vector>() || op_ty->IsFloatScalar() ||
+ op_ty->IsIntegerScalar() || op_ty->Is<Bool>()) {
ast::ExpressionList params;
params.push_back(operand1.expr);
params.push_back(operand2.expr);
@@ -4430,9 +4422,9 @@
return image;
}
-ast::Texture* FunctionEmitter::GetImageType(
+const Texture* FunctionEmitter::GetImageType(
const spvtools::opt::Instruction& image) {
- ast::Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image);
+ const Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image);
if (!parser_impl_.success()) {
Fail();
return {};
@@ -4441,7 +4433,7 @@
Fail() << "invalid texture type for " << image.PrettyPrint();
return {};
}
- auto* result = ptr_type->type()->UnwrapAll()->As<ast::Texture>();
+ auto* result = ptr_type->type->UnwrapAll()->As<Texture>();
if (!result) {
Fail() << "invalid texture type for " << image.PrettyPrint();
return {};
@@ -4496,12 +4488,12 @@
}
}
- ast::Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image);
+ const Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image);
if (!texture_ptr_type) {
return Fail();
}
- ast::Texture* texture_type =
- texture_ptr_type->type()->UnwrapAll()->As<ast::Texture>();
+ const Texture* texture_type =
+ texture_ptr_type->type->UnwrapAll()->As<Texture>();
if (!texture_type) {
return Fail();
@@ -4604,7 +4596,7 @@
}
TypedExpression lod = MakeOperand(inst, arg_index);
// When sampling from a depth texture, the Lod operand must be an I32.
- if (texture_type->Is<ast::DepthTexture>()) {
+ if (texture_type->Is<DepthTexture>()) {
// Convert it to a signed integer type.
lod = ToI32(lod);
}
@@ -4612,11 +4604,11 @@
image_operands_mask ^= SpvImageOperandsLodMask;
arg_index++;
} else if ((opcode == SpvOpImageFetch) &&
- (texture_type->Is<ast::SampledTexture>() ||
- texture_type->Is<ast::DepthTexture>())) {
+ (texture_type->Is<SampledTexture>() ||
+ texture_type->Is<DepthTexture>())) {
// textureLoad on sampled texture and depth texture requires an explicit
// level-of-detail parameter.
- params.push_back(parser_impl_.MakeNullValue(builder_.ty.i32()));
+ params.push_back(parser_impl_.MakeNullValue(ty_.I32()));
}
if (arg_index + 1 < num_args &&
(image_operands_mask & SpvImageOperandsGradMask)) {
@@ -4637,7 +4629,7 @@
return Fail() << "ConstOffset is only permitted for sampling operations: "
<< inst.PrettyPrint();
}
- switch (texture_type->dim()) {
+ switch (texture_type->dims) {
case ast::TextureDimension::k2d:
case ast::TextureDimension::k2dArray:
case ast::TextureDimension::k3d:
@@ -4676,8 +4668,8 @@
// The result type, derived from the SPIR-V instruction.
auto* result_type = parser_impl_.ConvertType(inst.type_id());
auto* result_component_type = result_type;
- if (auto* result_vector_type = result_type->As<ast::Vector>()) {
- result_component_type = result_vector_type->type();
+ if (auto* result_vector_type = As<Vector>(result_type)) {
+ result_component_type = result_vector_type->type;
}
// For depth textures, the arity might mot match WGSL:
@@ -4691,11 +4683,11 @@
// dref gather vec4 ImageFetch vec4 TODO(dneto)
// Construct a 4-element vector with the result from the builtin in the
// first component.
- if (texture_type->Is<ast::DepthTexture>()) {
+ if (texture_type->Is<DepthTexture>()) {
if (is_non_dref_sample || (opcode == SpvOpImageFetch)) {
value = create<ast::TypeConstructorExpression>(
Source{},
- builder_.ty.MaybeCreateTypename(result_type), // a vec4
+ result_type->Build(builder_), // a vec4
ast::ExpressionList{
value, parser_impl_.MakeNullValue(result_component_type),
parser_impl_.MakeNullValue(result_component_type),
@@ -4714,13 +4706,13 @@
}
auto* expected_component_type =
parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0));
- if (!AstTypesEquivalent(expected_component_type, result_component_type)) {
+ if (expected_component_type != result_component_type) {
// This occurs if one is signed integer and the other is unsigned integer,
// or vice versa. Perform a bitcast.
- value = create<ast::BitcastExpression>(Source{}, result_type, call_expr);
+ value = create<ast::BitcastExpression>(
+ Source{}, result_type->Build(builder_), call_expr);
}
- if (!expected_component_type->Is<ast::F32>() &&
- IsSampledImageAccess(opcode)) {
+ if (!expected_component_type->Is<F32>() && IsSampledImageAccess(opcode)) {
// WGSL permits sampled image access only on float textures.
// Reject this case in the SPIR-V reader, at least until SPIR-V validation
// catches up with this rule and can reject it earlier in the workflow.
@@ -4763,7 +4755,7 @@
}
exprs.push_back(
create<ast::CallExpression>(Source{}, dims_ident, dims_args));
- if (ast::IsTextureArray(texture_type->dim())) {
+ if (ast::IsTextureArray(texture_type->dims)) {
auto* layers_ident = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("textureNumLayers"));
exprs.push_back(create<ast::CallExpression>(
@@ -4772,9 +4764,8 @@
}
auto* result_type = parser_impl_.ConvertType(inst.type_id());
TypedExpression expr = {
- result_type,
- create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(result_type), exprs)};
+ result_type, create<ast::TypeConstructorExpression>(
+ Source{}, result_type->Build(builder_), exprs)};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpImageQueryLod:
@@ -4794,9 +4785,9 @@
auto* result_type = parser_impl_.ConvertType(inst.type_id());
// The SPIR-V result type must be integer scalar. The WGSL bulitin
// returns i32. If they aren't the same then convert the result.
- if (!result_type->Is<ast::I32>()) {
+ if (!result_type->Is<I32>()) {
ast_expr = create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(result_type),
+ Source{}, result_type->Build(builder_),
ast::ExpressionList{ast_expr});
}
TypedExpression expr{result_type, ast_expr};
@@ -4840,28 +4831,27 @@
if (!raw_coords.type) {
return {};
}
- ast::Texture* texture_type = GetImageType(*image);
+ const Texture* texture_type = GetImageType(*image);
if (!texture_type) {
return {};
}
- ast::TextureDimension dim = texture_type->dim();
+ ast::TextureDimension dim = texture_type->dims;
// Number of regular coordinates.
uint32_t num_axes = ast::NumCoordinateAxes(dim);
bool is_arrayed = ast::IsTextureArray(dim);
if ((num_axes == 0) || (num_axes > 3)) {
Fail() << "unsupported image dimensionality for "
- << texture_type->type_name() << " prompted by "
+ << texture_type->TypeInfo().name << " prompted by "
<< inst.PrettyPrint();
}
const auto num_coords_required = num_axes + (is_arrayed ? 1 : 0);
uint32_t num_coords_supplied = 0;
auto* component_type = raw_coords.type;
- if (component_type->is_float_scalar() ||
- component_type->is_integer_scalar()) {
+ if (component_type->IsFloatScalar() || component_type->IsIntegerScalar()) {
num_coords_supplied = 1;
- } else if (auto* vec_type = raw_coords.type->As<ast::Vector>()) {
- component_type = vec_type->type();
- num_coords_supplied = vec_type->size();
+ } else if (auto* vec_type = As<Vector>(raw_coords.type)) {
+ component_type = vec_type->type;
+ num_coords_supplied = vec_type->size;
}
if (num_coords_supplied == 0) {
Fail() << "bad or unsupported coordinate type for image access: "
@@ -4884,10 +4874,8 @@
// will actually use them.
auto prefix_swizzle_expr = [this, num_axes, component_type,
raw_coords]() -> ast::Expression* {
- auto* swizzle_type = (num_axes == 1)
- ? component_type
- : static_cast<ast::Type*>(
- builder_.ty.vec(component_type, num_axes));
+ auto* swizzle_type =
+ (num_axes == 1) ? component_type : ty_.Vector(component_type, num_axes);
auto* swizzle = create<ast::MemberAccessorExpression>(
Source{}, raw_coords.expr, PrefixSwizzle(num_axes));
return ToSignedIfUnsigned({swizzle_type, swizzle}).expr;
@@ -4921,32 +4909,32 @@
ast::Expression* FunctionEmitter::ConvertTexelForStorage(
const spvtools::opt::Instruction& inst,
TypedExpression texel,
- ast::Texture* texture_type) {
- auto* storage_texture_type = texture_type->As<ast::StorageTexture>();
+ const Texture* texture_type) {
+ auto* storage_texture_type = As<StorageTexture>(texture_type);
auto* src_type = texel.type;
if (!storage_texture_type) {
Fail() << "writing to other than storage texture: " << inst.PrettyPrint();
return nullptr;
}
- const auto format = storage_texture_type->image_format();
+ const auto format = storage_texture_type->format;
auto* dest_type = parser_impl_.GetTexelTypeForFormat(format);
if (!dest_type) {
Fail();
return nullptr;
}
- if (AstTypesEquivalent(src_type, dest_type)) {
+ if (src_type == dest_type) {
return texel.expr;
}
const uint32_t dest_count =
- dest_type->is_scalar() ? 1 : dest_type->As<ast::Vector>()->size();
+ dest_type->IsScalar() ? 1 : dest_type->As<Vector>()->size;
if (dest_count == 3) {
Fail() << "3-channel storage textures are not supported: "
<< inst.PrettyPrint();
return nullptr;
}
const uint32_t src_count =
- src_type->is_scalar() ? 1 : src_type->As<ast::Vector>()->size();
+ src_type->IsScalar() ? 1 : src_type->As<Vector>()->size;
if (src_count < dest_count) {
Fail() << "texel has too few components for storage texture: " << src_count
<< " provided but " << dest_count
@@ -4961,29 +4949,29 @@
: create<ast::MemberAccessorExpression>(Source{}, texel.expr,
PrefixSwizzle(dest_count));
- if (!(dest_type->is_float_scalar_or_vector() ||
- dest_type->is_unsigned_scalar_or_vector() ||
- dest_type->is_signed_scalar_or_vector())) {
+ if (!(dest_type->IsFloatScalarOrVector() ||
+ dest_type->IsUnsignedScalarOrVector() ||
+ dest_type->IsSignedScalarOrVector())) {
Fail() << "invalid destination type for storage texture write: "
- << dest_type->type_name();
+ << dest_type->TypeInfo().name;
return nullptr;
}
- if (!(src_type->is_float_scalar_or_vector() ||
- src_type->is_unsigned_scalar_or_vector() ||
- src_type->is_signed_scalar_or_vector())) {
+ if (!(src_type->IsFloatScalarOrVector() ||
+ src_type->IsUnsignedScalarOrVector() ||
+ src_type->IsSignedScalarOrVector())) {
Fail() << "invalid texel type for storage texture write: "
<< inst.PrettyPrint();
return nullptr;
}
- if (dest_type->is_float_scalar_or_vector() &&
- !src_type->is_float_scalar_or_vector()) {
+ if (dest_type->IsFloatScalarOrVector() &&
+ !src_type->IsFloatScalarOrVector()) {
Fail() << "can only write float or float vector to a storage image with "
"floating texel format: "
<< inst.PrettyPrint();
return nullptr;
}
- if (!dest_type->is_float_scalar_or_vector() &&
- src_type->is_float_scalar_or_vector()) {
+ if (!dest_type->IsFloatScalarOrVector() &&
+ src_type->IsFloatScalarOrVector()) {
Fail()
<< "float or float vector can only be written to a storage image with "
"floating texel format: "
@@ -4991,36 +4979,37 @@
return nullptr;
}
- if (dest_type->is_float_scalar_or_vector()) {
+ if (dest_type->IsFloatScalarOrVector()) {
return texel_prefix;
}
// The only remaining cases are signed/unsigned source, and signed/unsigned
// destination.
- if (dest_type->is_unsigned_scalar_or_vector() ==
- src_type->is_unsigned_scalar_or_vector()) {
+ if (dest_type->IsUnsignedScalarOrVector() ==
+ src_type->IsUnsignedScalarOrVector()) {
return texel_prefix;
}
// We must do a bitcast conversion.
- return create<ast::BitcastExpression>(Source{}, dest_type, texel_prefix);
+ return create<ast::BitcastExpression>(Source{}, dest_type->Build(builder_),
+ texel_prefix);
}
TypedExpression FunctionEmitter::ToI32(TypedExpression value) {
- if (!value.type || value.type->Is<ast::I32>()) {
+ if (!value.type || value.type->Is<I32>()) {
return value;
}
- return {builder_.ty.i32(),
+ return {ty_.I32(),
create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.i32(), ast::ExpressionList{value.expr})};
}
TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
- if (!value.type || !value.type->is_unsigned_scalar_or_vector()) {
+ if (!value.type || !value.type->IsUnsignedScalarOrVector()) {
return value;
}
- if (auto* vec_type = value.type->As<ast::Vector>()) {
- auto new_type = builder_.ty.vec(builder_.ty.i32(), vec_type->size());
- return {new_type,
- builder_.Construct(new_type, ast::ExpressionList{value.expr})};
+ if (auto* vec_type = value.type->As<Vector>()) {
+ auto* new_type = ty_.Vector(ty_.I32(), vec_type->size);
+ return {new_type, builder_.Construct(new_type->Build(builder_),
+ ast::ExpressionList{value.expr})};
}
return ToI32(value);
}
@@ -5073,14 +5062,12 @@
// Synthesize the result.
auto col = MakeOperand(inst, 0);
auto row = MakeOperand(inst, 1);
- auto* col_ty = col.type->As<ast::Vector>();
- auto* row_ty = row.type->As<ast::Vector>();
- auto* result_ty = parser_impl_.ConvertType(inst.type_id())->As<ast::Matrix>();
- if (!col_ty || !col_ty || !result_ty ||
- !AstTypesEquivalent(result_ty->type(), col_ty->type()) ||
- !AstTypesEquivalent(result_ty->type(), row_ty->type()) ||
- result_ty->columns() != row_ty->size() ||
- result_ty->rows() != col_ty->size()) {
+ auto* col_ty = As<Vector>(col.type);
+ auto* row_ty = As<Vector>(row.type);
+ auto* result_ty = As<Matrix>(parser_impl_.ConvertType(inst.type_id()));
+ if (!col_ty || !col_ty || !result_ty || result_ty->type != col_ty->type ||
+ result_ty->type != row_ty->type || result_ty->columns != row_ty->size ||
+ result_ty->rows != col_ty->size) {
Fail() << "invalid outer product instruction: bad types "
<< inst.PrettyPrint();
return {};
@@ -5096,11 +5083,11 @@
// | c.z * r.x c.z * r.y |
ast::ExpressionList result_columns;
- for (uint32_t icol = 0; icol < result_ty->columns(); icol++) {
+ for (uint32_t icol = 0; icol < result_ty->columns; icol++) {
ast::ExpressionList result_row;
auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
Swizzle(icol));
- for (uint32_t irow = 0; irow < result_ty->rows(); irow++) {
+ for (uint32_t irow = 0; irow < result_ty->rows; irow++) {
auto* column_factor = create<ast::MemberAccessorExpression>(
Source{}, col.expr, Swizzle(irow));
auto* elem = create<ast::BinaryExpression>(
@@ -5108,11 +5095,10 @@
result_row.push_back(elem);
}
result_columns.push_back(create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(col_ty), result_row));
+ Source{}, col_ty->Build(builder_), result_row));
}
return {result_ty, create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(result_ty),
- result_columns)};
+ Source{}, result_ty->Build(builder_), result_columns)};
}
bool FunctionEmitter::MakeVectorInsertDynamic(
@@ -5142,8 +5128,7 @@
auto* temp_var = create<ast::Variable>(
Source{}, registered_temp_name, ast::StorageClass::kFunction,
- builder_.ty.MaybeCreateTypename(ast_type), false, src_vector.expr,
- ast::DecorationList{});
+ ast_type->Build(builder_), false, src_vector.expr, ast::DecorationList{});
AddStatement(create<ast::VariableDeclStatement>(Source{}, temp_var));
auto* lhs = create<ast::ArrayAccessorExpression>(
@@ -5189,7 +5174,7 @@
auto* temp_var = create<ast::Variable>(
Source{}, registered_temp_name, ast::StorageClass::kFunction,
- builder_.ty.MaybeCreateTypename(ast_type), false, src_composite.expr,
+ ast_type->Build(builder_), false, src_composite.expr,
ast::DecorationList{});
AddStatement(create<ast::VariableDeclStatement>(Source{}, temp_var));
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 1f41bd1..64c2a93 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -515,7 +515,7 @@
/// @param type the AST type
/// @param result_id the SPIR-V ID for the locally defined value
/// @returns an possibly updated type
- ast::Type* RemapStorageClass(ast::Type* type, uint32_t result_id);
+ const Type* RemapStorageClass(const Type* type, uint32_t result_id);
/// Marks locally defined values when they should get a 'const'
/// definition in WGSL, or a 'var' definition at an outer scope.
@@ -856,7 +856,7 @@
/// Function parameters
ast::VariableList params;
/// Function return type
- ast::Type* return_type;
+ const Type* return_type;
/// Function decorations
ast::DecorationList decorations;
};
@@ -869,7 +869,7 @@
/// @returns the store type for the OpVariable instruction, or
/// null on failure.
- ast::Type* GetVariableStoreType(
+ const Type* GetVariableStoreType(
const spvtools::opt::Instruction& var_decl_inst);
/// Returns an expression for an instruction operand. Signedness conversion is
@@ -937,7 +937,7 @@
/// Get the AST texture the SPIR-V image memory object declaration.
/// @param inst the SPIR-V memory object declaration for the image.
/// @returns a texture type, or null on error
- ast::Texture* GetImageType(const spvtools::opt::Instruction& inst);
+ const Texture* GetImageType(const spvtools::opt::Instruction& inst);
/// Get the expression for the image operand from the first operand to the
/// given instruction.
@@ -974,7 +974,7 @@
ast::Expression* ConvertTexelForStorage(
const spvtools::opt::Instruction& inst,
TypedExpression texel,
- ast::Texture* texture_type);
+ const Texture* texture_type);
/// Returns an expression for an OpSelect, if its operands are scalars
/// or vectors. These translate directly to WGSL select. Otherwise, return
@@ -1128,6 +1128,7 @@
using StatementsStack = std::vector<StatementBlock>;
ParserImpl& parser_impl_;
+ TypeManager& ty_;
ProgramBuilder& builder_;
spvtools::opt::IRContext& ir_context_;
spvtools::opt::analysis::DefUseManager* def_use_mgr_;
diff --git a/src/reader/spirv/function_conversion_test.cc b/src/reader/spirv/function_conversion_test.cc
index 0d18996..8dc4999 100644
--- a/src/reader/spirv/function_conversion_test.cc
+++ b/src/reader/spirv/function_conversion_test.cc
@@ -202,7 +202,7 @@
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(),
HasSubstr("operand for conversion to floating point must be "
- "integral scalar or vector, but got: __bool"));
+ "integral scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertSToF_Vector_BadArgType) {
@@ -220,7 +220,7 @@
EXPECT_THAT(
p->error(),
HasSubstr("operand for conversion to floating point must be integral "
- "scalar or vector, but got: __vec_2__bool"));
+ "scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertSToF_Scalar_FromSigned) {
@@ -344,7 +344,7 @@
auto fe = p->function_emitter(100);
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(), Eq("operand for conversion to floating point must be "
- "integral scalar or vector, but got: __bool"));
+ "integral scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertUToF_Vector_BadArgType) {
@@ -361,7 +361,7 @@
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(),
Eq("operand for conversion to floating point must be integral "
- "scalar or vector, but got: __vec_2__bool"));
+ "scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertUToF_Scalar_FromSigned) {
@@ -486,7 +486,7 @@
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(),
Eq("operand for conversion to signed integer must be floating "
- "point scalar or vector, but got: __u32"));
+ "point scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertFToS_Vector_BadArgType) {
@@ -503,7 +503,7 @@
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(),
Eq("operand for conversion to signed integer must be floating "
- "point scalar or vector, but got: __vec_2__bool"));
+ "point scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertFToS_Scalar_ToSigned) {
@@ -628,7 +628,7 @@
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(),
Eq("operand for conversion to unsigned integer must be floating "
- "point scalar or vector, but got: __u32"));
+ "point scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertFToU_Vector_BadArgType) {
@@ -645,7 +645,7 @@
EXPECT_FALSE(fe.EmitBody());
EXPECT_THAT(p->error(),
Eq("operand for conversion to unsigned integer must be floating "
- "point scalar or vector, but got: __vec_2__bool"));
+ "point scalar or vector"));
}
TEST_F(SpvUnaryConversionTest, ConvertFToU_Scalar_ToSigned_IsError) {
diff --git a/src/reader/spirv/function_memory_test.cc b/src/reader/spirv/function_memory_test.cc
index 7121502..696e09f 100644
--- a/src/reader/spirv/function_memory_test.cc
+++ b/src/reader/spirv/function_memory_test.cc
@@ -836,7 +836,7 @@
Struct S {
[[block]]
StructMember{[[ offset 0 ]] field0: __u32}
- StructMember{[[ offset 4 ]] field1: __alias_RTArr__array__u32_stride_4}
+ StructMember{[[ offset 4 ]] field1: __type_name_RTArr}
}
Variable{
Decorations{
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index e1f6484..368b257 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -240,7 +240,7 @@
TypedExpression& TypedExpression::operator=(const TypedExpression&) = default;
-TypedExpression::TypedExpression(ast::Type* type_in, ast::Expression* expr_in)
+TypedExpression::TypedExpression(const Type* type_in, ast::Expression* expr_in)
: type(type_in), expr(expr_in) {}
ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary)
@@ -305,7 +305,7 @@
return tint::Program(std::move(builder_));
}
-ast::Type* ParserImpl::ConvertType(uint32_t type_id) {
+const Type* ParserImpl::ConvertType(uint32_t type_id) {
if (!success_) {
return nullptr;
}
@@ -322,18 +322,18 @@
}
auto maybe_generate_alias = [this, type_id,
- spirv_type](ast::Type* type) -> ast::Type* {
+ spirv_type](const Type* type) -> const Type* {
if (type != nullptr) {
return MaybeGenerateAlias(type_id, spirv_type, type);
}
- return {};
+ return type;
};
switch (spirv_type->kind()) {
case spvtools::opt::analysis::Type::kVoid:
- return maybe_generate_alias(builder_.ty.void_());
+ return maybe_generate_alias(ty_.Void());
case spvtools::opt::analysis::Type::kBool:
- return maybe_generate_alias(builder_.ty.bool_());
+ return maybe_generate_alias(ty_.Bool());
case spvtools::opt::analysis::Type::kInteger:
return maybe_generate_alias(ConvertType(spirv_type->AsInteger()));
case spvtools::opt::analysis::Type::kFloat:
@@ -362,7 +362,7 @@
case spvtools::opt::analysis::Type::kImage:
// Fake it for sampler and texture types. These are handled in an
// entirely different way.
- return maybe_generate_alias(builder_.ty.void_());
+ return maybe_generate_alias(ty_.Void());
default:
break;
}
@@ -774,36 +774,36 @@
return success_;
}
-ast::Type* ParserImpl::ConvertType(
+const Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Integer* int_ty) {
if (int_ty->width() == 32) {
- return int_ty->IsSigned() ? static_cast<ast::Type*>(builder_.ty.i32())
- : static_cast<ast::Type*>(builder_.ty.u32());
+ return int_ty->IsSigned() ? static_cast<const Type*>(ty_.I32())
+ : static_cast<const Type*>(ty_.U32());
}
Fail() << "unhandled integer width: " << int_ty->width();
return nullptr;
}
-ast::Type* ParserImpl::ConvertType(
+const Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Float* float_ty) {
if (float_ty->width() == 32) {
- return builder_.ty.f32();
+ return ty_.F32();
}
Fail() << "unhandled float width: " << float_ty->width();
return nullptr;
}
-ast::Type* ParserImpl::ConvertType(
+const Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Vector* vec_ty) {
const auto num_elem = vec_ty->element_count();
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type()));
if (ast_elem_ty == nullptr) {
- return nullptr;
+ return ast_elem_ty;
}
- return builder_.ty.vec(ast_elem_ty, num_elem);
+ return ty_.Vector(ast_elem_ty, num_elem);
}
-ast::Type* ParserImpl::ConvertType(
+const Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Matrix* mat_ty) {
const auto* vec_ty = mat_ty->element_type()->AsVector();
const auto* scalar_ty = vec_ty->element_type();
@@ -813,23 +813,23 @@
if (ast_scalar_ty == nullptr) {
return nullptr;
}
- return builder_.ty.mat(ast_scalar_ty, num_columns, num_rows);
+ return ty_.Matrix(ast_scalar_ty, num_columns, num_rows);
}
-ast::Type* ParserImpl::ConvertType(
+const Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::RuntimeArray* rtarr_ty) {
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type()));
if (ast_elem_ty == nullptr) {
return nullptr;
}
- ast::DecorationList decorations;
- if (!ParseArrayDecorations(rtarr_ty, &decorations)) {
+ uint32_t array_stride = 0;
+ if (!ParseArrayDecorations(rtarr_ty, &array_stride)) {
return nullptr;
}
- return builder_.ty.array(ast_elem_ty, 0, std::move(decorations));
+ return ty_.Array(ast_elem_ty, 0, array_stride);
}
-ast::Type* ParserImpl::ConvertType(
+const Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Array* arr_ty) {
const auto elem_type_id = type_mgr_->GetId(arr_ty->element_type());
auto* ast_elem_ty = ConvertType(elem_type_id);
@@ -863,21 +863,19 @@
<< num_elem;
return nullptr;
}
- ast::DecorationList decorations;
- if (!ParseArrayDecorations(arr_ty, &decorations)) {
+ uint32_t array_stride = 0;
+ if (!ParseArrayDecorations(arr_ty, &array_stride)) {
return nullptr;
}
-
if (remap_buffer_block_type_.count(elem_type_id)) {
remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty));
}
- return builder_.ty.array(ast_elem_ty, static_cast<uint32_t>(num_elem),
- std::move(decorations));
+ return ty_.Array(ast_elem_ty, static_cast<uint32_t>(num_elem), array_stride);
}
bool ParserImpl::ParseArrayDecorations(
const spvtools::opt::analysis::Type* spv_type,
- ast::DecorationList* decorations) {
+ uint32_t* array_stride) {
bool has_array_stride = false;
const auto type_id = type_mgr_->GetId(spv_type);
for (auto& decoration : this->GetDecorationsFor(type_id)) {
@@ -892,7 +890,7 @@
<< ": multiple ArrayStride decorations";
}
has_array_stride = true;
- decorations->push_back(create<ast::StrideDecoration>(Source{}, stride));
+ *array_stride = stride;
} else {
return Fail() << "invalid array type ID " << type_id
<< ": unknown decoration "
@@ -904,7 +902,7 @@
return true;
}
-ast::Type* ParserImpl::ConvertType(
+const Type* ParserImpl::ConvertType(
uint32_t type_id,
const spvtools::opt::analysis::Struct* struct_ty) {
// Compute the struct decoration.
@@ -930,6 +928,7 @@
// Compute members
ast::StructMemberList ast_members;
const auto members = struct_ty->element_types();
+ TypeList ast_member_types;
unsigned num_non_writable_members = 0;
for (uint32_t member_index = 0; member_index < members.size();
++member_index) {
@@ -940,6 +939,8 @@
return nullptr;
}
+ ast_member_types.emplace_back(ast_member_ty);
+
// Scan member for built-in decorations. Some vertex built-ins are handled
// specially, and should not generate a structure member.
bool create_ast_member = true;
@@ -1003,8 +1004,8 @@
}
const auto member_name = namer_.GetMemberName(type_id, member_index);
auto* ast_struct_member = create<ast::StructMember>(
- Source{}, builder_.Symbols().Register(member_name), ast_member_ty,
- std::move(ast_member_decorations));
+ Source{}, builder_.Symbols().Register(member_name),
+ ast_member_ty->Build(builder_), std::move(ast_member_decorations));
ast_members.push_back(ast_struct_member);
}
@@ -1030,7 +1031,7 @@
read_only_struct_types_.insert(ast_struct->name());
}
AddConstructedType(sym, ast_struct);
- return ast_struct;
+ return ty_.Struct(sym, std::move(ast_member_types));
}
void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) {
@@ -1040,8 +1041,8 @@
}
}
-ast::Type* ParserImpl::ConvertType(uint32_t type_id,
- const spvtools::opt::analysis::Pointer*) {
+const Type* ParserImpl::ConvertType(uint32_t type_id,
+ const spvtools::opt::analysis::Pointer*) {
const auto* inst = def_use_mgr_->GetDef(type_id);
const auto pointee_type_id = inst->GetSingleWordInOperand(1);
const auto storage_class = SpvStorageClass(inst->GetSingleWordInOperand(0));
@@ -1080,8 +1081,7 @@
}
}
- ast_elem_ty = builder_.ty.MaybeCreateTypename(ast_elem_ty);
- return builder_.ty.pointer(ast_elem_ty, ast_storage_class);
+ return ty_.Pointer(ast_elem_ty, ast_storage_class);
}
bool ParserImpl::RegisterTypes() {
@@ -1114,7 +1114,7 @@
// that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant.
for (auto& inst : module_->types_values()) {
// These will be populated for a valid scalar spec constant.
- ast::Type* ast_type = nullptr;
+ const Type* ast_type = nullptr;
ast::ScalarConstructorExpression* ast_expr = nullptr;
switch (inst.opcode()) {
@@ -1129,15 +1129,15 @@
case SpvOpSpecConstant: {
ast_type = ConvertType(inst.type_id());
const uint32_t literal_value = inst.GetSingleWordInOperand(0);
- if (ast_type->Is<ast::I32>()) {
+ if (ast_type->Is<I32>()) {
ast_expr = create<ast::ScalarConstructorExpression>(
Source{}, create<ast::SintLiteral>(
Source{}, static_cast<int32_t>(literal_value)));
- } else if (ast_type->Is<ast::U32>()) {
+ } else if (ast_type->Is<U32>()) {
ast_expr = create<ast::ScalarConstructorExpression>(
Source{}, create<ast::UintLiteral>(
Source{}, static_cast<uint32_t>(literal_value)));
- } else if (ast_type->Is<ast::F32>()) {
+ } else if (ast_type->Is<F32>()) {
float float_value;
// Copy the bits so we can read them as a float.
std::memcpy(&float_value, &literal_value, sizeof(float_value));
@@ -1173,12 +1173,12 @@
return success_;
}
-ast::Type* ParserImpl::MaybeGenerateAlias(
+const Type* ParserImpl::MaybeGenerateAlias(
uint32_t type_id,
const spvtools::opt::analysis::Type* type,
- ast::Type* ast_type) {
+ const Type* ast_type) {
if (!success_) {
- return {};
+ return nullptr;
}
// We only care about arrays, and runtime arrays.
@@ -1202,16 +1202,17 @@
auto* ast_underlying_type = ast_type;
if (ast_underlying_type == nullptr) {
Fail() << "internal error: no type registered for SPIR-V ID: " << type_id;
- return {};
+ return nullptr;
}
const auto name = namer_.GetName(type_id);
const auto sym = builder_.Symbols().Register(name);
- auto ast_alias_type = builder_.ty.alias(sym, ast_underlying_type);
+ auto ast_alias_type =
+ builder_.ty.alias(sym, ast_underlying_type->Build(builder_));
// Record this new alias as the AST type for this SPIR-V ID.
AddConstructedType(sym, ast_alias_type);
- return ast_alias_type;
+ return ty_.Alias(sym, ast_underlying_type);
}
bool ParserImpl::EmitModuleScopeVariables() {
@@ -1252,7 +1253,7 @@
if (!success_) {
return false;
}
- ast::Type* ast_type;
+ const Type* ast_type = nullptr;
if (spirv_storage_class == SpvStorageClassUniformConstant) {
// These are opaque handles: samplers or textures
ast_type = GetTypeForHandleVar(var);
@@ -1266,14 +1267,14 @@
"SPIR-V type with ID: "
<< var.type_id();
}
- if (!ast_type->Is<ast::Pointer>()) {
+ if (!ast_type->Is<Pointer>()) {
return Fail() << "variable with ID " << var.result_id()
<< " has non-pointer type " << var.type_id();
}
}
- auto* ast_store_type = ast_type->As<ast::Pointer>()->type();
- auto ast_storage_class = ast_type->As<ast::Pointer>()->storage_class();
+ auto* ast_store_type = ast_type->As<Pointer>()->type;
+ auto ast_storage_class = ast_type->As<Pointer>()->storage_class;
ast::Expression* ast_constructor = nullptr;
if (var.NumInOperands() > 1) {
// SPIR-V initializers are always constants.
@@ -1336,7 +1337,7 @@
ast::Variable* ParserImpl::MakeVariable(uint32_t id,
ast::StorageClass sc,
- ast::Type* type,
+ const Type* type,
bool is_const,
ast::Expression* constructor,
ast::DecorationList decorations) {
@@ -1347,14 +1348,14 @@
if (sc == ast::StorageClass::kStorage) {
bool read_only = false;
- if (auto* tn = type->As<ast::TypeName>()) {
- read_only = read_only_struct_types_.count(tn->name()) > 0;
+ if (auto* tn = type->As<Named>()) {
+ read_only = read_only_struct_types_.count(tn->name) > 0;
}
// Apply the access(read) or access(read_write) modifier.
auto access = read_only ? ast::AccessControl::kReadOnly
: ast::AccessControl::kReadWrite;
- type = builder_.ty.access(access, type);
+ type = ty_.AccessControl(type, access);
}
for (auto& deco : GetDecorationsFor(id)) {
@@ -1396,7 +1397,7 @@
"SampleMask must be an array of 1 element.";
}
special_builtins_[id] = spv_builtin;
- type = builder_.ty.u32();
+ type = ty_.U32();
break;
}
default:
@@ -1439,8 +1440,8 @@
std::string name = namer_.Name(id);
return create<ast::Variable>(Source{}, builder_.Symbols().Register(name), sc,
- builder_.ty.MaybeCreateTypename(type), is_const,
- constructor, decorations);
+ type->Build(builder_), is_const, constructor,
+ decorations);
}
TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
@@ -1476,27 +1477,27 @@
// So canonicalization should map that way too.
// Currently "null<type>" is missing from the WGSL parser.
// See https://bugs.chromium.org/p/tint/issues/detail?id=34
- if (ast_type->Is<ast::U32>()) {
- return {ast_type, create<ast::ScalarConstructorExpression>(
- Source{}, create<ast::UintLiteral>(
- source, spirv_const->GetU32()))};
+ if (ast_type->Is<U32>()) {
+ return {ty_.U32(), create<ast::ScalarConstructorExpression>(
+ Source{}, create<ast::UintLiteral>(
+ source, spirv_const->GetU32()))};
}
- if (ast_type->Is<ast::I32>()) {
- return {ast_type, create<ast::ScalarConstructorExpression>(
- Source{}, create<ast::SintLiteral>(
- source, spirv_const->GetS32()))};
+ if (ast_type->Is<I32>()) {
+ return {ty_.I32(), create<ast::ScalarConstructorExpression>(
+ Source{}, create<ast::SintLiteral>(
+ source, spirv_const->GetS32()))};
}
- if (ast_type->Is<ast::F32>()) {
- return {ast_type, create<ast::ScalarConstructorExpression>(
- Source{}, create<ast::FloatLiteral>(
- source, spirv_const->GetFloat()))};
+ if (ast_type->Is<F32>()) {
+ return {ty_.F32(), create<ast::ScalarConstructorExpression>(
+ Source{}, create<ast::FloatLiteral>(
+ source, spirv_const->GetFloat()))};
}
- if (ast_type->Is<ast::Bool>()) {
+ if (ast_type->Is<Bool>()) {
const bool value = spirv_const->AsNullConstant()
? false
: spirv_const->AsBoolConstant()->value();
- return {ast_type, create<ast::ScalarConstructorExpression>(
- Source{}, create<ast::BoolLiteral>(source, value))};
+ return {ty_.Bool(), create<ast::ScalarConstructorExpression>(
+ Source{}, create<ast::BoolLiteral>(source, value))};
}
auto* spirv_composite_const = spirv_const->AsCompositeConstant();
if (spirv_composite_const != nullptr) {
@@ -1521,10 +1522,9 @@
}
ast_components.emplace_back(ast_component.expr);
}
- return {original_ast_type,
- create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(original_ast_type),
- std::move(ast_components))};
+ return {original_ast_type, create<ast::TypeConstructorExpression>(
+ Source{}, original_ast_type->Build(builder_),
+ std::move(ast_components))};
}
auto* spirv_null_const = spirv_const->AsNullConstant();
if (spirv_null_const != nullptr) {
@@ -1535,7 +1535,7 @@
return {};
}
-ast::Expression* ParserImpl::MakeNullValue(ast::Type* type) {
+ast::Expression* ParserImpl::MakeNullValue(const Type* type) {
// TODO(dneto): Use the no-operands constructor syntax when it becomes
// available in Tint.
// https://github.com/gpuweb/gpuweb/issues/685
@@ -1549,93 +1549,89 @@
auto* original_type = type;
type = type->UnwrapIfNeeded();
- if (type->Is<ast::Bool>()) {
+ if (type->Is<Bool>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::BoolLiteral>(Source{}, false));
}
- if (type->Is<ast::U32>()) {
+ if (type->Is<U32>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::UintLiteral>(Source{}, 0u));
}
- if (type->Is<ast::I32>()) {
+ if (type->Is<I32>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::SintLiteral>(Source{}, 0));
}
- if (type->Is<ast::F32>()) {
+ if (type->Is<F32>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::FloatLiteral>(Source{}, 0.0f));
}
- if (type->Is<ast::TypeName>()) {
+ if (type->Is<Alias>()) {
// TODO(amaiorano): No type constructor for TypeName (yet?)
ast::ExpressionList ast_components;
- return create<ast::TypeConstructorExpression>(Source{}, original_type,
- std::move(ast_components));
+ return create<ast::TypeConstructorExpression>(
+ Source{}, original_type->Build(builder_), std::move(ast_components));
}
- if (auto* vec_ty = type->As<ast::Vector>()) {
+ if (auto* vec_ty = type->As<Vector>()) {
ast::ExpressionList ast_components;
- for (size_t i = 0; i < vec_ty->size(); ++i) {
- ast_components.emplace_back(MakeNullValue(vec_ty->type()));
+ for (size_t i = 0; i < vec_ty->size; ++i) {
+ ast_components.emplace_back(MakeNullValue(vec_ty->type));
}
return create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(type),
- std::move(ast_components));
+ Source{}, type->Build(builder_), std::move(ast_components));
}
- if (auto* mat_ty = type->As<ast::Matrix>()) {
+ if (auto* mat_ty = type->As<Matrix>()) {
// Matrix components are columns
- auto column_ty = builder_.ty.vec(mat_ty->type(), mat_ty->rows());
+ auto* column_ty = ty_.Vector(mat_ty->type, mat_ty->rows);
ast::ExpressionList ast_components;
- for (size_t i = 0; i < mat_ty->columns(); ++i) {
+ for (size_t i = 0; i < mat_ty->columns; ++i) {
ast_components.emplace_back(MakeNullValue(column_ty));
}
return create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(type),
- std::move(ast_components));
+ Source{}, type->Build(builder_), std::move(ast_components));
}
- if (auto* arr_ty = type->As<ast::Array>()) {
+ if (auto* arr_ty = type->As<Array>()) {
ast::ExpressionList ast_components;
- for (size_t i = 0; i < arr_ty->size(); ++i) {
- ast_components.emplace_back(MakeNullValue(arr_ty->type()));
+ for (size_t i = 0; i < arr_ty->size; ++i) {
+ ast_components.emplace_back(MakeNullValue(arr_ty->type));
}
return create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(original_type),
- std::move(ast_components));
+ Source{}, original_type->Build(builder_), std::move(ast_components));
}
- if (auto* struct_ty = type->As<ast::Struct>()) {
+ if (auto* struct_ty = type->As<Struct>()) {
ast::ExpressionList ast_components;
- for (auto* member : struct_ty->members()) {
- ast_components.emplace_back(MakeNullValue(member->type()));
+ for (auto* member : struct_ty->members) {
+ ast_components.emplace_back(MakeNullValue(member));
}
return create<ast::TypeConstructorExpression>(
- Source{}, builder_.ty.MaybeCreateTypename(original_type),
- std::move(ast_components));
+ Source{}, original_type->Build(builder_), std::move(ast_components));
}
- Fail() << "can't make null value for type: " << type->type_name();
+ Fail() << "can't make null value for type: " << type->TypeInfo().name;
return nullptr;
}
-TypedExpression ParserImpl::MakeNullExpression(ast::Type* type) {
+TypedExpression ParserImpl::MakeNullExpression(const Type* type) {
return {type, MakeNullValue(type)};
}
-ast::Type* ParserImpl::UnsignedTypeFor(ast::Type* type) {
- if (type->Is<ast::I32>()) {
- return builder_.ty.u32();
+const Type* ParserImpl::UnsignedTypeFor(const Type* type) {
+ if (type->Is<I32>()) {
+ return ty_.U32();
}
- if (auto* v = type->As<ast::Vector>()) {
- if (v->type()->Is<ast::I32>()) {
- return builder_.ty.vec(builder_.ty.u32(), v->size());
+ if (auto* v = type->As<Vector>()) {
+ if (v->type->Is<I32>()) {
+ return ty_.Vector(ty_.U32(), v->size);
}
}
return {};
}
-ast::Type* ParserImpl::SignedTypeFor(ast::Type* type) {
- if (type->Is<ast::U32>()) {
- return builder_.ty.i32();
+const Type* ParserImpl::SignedTypeFor(const Type* type) {
+ if (type->Is<U32>()) {
+ return ty_.I32();
}
- if (auto* v = type->As<ast::Vector>()) {
- if (v->type()->Is<ast::U32>()) {
- return builder_.ty.vec(builder_.ty.i32(), v->size());
+ if (auto* v = type->As<Vector>()) {
+ if (v->type->Is<U32>()) {
+ return ty_.Vector(ty_.I32(), v->size);
}
}
return {};
@@ -1674,13 +1670,14 @@
if (auto* unsigned_ty = UnsignedTypeFor(type)) {
// Conversion is required.
return {unsigned_ty,
- create<ast::BitcastExpression>(Source{}, unsigned_ty, expr.expr)};
+ create<ast::BitcastExpression>(
+ Source{}, unsigned_ty->Build(builder_), expr.expr)};
}
} else if (requires_signed) {
if (auto* signed_ty = SignedTypeFor(type)) {
// Conversion is required.
- return {signed_ty,
- create<ast::BitcastExpression>(Source{}, signed_ty, expr.expr)};
+ return {signed_ty, create<ast::BitcastExpression>(
+ Source{}, signed_ty->Build(builder_), expr.expr)};
}
}
// We should not reach here.
@@ -1689,21 +1686,22 @@
TypedExpression ParserImpl::RectifySecondOperandSignedness(
const spvtools::opt::Instruction& inst,
- ast::Type* first_operand_type,
+ const Type* first_operand_type,
TypedExpression&& second_operand_expr) {
- if (!AstTypesEquivalent(first_operand_type, second_operand_expr.type) &&
+ if ((first_operand_type != second_operand_expr.type) &&
AssumesSecondOperandSignednessMatchesFirstOperand(inst.opcode())) {
// Conversion is required.
return {first_operand_type,
- create<ast::BitcastExpression>(Source{}, first_operand_type,
+ create<ast::BitcastExpression>(Source{},
+ first_operand_type->Build(builder_),
second_operand_expr.expr)};
}
// No conversion necessary.
return std::move(second_operand_expr);
}
-ast::Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst,
- ast::Type* first_operand_type) {
+const Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst,
+ const Type* first_operand_type) {
const auto opcode = inst.opcode();
if (AssumesResultSignednessMatchesFirstOperand(opcode)) {
return first_operand_type;
@@ -1718,66 +1716,63 @@
return nullptr;
}
-ast::Type* ParserImpl::GetSignedIntMatchingShape(ast::Type* other) {
+const Type* ParserImpl::GetSignedIntMatchingShape(const Type* other) {
if (other == nullptr) {
Fail() << "no type provided";
}
- auto i32 = builder_.ty.i32();
- if (other->Is<ast::F32>() || other->Is<ast::U32>() || other->Is<ast::I32>()) {
- return i32;
+ if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
+ return ty_.I32();
}
- auto* vec_ty = other->As<ast::Vector>();
- if (vec_ty) {
- return builder_.ty.vec(i32, vec_ty->size());
+ if (auto* vec_ty = other->As<Vector>()) {
+ return ty_.Vector(ty_.I32(), vec_ty->size);
}
- Fail() << "required numeric scalar or vector, but got " << other->type_name();
+ Fail() << "required numeric scalar or vector, but got "
+ << other->TypeInfo().name;
return nullptr;
}
-ast::Type* ParserImpl::GetUnsignedIntMatchingShape(ast::Type* other) {
+const Type* ParserImpl::GetUnsignedIntMatchingShape(const Type* other) {
if (other == nullptr) {
Fail() << "no type provided";
return nullptr;
}
- auto u32 = builder_.ty.u32();
- if (other->Is<ast::F32>() || other->Is<ast::U32>() || other->Is<ast::I32>()) {
- return u32;
+ if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
+ return ty_.U32();
}
- auto* vec_ty = other->As<ast::Vector>();
- if (vec_ty) {
- return builder_.ty.vec(u32, vec_ty->size());
+ if (auto* vec_ty = other->As<Vector>()) {
+ return ty_.Vector(ty_.U32(), vec_ty->size);
}
- Fail() << "required numeric scalar or vector, but got " << other->type_name();
+ Fail() << "required numeric scalar or vector, but got "
+ << other->TypeInfo().name;
return nullptr;
}
TypedExpression ParserImpl::RectifyForcedResultType(
TypedExpression expr,
const spvtools::opt::Instruction& inst,
- ast::Type* first_operand_type) {
+ const Type* first_operand_type) {
auto* forced_result_ty = ForcedResultType(inst, first_operand_type);
- if ((forced_result_ty == nullptr) ||
- AstTypesEquivalent(forced_result_ty, expr.type)) {
+ if ((!forced_result_ty) || (forced_result_ty == expr.type)) {
return expr;
}
- return {expr.type,
- create<ast::BitcastExpression>(Source{}, expr.type, expr.expr)};
+ return {expr.type, create<ast::BitcastExpression>(
+ Source{}, expr.type->Build(builder_), expr.expr)};
}
TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) {
- if (expr.type && expr.type->is_signed_scalar_or_vector()) {
+ if (expr.type && expr.type->IsSignedScalarOrVector()) {
auto* new_type = GetUnsignedIntMatchingShape(expr.type);
- return {new_type,
- create<ast::BitcastExpression>(Source{}, new_type, expr.expr)};
+ return {new_type, create<ast::BitcastExpression>(
+ Source{}, new_type->Build(builder_), expr.expr)};
}
return expr;
}
TypedExpression ParserImpl::AsSigned(TypedExpression expr) {
- if (expr.type && expr.type->is_unsigned_scalar_or_vector()) {
+ if (expr.type && expr.type->IsUnsignedScalarOrVector()) {
auto* new_type = GetSignedIntMatchingShape(expr.type);
- return {new_type,
- create<ast::BitcastExpression>(Source{}, new_type, expr.expr)};
+ return {new_type, create<ast::BitcastExpression>(
+ Source{}, new_type->Build(builder_), expr.expr)};
}
return expr;
}
@@ -1952,7 +1947,7 @@
return raw_handle_type;
}
-ast::Pointer* ParserImpl::GetTypeForHandleVar(
+const Pointer* ParserImpl::GetTypeForHandleVar(
const spvtools::opt::Instruction& var) {
auto where = handle_type_.find(&var);
if (where != handle_type_.end()) {
@@ -2036,11 +2031,11 @@
}
// Construct the Tint handle type.
- ast::Type* ast_store_type;
+ const Type* ast_store_type = nullptr;
if (usage.IsSampler()) {
- ast_store_type = builder_.ty.sampler(
- usage.IsComparisonSampler() ? ast::SamplerKind::kComparisonSampler
- : ast::SamplerKind::kSampler);
+ ast_store_type = ty_.Sampler(usage.IsComparisonSampler()
+ ? ast::SamplerKind::kComparisonSampler
+ : ast::SamplerKind::kSampler);
} else if (usage.IsTexture()) {
const spvtools::opt::analysis::Image* image_type =
type_mgr_->GetType(raw_handle_type->result_id())->AsImage();
@@ -2069,14 +2064,13 @@
// OpImage variable with an OpImage*Dref* instruction. In WGSL we must
// treat that as a depth texture.
if (image_type->depth() || usage.IsDepthTexture()) {
- ast_store_type = builder_.ty.depth_texture(dim);
+ ast_store_type = ty_.DepthTexture(dim);
} else if (image_type->is_multisampled()) {
// Multisampled textures are never depth textures.
ast_store_type =
- builder_.ty.multisampled_texture(dim, ast_sampled_component_type);
+ ty_.MultisampledTexture(dim, ast_sampled_component_type);
} else {
- ast_store_type =
- builder_.ty.sampled_texture(dim, ast_sampled_component_type);
+ ast_store_type = ty_.SampledTexture(dim, ast_sampled_component_type);
}
} else {
const auto access = usage.IsStorageReadTexture()
@@ -2087,7 +2081,7 @@
return nullptr;
}
ast_store_type =
- builder_.ty.access(access, builder_.ty.storage_texture(dim, format));
+ ty_.AccessControl(ty_.StorageTexture(dim, format), access);
}
} else {
Fail() << "unsupported: UniformConstant variable is not a recognized "
@@ -2097,14 +2091,14 @@
}
// Form the pointer type.
- auto result =
- builder_.ty.pointer(ast_store_type, ast::StorageClass::kUniformConstant);
+ auto* result =
+ ty_.Pointer(ast_store_type, ast::StorageClass::kUniformConstant);
// Remember it for later.
handle_type_[&var] = result;
return result;
}
-ast::Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) {
+const Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) {
switch (format) {
case ast::ImageFormat::kR8Uint:
case ast::ImageFormat::kR16Uint:
@@ -2115,7 +2109,7 @@
case ast::ImageFormat::kRg32Uint:
case ast::ImageFormat::kRgba16Uint:
case ast::ImageFormat::kRgba32Uint:
- return builder_.ty.u32();
+ return ty_.U32();
case ast::ImageFormat::kR8Sint:
case ast::ImageFormat::kR16Sint:
@@ -2126,7 +2120,7 @@
case ast::ImageFormat::kRg32Sint:
case ast::ImageFormat::kRgba16Sint:
case ast::ImageFormat::kRgba32Sint:
- return builder_.ty.i32();
+ return ty_.I32();
case ast::ImageFormat::kR8Unorm:
case ast::ImageFormat::kRg8Unorm:
@@ -2145,7 +2139,7 @@
case ast::ImageFormat::kRg32Float:
case ast::ImageFormat::kRgba16Float:
case ast::ImageFormat::kRgba32Float:
- return builder_.ty.f32();
+ return ty_.F32();
default:
break;
}
@@ -2153,7 +2147,7 @@
return nullptr;
}
-ast::Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) {
+const Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) {
auto* component_type = GetComponentTypeForFormat(format);
if (!component_type) {
return nullptr;
@@ -2185,7 +2179,7 @@
case ast::ImageFormat::kRg8Uint:
case ast::ImageFormat::kRg8Unorm:
// Two channels
- return builder_.ty.vec(component_type, 2);
+ return ty_.Vector(component_type, 2);
case ast::ImageFormat::kBgra8Unorm:
case ast::ImageFormat::kBgra8UnormSrgb:
@@ -2202,7 +2196,7 @@
case ast::ImageFormat::kRgba8Unorm:
case ast::ImageFormat::kRgba8UnormSrgb:
// Four channels
- return builder_.ty.vec(component_type, 4);
+ return ty_.Vector(component_type, 4);
default:
break;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 900165b..7253cdd 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -28,6 +28,7 @@
#include "src/reader/spirv/entry_point_info.h"
#include "src/reader/spirv/enum_converter.h"
#include "src/reader/spirv/namer.h"
+#include "src/reader/spirv/parser_type.h"
#include "src/reader/spirv/usage.h"
/// This is the implementation of the SPIR-V parser for Tint.
@@ -51,14 +52,6 @@
namespace reader {
namespace spirv {
-/// Returns true of the two input ast types are semantically equivalent
-/// @param lhs first type to compare
-/// @param rhs other type to compare
-/// @returns true if both types are semantically equivalent
-inline bool AstTypesEquivalent(ast::Type* lhs, ast::Type* rhs) {
- return lhs->type_name() == rhs->type_name();
-}
-
/// The binary representation of a SPIR-V decoration enum followed by its
/// operands, if any.
/// Example: { SpvDecorationBlock }
@@ -81,10 +74,10 @@
/// Constructor
/// @param type_in the type of the expression
/// @param expr_in the expression
- TypedExpression(ast::Type* type_in, ast::Expression* expr_in);
+ TypedExpression(const Type* type_in, ast::Expression* expr_in);
/// The type
- ast::Type* type;
+ Type const* type = nullptr;
/// The expression
ast::Expression* expr = nullptr;
};
@@ -110,6 +103,9 @@
/// program. To be used only for testing.
ProgramBuilder& builder() { return builder_; }
+ /// @returns the type manager
+ TypeManager& type_manager() { return ty_; }
+
/// Logs failure, ands return a failure stream to accumulate diagnostic
/// messages. By convention, a failure should only be logged along with
/// a non-empty string diagnostic.
@@ -163,7 +159,7 @@
/// after the internal representation of the module has been built.
/// @param type_id the SPIR-V ID of a type.
/// @returns a Tint type, or nullptr
- ast::Type* ConvertType(uint32_t type_id);
+ const Type* ConvertType(uint32_t type_id);
/// Emits an alias type declaration for the given type, if necessary, and
/// also updates the mapping of the SPIR-V type ID to the alias type.
@@ -176,9 +172,9 @@
/// @param type the type that might get an alias
/// @param ast_type the ast type that might get an alias
/// @returns an alias type or `ast_type` if no alias was created
- ast::Type* MaybeGenerateAlias(uint32_t type_id,
- const spvtools::opt::analysis::Type* type,
- ast::Type* ast_type);
+ const Type* MaybeGenerateAlias(uint32_t type_id,
+ const spvtools::opt::analysis::Type* type,
+ const Type* ast_type);
/// @returns the fail stream object
FailStream& fail_stream() { return fail_stream_; }
@@ -328,7 +324,7 @@
/// in the error case
ast::Variable* MakeVariable(uint32_t id,
ast::StorageClass sc,
- ast::Type* type,
+ const Type* type,
bool is_const,
ast::Expression* constructor,
ast::DecorationList decorations);
@@ -341,12 +337,12 @@
/// Creates an AST expression node for the null value for the given type.
/// @param type the AST type
/// @returns a new expression
- ast::Expression* MakeNullValue(ast::Type* type);
+ ast::Expression* MakeNullValue(const Type* type);
/// Make a typed expression for the null value for the given type.
/// @param type the AST type
/// @returns a new typed expression
- TypedExpression MakeNullExpression(ast::Type* type);
+ TypedExpression MakeNullExpression(const Type* type);
/// Converts a given expression to the signedness demanded for an operand
/// of the given SPIR-V instruction, if required. If the instruction assumes
@@ -371,7 +367,7 @@
/// @returns second_operand_expr, or a cast of it
TypedExpression RectifySecondOperandSignedness(
const spvtools::opt::Instruction& inst,
- ast::Type* first_operand_type,
+ const Type* first_operand_type,
TypedExpression&& second_operand_expr);
/// Returns the "forced" result type for the given SPIR-V instruction.
@@ -382,8 +378,8 @@
/// @param inst the SPIR-V instruction
/// @param first_operand_type the AST type for the first operand.
/// @returns the forced AST result type, or nullptr if no forcing is required.
- ast::Type* ForcedResultType(const spvtools::opt::Instruction& inst,
- ast::Type* first_operand_type);
+ const Type* ForcedResultType(const spvtools::opt::Instruction& inst,
+ const Type* first_operand_type);
/// Returns a signed integer scalar or vector type matching the shape (scalar,
/// vector, and component bit width) of another type, which itself is a
@@ -391,7 +387,7 @@
/// requirement.
/// @param other the type whose shape must be matched
/// @returns the signed scalar or vector type
- ast::Type* GetSignedIntMatchingShape(ast::Type* other);
+ const Type* GetSignedIntMatchingShape(const Type* other);
/// Returns a signed integer scalar or vector type matching the shape (scalar,
/// vector, and component bit width) of another type, which itself is a
@@ -399,7 +395,7 @@
/// requirement.
/// @param other the type whose shape must be matched
/// @returns the unsigned scalar or vector type
- ast::Type* GetUnsignedIntMatchingShape(ast::Type* other);
+ const Type* GetUnsignedIntMatchingShape(const Type* other);
/// Wraps the given expression in an as-cast to the given expression's type,
/// when the underlying operation produces a forced result type different
@@ -412,18 +408,20 @@
TypedExpression RectifyForcedResultType(
TypedExpression expr,
const spvtools::opt::Instruction& inst,
- ast::Type* first_operand_type);
+ const Type* first_operand_type);
- /// @returns the given expression, but ensuring it's an unsigned type of the
- /// same shape as the operand. Wraps the expresion with a bitcast if needed.
+ /// Returns the given expression, but ensuring it's an unsigned type of the
+ /// same shape as the operand. Wraps the expression with a bitcast if needed.
/// Assumes the given expresion is a integer scalar or vector.
/// @param expr an integer scalar or integer vector expression.
+ /// @return the potentially cast TypedExpression
TypedExpression AsUnsigned(TypedExpression expr);
- /// @returns the given expression, but ensuring it's a signed type of the
- /// same shape as the operand. Wraps the expresion with a bitcast if needed.
+ /// Returns the given expression, but ensuring it's a signed type of the
+ /// same shape as the operand. Wraps the expression with a bitcast if needed.
/// Assumes the given expresion is a integer scalar or vector.
/// @param expr an integer scalar or integer vector expression.
+ /// @return the potentially cast TypedExpression
TypedExpression AsSigned(TypedExpression expr);
/// Bookkeeping used for tracking the "position" builtin variable.
@@ -512,18 +510,18 @@
/// @param var the OpVariable instruction
/// @returns the Tint AST type for the poiner-to-{sampler|texture} or null on
/// error
- ast::Pointer* GetTypeForHandleVar(const spvtools::opt::Instruction& var);
+ const Pointer* GetTypeForHandleVar(const spvtools::opt::Instruction& var);
/// Returns the channel component type corresponding to the given image
/// format.
/// @param format image texel format
/// @returns the component type, one of f32, i32, u32
- ast::Type* GetComponentTypeForFormat(ast::ImageFormat format);
+ const Type* GetComponentTypeForFormat(ast::ImageFormat format);
/// Returns texel type corresponding to the given image format.
/// @param format image texel format
/// @returns the texel format
- ast::Type* GetTexelTypeForFormat(ast::ImageFormat format);
+ const Type* GetTexelTypeForFormat(ast::ImageFormat format);
/// Returns the SPIR-V instruction with the given ID, or nullptr.
/// @param id the SPIR-V result ID
@@ -561,19 +559,20 @@
private:
/// Converts a specific SPIR-V type to a Tint type. Integer case
- ast::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty);
+ const Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty);
/// Converts a specific SPIR-V type to a Tint type. Float case
- ast::Type* ConvertType(const spvtools::opt::analysis::Float* float_ty);
+ const Type* ConvertType(const spvtools::opt::analysis::Float* float_ty);
/// Converts a specific SPIR-V type to a Tint type. Vector case
- ast::Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty);
+ const Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty);
/// Converts a specific SPIR-V type to a Tint type. Matrix case
- ast::Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty);
+ const Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty);
/// Converts a specific SPIR-V type to a Tint type. RuntimeArray case
/// @param rtarr_ty the Tint type
- ast::Type* ConvertType(const spvtools::opt::analysis::RuntimeArray* rtarr_ty);
+ const Type* ConvertType(
+ const spvtools::opt::analysis::RuntimeArray* rtarr_ty);
/// Converts a specific SPIR-V type to a Tint type. Array case
/// @param arr_ty the Tint type
- ast::Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty);
+ const Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty);
/// Converts a specific SPIR-V type to a Tint type. Struct case.
/// SPIR-V allows distinct struct type definitions for two OpTypeStruct
/// that otherwise have the same set of members (and struct and member
@@ -585,34 +584,34 @@
/// not significant to the optimizer's module representation.
/// @param type_id the SPIR-V ID for the type.
/// @param struct_ty the Tint type
- ast::Type* ConvertType(uint32_t type_id,
- const spvtools::opt::analysis::Struct* struct_ty);
+ const Type* ConvertType(uint32_t type_id,
+ const spvtools::opt::analysis::Struct* struct_ty);
/// Converts a specific SPIR-V type to a Tint type. Pointer case
/// The pointer to gl_PerVertex maps to nullptr, and instead is recorded
/// in member #builtin_position_.
/// @param type_id the SPIR-V ID for the type.
/// @param ptr_ty the Tint type
- ast::Type* ConvertType(uint32_t type_id,
- const spvtools::opt::analysis::Pointer* ptr_ty);
+ const Type* ConvertType(uint32_t type_id,
+ const spvtools::opt::analysis::Pointer* ptr_ty);
/// If `type` is a signed integral, or vector of signed integral,
/// returns the unsigned type, otherwise returns `type`.
/// @param type the possibly signed type
/// @returns the unsigned type
- ast::Type* UnsignedTypeFor(ast::Type* type);
+ const Type* UnsignedTypeFor(const Type* type);
/// If `type` is a unsigned integral, or vector of unsigned integral,
/// returns the signed type, otherwise returns `type`.
/// @param type the possibly unsigned type
/// @returns the signed type
- ast::Type* SignedTypeFor(ast::Type* type);
+ const Type* SignedTypeFor(const Type* type);
/// Parses the array or runtime-array decorations.
/// @param spv_type the SPIR-V array or runtime-array type.
- /// @param decorations the populated decoration list
+ /// @param array_stride pointer to the array stride
/// @returns true on success.
bool ParseArrayDecorations(const spvtools::opt::analysis::Type* spv_type,
- ast::DecorationList* decorations);
+ uint32_t* array_stride);
/// Adds `type` as a constructed type if it hasn't been added yet.
/// @param name the type's unique name
@@ -633,6 +632,9 @@
// The program builder.
ProgramBuilder builder_;
+ // The type manager.
+ TypeManager ty_;
+
// Is the parse successful?
bool success_ = true;
// Collector for diagnostic messages.
@@ -716,7 +718,7 @@
// usages implied by usages of the memory-object-declaration.
std::unordered_map<const spvtools::opt::Instruction*, Usage> handle_usage_;
// The inferred pointer type for the given handle variable.
- std::unordered_map<const spvtools::opt::Instruction*, ast::Pointer*>
+ std::unordered_map<const spvtools::opt::Instruction*, const Pointer*>
handle_type_;
// Set of symbols of constructed types that have been added, used to avoid
diff --git a/src/reader/spirv/parser_impl_convert_type_test.cc b/src/reader/spirv/parser_impl_convert_type_test.cc
index 180f119..3b32374 100644
--- a/src/reader/spirv/parser_impl_convert_type_test.cc
+++ b/src/reader/spirv/parser_impl_convert_type_test.cc
@@ -75,7 +75,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(1);
- EXPECT_TRUE(type->Is<ast::Void>());
+ EXPECT_TRUE(type->Is<Void>());
EXPECT_TRUE(p->error().empty());
}
@@ -84,7 +84,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(100);
- EXPECT_TRUE(type->Is<ast::Bool>());
+ EXPECT_TRUE(type->Is<Bool>());
EXPECT_TRUE(p->error().empty());
}
@@ -93,7 +93,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(2);
- EXPECT_TRUE(type->Is<ast::I32>());
+ EXPECT_TRUE(type->Is<I32>());
EXPECT_TRUE(p->error().empty());
}
@@ -102,7 +102,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::U32>());
+ EXPECT_TRUE(type->Is<U32>());
EXPECT_TRUE(p->error().empty());
}
@@ -111,7 +111,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(4);
- EXPECT_TRUE(type->Is<ast::F32>());
+ EXPECT_TRUE(type->Is<F32>());
EXPECT_TRUE(p->error().empty());
}
@@ -155,19 +155,19 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* v2xf32 = p->ConvertType(20);
- EXPECT_TRUE(v2xf32->Is<ast::Vector>());
- EXPECT_TRUE(v2xf32->As<ast::Vector>()->type()->Is<ast::F32>());
- EXPECT_EQ(v2xf32->As<ast::Vector>()->size(), 2u);
+ EXPECT_TRUE(v2xf32->Is<Vector>());
+ EXPECT_TRUE(v2xf32->As<Vector>()->type->Is<F32>());
+ EXPECT_EQ(v2xf32->As<Vector>()->size, 2u);
auto* v3xf32 = p->ConvertType(30);
- EXPECT_TRUE(v3xf32->Is<ast::Vector>());
- EXPECT_TRUE(v3xf32->As<ast::Vector>()->type()->Is<ast::F32>());
- EXPECT_EQ(v3xf32->As<ast::Vector>()->size(), 3u);
+ EXPECT_TRUE(v3xf32->Is<Vector>());
+ EXPECT_TRUE(v3xf32->As<Vector>()->type->Is<F32>());
+ EXPECT_EQ(v3xf32->As<Vector>()->size, 3u);
auto* v4xf32 = p->ConvertType(40);
- EXPECT_TRUE(v4xf32->Is<ast::Vector>());
- EXPECT_TRUE(v4xf32->As<ast::Vector>()->type()->Is<ast::F32>());
- EXPECT_EQ(v4xf32->As<ast::Vector>()->size(), 4u);
+ EXPECT_TRUE(v4xf32->Is<Vector>());
+ EXPECT_TRUE(v4xf32->As<Vector>()->type->Is<F32>());
+ EXPECT_EQ(v4xf32->As<Vector>()->size, 4u);
EXPECT_TRUE(p->error().empty());
}
@@ -182,19 +182,19 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* v2xi32 = p->ConvertType(20);
- EXPECT_TRUE(v2xi32->Is<ast::Vector>());
- EXPECT_TRUE(v2xi32->As<ast::Vector>()->type()->Is<ast::I32>());
- EXPECT_EQ(v2xi32->As<ast::Vector>()->size(), 2u);
+ EXPECT_TRUE(v2xi32->Is<Vector>());
+ EXPECT_TRUE(v2xi32->As<Vector>()->type->Is<I32>());
+ EXPECT_EQ(v2xi32->As<Vector>()->size, 2u);
auto* v3xi32 = p->ConvertType(30);
- EXPECT_TRUE(v3xi32->Is<ast::Vector>());
- EXPECT_TRUE(v3xi32->As<ast::Vector>()->type()->Is<ast::I32>());
- EXPECT_EQ(v3xi32->As<ast::Vector>()->size(), 3u);
+ EXPECT_TRUE(v3xi32->Is<Vector>());
+ EXPECT_TRUE(v3xi32->As<Vector>()->type->Is<I32>());
+ EXPECT_EQ(v3xi32->As<Vector>()->size, 3u);
auto* v4xi32 = p->ConvertType(40);
- EXPECT_TRUE(v4xi32->Is<ast::Vector>());
- EXPECT_TRUE(v4xi32->As<ast::Vector>()->type()->Is<ast::I32>());
- EXPECT_EQ(v4xi32->As<ast::Vector>()->size(), 4u);
+ EXPECT_TRUE(v4xi32->Is<Vector>());
+ EXPECT_TRUE(v4xi32->As<Vector>()->type->Is<I32>());
+ EXPECT_EQ(v4xi32->As<Vector>()->size, 4u);
EXPECT_TRUE(p->error().empty());
}
@@ -209,19 +209,19 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* v2xu32 = p->ConvertType(20);
- EXPECT_TRUE(v2xu32->Is<ast::Vector>());
- EXPECT_TRUE(v2xu32->As<ast::Vector>()->type()->Is<ast::U32>());
- EXPECT_EQ(v2xu32->As<ast::Vector>()->size(), 2u);
+ EXPECT_TRUE(v2xu32->Is<Vector>());
+ EXPECT_TRUE(v2xu32->As<Vector>()->type->Is<U32>());
+ EXPECT_EQ(v2xu32->As<Vector>()->size, 2u);
auto* v3xu32 = p->ConvertType(30);
- EXPECT_TRUE(v3xu32->Is<ast::Vector>());
- EXPECT_TRUE(v3xu32->As<ast::Vector>()->type()->Is<ast::U32>());
- EXPECT_EQ(v3xu32->As<ast::Vector>()->size(), 3u);
+ EXPECT_TRUE(v3xu32->Is<Vector>());
+ EXPECT_TRUE(v3xu32->As<Vector>()->type->Is<U32>());
+ EXPECT_EQ(v3xu32->As<Vector>()->size, 3u);
auto* v4xu32 = p->ConvertType(40);
- EXPECT_TRUE(v4xu32->Is<ast::Vector>());
- EXPECT_TRUE(v4xu32->As<ast::Vector>()->type()->Is<ast::U32>());
- EXPECT_EQ(v4xu32->As<ast::Vector>()->size(), 4u);
+ EXPECT_TRUE(v4xu32->Is<Vector>());
+ EXPECT_TRUE(v4xu32->As<Vector>()->type->Is<U32>());
+ EXPECT_EQ(v4xu32->As<Vector>()->size, 4u);
EXPECT_TRUE(p->error().empty());
}
@@ -261,58 +261,58 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* m22 = p->ConvertType(22);
- EXPECT_TRUE(m22->Is<ast::Matrix>());
- EXPECT_TRUE(m22->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m22->As<ast::Matrix>()->rows(), 2u);
- EXPECT_EQ(m22->As<ast::Matrix>()->columns(), 2u);
+ EXPECT_TRUE(m22->Is<Matrix>());
+ EXPECT_TRUE(m22->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m22->As<Matrix>()->rows, 2u);
+ EXPECT_EQ(m22->As<Matrix>()->columns, 2u);
auto* m23 = p->ConvertType(23);
- EXPECT_TRUE(m23->Is<ast::Matrix>());
- EXPECT_TRUE(m23->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m23->As<ast::Matrix>()->rows(), 2u);
- EXPECT_EQ(m23->As<ast::Matrix>()->columns(), 3u);
+ EXPECT_TRUE(m23->Is<Matrix>());
+ EXPECT_TRUE(m23->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m23->As<Matrix>()->rows, 2u);
+ EXPECT_EQ(m23->As<Matrix>()->columns, 3u);
auto* m24 = p->ConvertType(24);
- EXPECT_TRUE(m24->Is<ast::Matrix>());
- EXPECT_TRUE(m24->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m24->As<ast::Matrix>()->rows(), 2u);
- EXPECT_EQ(m24->As<ast::Matrix>()->columns(), 4u);
+ EXPECT_TRUE(m24->Is<Matrix>());
+ EXPECT_TRUE(m24->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m24->As<Matrix>()->rows, 2u);
+ EXPECT_EQ(m24->As<Matrix>()->columns, 4u);
auto* m32 = p->ConvertType(32);
- EXPECT_TRUE(m32->Is<ast::Matrix>());
- EXPECT_TRUE(m32->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m32->As<ast::Matrix>()->rows(), 3u);
- EXPECT_EQ(m32->As<ast::Matrix>()->columns(), 2u);
+ EXPECT_TRUE(m32->Is<Matrix>());
+ EXPECT_TRUE(m32->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m32->As<Matrix>()->rows, 3u);
+ EXPECT_EQ(m32->As<Matrix>()->columns, 2u);
auto* m33 = p->ConvertType(33);
- EXPECT_TRUE(m33->Is<ast::Matrix>());
- EXPECT_TRUE(m33->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m33->As<ast::Matrix>()->rows(), 3u);
- EXPECT_EQ(m33->As<ast::Matrix>()->columns(), 3u);
+ EXPECT_TRUE(m33->Is<Matrix>());
+ EXPECT_TRUE(m33->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m33->As<Matrix>()->rows, 3u);
+ EXPECT_EQ(m33->As<Matrix>()->columns, 3u);
auto* m34 = p->ConvertType(34);
- EXPECT_TRUE(m34->Is<ast::Matrix>());
- EXPECT_TRUE(m34->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m34->As<ast::Matrix>()->rows(), 3u);
- EXPECT_EQ(m34->As<ast::Matrix>()->columns(), 4u);
+ EXPECT_TRUE(m34->Is<Matrix>());
+ EXPECT_TRUE(m34->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m34->As<Matrix>()->rows, 3u);
+ EXPECT_EQ(m34->As<Matrix>()->columns, 4u);
auto* m42 = p->ConvertType(42);
- EXPECT_TRUE(m42->Is<ast::Matrix>());
- EXPECT_TRUE(m42->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m42->As<ast::Matrix>()->rows(), 4u);
- EXPECT_EQ(m42->As<ast::Matrix>()->columns(), 2u);
+ EXPECT_TRUE(m42->Is<Matrix>());
+ EXPECT_TRUE(m42->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m42->As<Matrix>()->rows, 4u);
+ EXPECT_EQ(m42->As<Matrix>()->columns, 2u);
auto* m43 = p->ConvertType(43);
- EXPECT_TRUE(m43->Is<ast::Matrix>());
- EXPECT_TRUE(m43->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m43->As<ast::Matrix>()->rows(), 4u);
- EXPECT_EQ(m43->As<ast::Matrix>()->columns(), 3u);
+ EXPECT_TRUE(m43->Is<Matrix>());
+ EXPECT_TRUE(m43->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m43->As<Matrix>()->rows, 4u);
+ EXPECT_EQ(m43->As<Matrix>()->columns, 3u);
auto* m44 = p->ConvertType(44);
- EXPECT_TRUE(m44->Is<ast::Matrix>());
- EXPECT_TRUE(m44->As<ast::Matrix>()->type()->Is<ast::F32>());
- EXPECT_EQ(m44->As<ast::Matrix>()->rows(), 4u);
- EXPECT_EQ(m44->As<ast::Matrix>()->columns(), 4u);
+ EXPECT_TRUE(m44->Is<Matrix>());
+ EXPECT_TRUE(m44->As<Matrix>()->type->Is<F32>());
+ EXPECT_EQ(m44->As<Matrix>()->rows, 4u);
+ EXPECT_EQ(m44->As<Matrix>()->columns, 4u);
EXPECT_TRUE(p->error().empty());
}
@@ -326,15 +326,14 @@
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
- EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is<ast::Array>());
- auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>();
- EXPECT_TRUE(arr_type->IsRuntimeArray());
+ EXPECT_TRUE(type->UnwrapAll()->Is<Array>());
+ auto* arr_type = type->UnwrapAll()->As<Array>();
ASSERT_NE(arr_type, nullptr);
- EXPECT_EQ(arr_type->size(), 0u);
- EXPECT_EQ(arr_type->decorations().size(), 0u);
- auto* elem_type = arr_type->type();
+ EXPECT_EQ(arr_type->size, 0u);
+ EXPECT_EQ(arr_type->stride, 0u);
+ auto* elem_type = arr_type->type;
ASSERT_NE(elem_type, nullptr);
- EXPECT_TRUE(elem_type->Is<ast::U32>());
+ EXPECT_TRUE(elem_type->Is<U32>());
EXPECT_TRUE(p->error().empty());
}
@@ -361,14 +360,10 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
- auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>();
- EXPECT_TRUE(arr_type->IsRuntimeArray());
+ auto* arr_type = type->UnwrapAll()->As<Array>();
+ EXPECT_EQ(arr_type->size, 0u);
ASSERT_NE(arr_type, nullptr);
- ASSERT_EQ(arr_type->decorations().size(), 1u);
- auto* stride = arr_type->decorations()[0];
- ASSERT_TRUE(stride->Is<ast::StrideDecoration>());
- ASSERT_EQ(stride->As<ast::StrideDecoration>()->stride(), 64u);
- EXPECT_TRUE(p->error().empty());
+ EXPECT_EQ(arr_type->stride, 64u);
}
TEST_F(SpvParserTest, ConvertType_RuntimeArray_ArrayStride_ZeroIsError) {
@@ -409,15 +404,14 @@
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
- EXPECT_TRUE(type->Is<ast::Array>());
- auto* arr_type = type->As<ast::Array>();
- EXPECT_FALSE(arr_type->IsRuntimeArray());
+ EXPECT_TRUE(type->Is<Array>());
+ auto* arr_type = type->As<Array>();
ASSERT_NE(arr_type, nullptr);
- EXPECT_EQ(arr_type->size(), 42u);
- EXPECT_EQ(arr_type->decorations().size(), 0u);
- auto* elem_type = arr_type->type();
+ EXPECT_EQ(arr_type->size, 42u);
+ EXPECT_EQ(arr_type->stride, 0u);
+ auto* elem_type = arr_type->type;
ASSERT_NE(elem_type, nullptr);
- EXPECT_TRUE(elem_type->Is<ast::U32>());
+ EXPECT_TRUE(elem_type->Is<U32>());
EXPECT_TRUE(p->error().empty());
}
@@ -496,15 +490,10 @@
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
- EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is<ast::Array>());
- auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>();
+ EXPECT_TRUE(type->UnwrapAll()->Is<Array>());
+ auto* arr_type = type->UnwrapAll()->As<Array>();
ASSERT_NE(arr_type, nullptr);
-
- ASSERT_EQ(arr_type->decorations().size(), 1u);
- auto* stride = arr_type->decorations()[0];
- ASSERT_TRUE(stride->Is<ast::StrideDecoration>());
- ASSERT_EQ(stride->As<ast::StrideDecoration>()->stride(), 8u);
-
+ EXPECT_EQ(arr_type->stride, 8u);
EXPECT_TRUE(p->error().empty());
}
@@ -550,14 +539,11 @@
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
- EXPECT_TRUE(type->Is<ast::Struct>());
+ EXPECT_TRUE(type->Is<Struct>());
+ auto* str = type->Build(p->builder());
Program program = p->program();
- EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S {
- StructMember{field0: __u32}
- StructMember{field1: __f32}
-}
-)"));
+ EXPECT_THAT(program.str(str), Eq(R"(__type_name_S)"));
}
TEST_F(SpvParserTest, ConvertType_StructWithBlockDecoration) {
@@ -571,14 +557,11 @@
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
- EXPECT_TRUE(type->Is<ast::Struct>());
+ EXPECT_TRUE(type->Is<Struct>());
+ auto* str = type->Build(p->builder());
Program program = p->program();
- EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S {
- [[block]]
- StructMember{field0: __u32}
-}
-)"));
+ EXPECT_THAT(program.str(str), Eq(R"(__type_name_S)"));
}
TEST_F(SpvParserTest, ConvertType_StructWithMemberDecorations) {
@@ -596,15 +579,11 @@
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
- EXPECT_TRUE(type->Is<ast::Struct>());
+ EXPECT_TRUE(type->Is<Struct>());
+ auto* str = type->Build(p->builder());
Program program = p->program();
- EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S {
- StructMember{[[ offset 0 ]] field0: __f32}
- StructMember{[[ offset 8 ]] field1: __vec_2__f32}
- StructMember{[[ offset 16 ]] field2: __mat_2_2__f32}
-}
-)"));
+ EXPECT_THAT(program.str(str), Eq(R"(__type_name_S)"));
}
// TODO(dneto): Demonstrate other member decorations. Blocked on
@@ -645,11 +624,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kInput);
EXPECT_TRUE(p->error().empty());
}
@@ -661,11 +640,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kOutput);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kOutput);
EXPECT_TRUE(p->error().empty());
}
@@ -677,11 +656,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniform);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kUniform);
EXPECT_TRUE(p->error().empty());
}
@@ -693,11 +672,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kWorkgroup);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kWorkgroup);
EXPECT_TRUE(p->error().empty());
}
@@ -709,11 +688,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniformConstant);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kUniformConstant);
EXPECT_TRUE(p->error().empty());
}
@@ -725,11 +704,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kStorage);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kStorage);
EXPECT_TRUE(p->error().empty());
}
@@ -741,11 +720,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kImage);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kImage);
EXPECT_TRUE(p->error().empty());
}
@@ -757,11 +736,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kPrivate);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kPrivate);
EXPECT_TRUE(p->error().empty());
}
@@ -773,11 +752,11 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(3);
- EXPECT_TRUE(type->Is<ast::Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ EXPECT_TRUE(type->Is<Pointer>());
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>());
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kFunction);
+ EXPECT_TRUE(ptr_ty->type->Is<F32>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kFunction);
EXPECT_TRUE(p->error().empty());
}
@@ -792,17 +771,17 @@
auto* type = p->ConvertType(3);
EXPECT_NE(type, nullptr);
- EXPECT_TRUE(type->Is<ast::Pointer>());
+ EXPECT_TRUE(type->Is<Pointer>());
- auto* ptr_ty = type->As<ast::Pointer>();
+ auto* ptr_ty = type->As<Pointer>();
EXPECT_NE(ptr_ty, nullptr);
- EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput);
- EXPECT_TRUE(ptr_ty->type()->Is<ast::Pointer>());
+ EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kInput);
+ EXPECT_TRUE(ptr_ty->type->Is<Pointer>());
- auto* ptr_ptr_ty = ptr_ty->type()->As<ast::Pointer>();
+ auto* ptr_ptr_ty = ptr_ty->type->As<Pointer>();
EXPECT_NE(ptr_ptr_ty, nullptr);
- EXPECT_EQ(ptr_ptr_ty->storage_class(), ast::StorageClass::kOutput);
- EXPECT_TRUE(ptr_ptr_ty->type()->Is<ast::F32>());
+ EXPECT_EQ(ptr_ptr_ty->storage_class, ast::StorageClass::kOutput);
+ EXPECT_TRUE(ptr_ptr_ty->type->Is<F32>());
EXPECT_TRUE(p->error().empty());
}
@@ -815,7 +794,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(1);
- EXPECT_TRUE(type->Is<ast::Void>());
+ EXPECT_TRUE(type->Is<Void>());
EXPECT_TRUE(p->error().empty());
}
@@ -828,7 +807,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(1);
- EXPECT_TRUE(type->Is<ast::Void>());
+ EXPECT_TRUE(type->Is<Void>());
EXPECT_TRUE(p->error().empty());
}
@@ -841,7 +820,7 @@
EXPECT_TRUE(p->BuildInternalModule());
auto* type = p->ConvertType(1);
- EXPECT_TRUE(type->Is<ast::Void>());
+ EXPECT_TRUE(type->Is<Void>());
EXPECT_TRUE(p->error().empty());
}
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index 91002a0..c8afb2c 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -1772,7 +1772,7 @@
[[block]]
StructMember{[[ offset 0 ]] field0: __u32}
StructMember{[[ offset 4 ]] field1: __f32}
- StructMember{[[ offset 8 ]] field2: __alias_Arr__array__u32_2_stride_4}
+ StructMember{[[ offset 8 ]] field2: __type_name_Arr}
}
Variable{
x_1
diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h
index bfa5ca6..dd83ca6 100644
--- a/src/reader/spirv/parser_impl_test_helper.h
+++ b/src/reader/spirv/parser_impl_test_helper.h
@@ -158,7 +158,7 @@
/// after the internal representation of the module has been built.
/// @param id the SPIR-V ID of a type.
/// @returns a Tint type, or nullptr
- ast::Type* ConvertType(uint32_t id) { return impl_.ConvertType(id); }
+ const Type* ConvertType(uint32_t id) { return impl_.ConvertType(id); }
/// Gets the list of decorations for a SPIR-V result ID. Returns an empty
/// vector if the ID is not a result ID, or if no decorations target that ID.
diff --git a/src/reader/spirv/parser_type.cc b/src/reader/spirv/parser_type.cc
new file mode 100644
index 0000000..4f2ca5a
--- /dev/null
+++ b/src/reader/spirv/parser_type.cc
@@ -0,0 +1,514 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or stateied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/reader/spirv/parser_type.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/program_builder.h"
+#include "src/utils/get_or_create.h"
+#include "src/utils/hash.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Type);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Void);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Bool);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::U32);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::F32);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::I32);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Pointer);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Vector);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Matrix);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Array);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::AccessControl);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Sampler);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Texture);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::DepthTexture);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::MultisampledTexture);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::SampledTexture);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::StorageTexture);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Named);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Alias);
+TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Struct);
+
+namespace tint {
+namespace reader {
+namespace spirv {
+
+namespace {
+struct PointerHasher {
+ size_t operator()(const Pointer& t) const {
+ return utils::Hash(t.type, t.storage_class);
+ }
+};
+
+struct VectorHasher {
+ size_t operator()(const Vector& t) const {
+ return utils::Hash(t.type, t.size);
+ }
+};
+
+struct MatrixHasher {
+ size_t operator()(const Matrix& t) const {
+ return utils::Hash(t.type, t.columns, t.rows);
+ }
+};
+
+struct ArrayHasher {
+ size_t operator()(const Array& t) const {
+ return utils::Hash(t.type, t.size, t.stride);
+ }
+};
+
+struct AccessControlHasher {
+ size_t operator()(const AccessControl& t) const {
+ return utils::Hash(t.type, t.access);
+ }
+};
+
+struct MultisampledTextureHasher {
+ size_t operator()(const MultisampledTexture& t) const {
+ return utils::Hash(t.dims, t.type);
+ }
+};
+
+struct SampledTextureHasher {
+ size_t operator()(const SampledTexture& t) const {
+ return utils::Hash(t.dims, t.type);
+ }
+};
+
+struct StorageTextureHasher {
+ size_t operator()(const StorageTexture& t) const {
+ return utils::Hash(t.dims, t.format);
+ }
+};
+} // namespace
+
+static bool operator==(const Pointer& a, const Pointer& b) {
+ return a.type == b.type && a.storage_class == b.storage_class;
+}
+
+static bool operator==(const Vector& a, const Vector& b) {
+ return a.type == b.type && a.size == b.size;
+}
+
+static bool operator==(const Matrix& a, const Matrix& b) {
+ return a.type == b.type && a.columns == b.columns && a.rows == b.rows;
+}
+
+static bool operator==(const Array& a, const Array& b) {
+ return a.type == b.type && a.size == b.size && a.stride == b.stride;
+}
+
+static bool operator==(const AccessControl& a, const AccessControl& b) {
+ return a.type == b.type && a.access == b.access;
+}
+
+static bool operator==(const MultisampledTexture& a,
+ const MultisampledTexture& b) {
+ return a.dims == b.dims && a.type == b.type;
+}
+
+static bool operator==(const SampledTexture& a, const SampledTexture& b) {
+ return a.dims == b.dims && a.type == b.type;
+}
+
+static bool operator==(const StorageTexture& a, const StorageTexture& b) {
+ return a.dims == b.dims && a.format == b.format;
+}
+
+ast::Type* Void::Build(ProgramBuilder& b) const {
+ return b.ty.void_();
+}
+
+ast::Type* Bool::Build(ProgramBuilder& b) const {
+ return b.ty.bool_();
+}
+
+ast::Type* U32::Build(ProgramBuilder& b) const {
+ return b.ty.u32();
+}
+
+ast::Type* F32::Build(ProgramBuilder& b) const {
+ return b.ty.f32();
+}
+
+ast::Type* I32::Build(ProgramBuilder& b) const {
+ return b.ty.i32();
+}
+
+Pointer::Pointer(const Type* t, ast::StorageClass s)
+ : type(t), storage_class(s) {}
+Pointer::Pointer(const Pointer&) = default;
+
+ast::Type* Pointer::Build(ProgramBuilder& b) const {
+ return b.ty.pointer(type->Build(b), storage_class);
+}
+
+Vector::Vector(const Type* t, uint32_t s) : type(t), size(s) {}
+Vector::Vector(const Vector&) = default;
+
+ast::Type* Vector::Build(ProgramBuilder& b) const {
+ return b.ty.vec(type->Build(b), size);
+}
+
+Matrix::Matrix(const Type* t, uint32_t c, uint32_t r)
+ : type(t), columns(c), rows(r) {}
+Matrix::Matrix(const Matrix&) = default;
+
+ast::Type* Matrix::Build(ProgramBuilder& b) const {
+ return b.ty.mat(type->Build(b), columns, rows);
+}
+
+Array::Array(const Type* t, uint32_t sz, uint32_t st)
+ : type(t), size(sz), stride(st) {}
+Array::Array(const Array&) = default;
+
+ast::Type* Array::Build(ProgramBuilder& b) const {
+ return b.ty.array(type->Build(b), size, stride);
+}
+
+AccessControl::AccessControl(const Type* t, ast::AccessControl::Access a)
+ : type(t), access(a) {}
+AccessControl::AccessControl(const AccessControl&) = default;
+
+ast::Type* AccessControl::Build(ProgramBuilder& b) const {
+ return b.ty.access(access, type->Build(b));
+}
+
+Sampler::Sampler(ast::SamplerKind k) : kind(k) {}
+Sampler::Sampler(const Sampler&) = default;
+
+ast::Type* Sampler::Build(ProgramBuilder& b) const {
+ return b.ty.sampler(kind);
+}
+
+Texture::Texture(ast::TextureDimension d) : dims(d) {}
+Texture::Texture(const Texture&) = default;
+
+DepthTexture::DepthTexture(ast::TextureDimension d) : Base(d) {}
+DepthTexture::DepthTexture(const DepthTexture&) = default;
+
+ast::Type* DepthTexture::Build(ProgramBuilder& b) const {
+ return b.ty.depth_texture(dims);
+}
+
+MultisampledTexture::MultisampledTexture(ast::TextureDimension d, const Type* t)
+ : Base(d), type(t) {}
+MultisampledTexture::MultisampledTexture(const MultisampledTexture&) = default;
+
+ast::Type* MultisampledTexture::Build(ProgramBuilder& b) const {
+ return b.ty.multisampled_texture(dims, type->Build(b));
+}
+
+SampledTexture::SampledTexture(ast::TextureDimension d, const Type* t)
+ : Base(d), type(t) {}
+SampledTexture::SampledTexture(const SampledTexture&) = default;
+
+ast::Type* SampledTexture::Build(ProgramBuilder& b) const {
+ return b.ty.sampled_texture(dims, type->Build(b));
+}
+
+StorageTexture::StorageTexture(ast::TextureDimension d, ast::ImageFormat f)
+ : Base(d), format(f) {}
+StorageTexture::StorageTexture(const StorageTexture&) = default;
+
+ast::Type* StorageTexture::Build(ProgramBuilder& b) const {
+ return b.ty.storage_texture(dims, format);
+}
+
+Named::Named(Symbol n) : name(n) {}
+Named::Named(const Named&) = default;
+Named::~Named() = default;
+
+Alias::Alias(Symbol n, const Type* ty) : Base(n), type(ty) {}
+Alias::Alias(const Alias&) = default;
+
+ast::Type* Alias::Build(ProgramBuilder& b) const {
+ return b.ty.type_name(name);
+}
+
+Struct::Struct(Symbol n, TypeList m) : Base(n), members(std::move(m)) {}
+Struct::Struct(const Struct&) = default;
+Struct::~Struct() = default;
+
+ast::Type* Struct::Build(ProgramBuilder& b) const {
+ return b.ty.type_name(name);
+}
+
+/// The PIMPL state of the Types object.
+struct TypeManager::State {
+ /// The allocator of types
+ BlockAllocator<Type> allocator_;
+ /// The lazily-created Void type
+ spirv::Void const* void_ = nullptr;
+ /// The lazily-created Bool type
+ spirv::Bool const* bool_ = nullptr;
+ /// The lazily-created U32 type
+ spirv::U32 const* u32_ = nullptr;
+ /// The lazily-created F32 type
+ spirv::F32 const* f32_ = nullptr;
+ /// The lazily-created I32 type
+ spirv::I32 const* i32_ = nullptr;
+ /// Map of Pointer to the returned Pointer type instance
+ std::unordered_map<spirv::Pointer, const spirv::Pointer*, PointerHasher>
+ pointers_;
+ /// Map of Vector to the returned Vector type instance
+ std::unordered_map<spirv::Vector, const spirv::Vector*, VectorHasher>
+ vectors_;
+ /// Map of Matrix to the returned Matrix type instance
+ std::unordered_map<spirv::Matrix, const spirv::Matrix*, MatrixHasher>
+ matrices_;
+ /// Map of Array to the returned Array type instance
+ std::unordered_map<spirv::Array, const spirv::Array*, ArrayHasher> arrays_;
+ /// Map of AccessControl to the returned AccessControl type instance
+ std::unordered_map<spirv::AccessControl,
+ const spirv::AccessControl*,
+ AccessControlHasher>
+ access_controls_;
+ /// Map of type name to returned Alias instance
+ std::unordered_map<Symbol, const spirv::Alias*> aliases_;
+ /// Map of type name to returned Struct instance
+ std::unordered_map<Symbol, const spirv::Struct*> structs_;
+ /// Map of ast::SamplerKind to returned Sampler instance
+ std::unordered_map<ast::SamplerKind, const spirv::Sampler*> samplers_;
+ /// Map of ast::TextureDimension to returned DepthTexture instance
+ std::unordered_map<ast::TextureDimension, const spirv::DepthTexture*>
+ depth_textures_;
+ /// Map of MultisampledTexture to the returned MultisampledTexture type
+ /// instance
+ std::unordered_map<spirv::MultisampledTexture,
+ const spirv::MultisampledTexture*,
+ MultisampledTextureHasher>
+ multisampled_textures_;
+ /// Map of SampledTexture to the returned SampledTexture type instance
+ std::unordered_map<spirv::SampledTexture,
+ const spirv::SampledTexture*,
+ SampledTextureHasher>
+ sampled_textures_;
+ /// Map of StorageTexture to the returned StorageTexture type instance
+ std::unordered_map<spirv::StorageTexture,
+ const spirv::StorageTexture*,
+ StorageTextureHasher>
+ storage_textures_;
+};
+
+const Type* Type::UnwrapPtrIfNeeded() const {
+ if (auto* ptr = As<Pointer>()) {
+ return ptr->type;
+ }
+ return this;
+}
+
+const Type* Type::UnwrapAliasIfNeeded() const {
+ const Type* unwrapped = this;
+ while (auto* ptr = unwrapped->As<Alias>()) {
+ unwrapped = ptr->type;
+ }
+ return unwrapped;
+}
+
+const Type* Type::UnwrapIfNeeded() const {
+ auto* where = this;
+ while (true) {
+ if (auto* alias = where->As<Alias>()) {
+ where = alias->type;
+ } else if (auto* access = where->As<AccessControl>()) {
+ where = access->type;
+ } else {
+ break;
+ }
+ }
+ return where;
+}
+
+const Type* Type::UnwrapAll() const {
+ return UnwrapIfNeeded()->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
+}
+
+bool Type::IsFloatScalar() const {
+ return Is<F32>();
+}
+
+bool Type::IsFloatScalarOrVector() const {
+ return IsFloatScalar() || IsFloatVector();
+}
+
+bool Type::IsFloatVector() const {
+ return Is<Vector>([](const Vector* v) { return v->type->IsFloatScalar(); });
+}
+
+bool Type::IsIntegerScalar() const {
+ return IsAnyOf<U32, I32>();
+}
+
+bool Type::IsIntegerScalarOrVector() const {
+ return IsUnsignedScalarOrVector() || IsSignedScalarOrVector();
+}
+
+bool Type::IsScalar() const {
+ return IsAnyOf<F32, U32, I32, Bool>();
+}
+
+bool Type::IsSignedIntegerVector() const {
+ return Is<Vector>([](const Vector* v) { return v->type->Is<I32>(); });
+}
+
+bool Type::IsSignedScalarOrVector() const {
+ return Is<I32>() || IsSignedIntegerVector();
+}
+
+bool Type::IsUnsignedIntegerVector() const {
+ return Is<Vector>([](const Vector* v) { return v->type->Is<U32>(); });
+}
+
+bool Type::IsUnsignedScalarOrVector() const {
+ return Is<U32>() || IsUnsignedIntegerVector();
+}
+
+TypeManager::TypeManager() {
+ state = std::make_unique<State>();
+}
+
+TypeManager::~TypeManager() = default;
+
+const spirv::Void* TypeManager::Void() {
+ if (!state->void_) {
+ state->void_ = state->allocator_.Create<spirv::Void>();
+ }
+ return state->void_;
+}
+
+const spirv::Bool* TypeManager::Bool() {
+ if (!state->bool_) {
+ state->bool_ = state->allocator_.Create<spirv::Bool>();
+ }
+ return state->bool_;
+}
+
+const spirv::U32* TypeManager::U32() {
+ if (!state->u32_) {
+ state->u32_ = state->allocator_.Create<spirv::U32>();
+ }
+ return state->u32_;
+}
+
+const spirv::F32* TypeManager::F32() {
+ if (!state->f32_) {
+ state->f32_ = state->allocator_.Create<spirv::F32>();
+ }
+ return state->f32_;
+}
+
+const spirv::I32* TypeManager::I32() {
+ if (!state->i32_) {
+ state->i32_ = state->allocator_.Create<spirv::I32>();
+ }
+ return state->i32_;
+}
+
+const spirv::Pointer* TypeManager::Pointer(const Type* el,
+ ast::StorageClass sc) {
+ return utils::GetOrCreate(state->pointers_, spirv::Pointer(el, sc), [&] {
+ return state->allocator_.Create<spirv::Pointer>(el, sc);
+ });
+}
+
+const spirv::Vector* TypeManager::Vector(const Type* el, uint32_t size) {
+ return utils::GetOrCreate(state->vectors_, spirv::Vector(el, size), [&] {
+ return state->allocator_.Create<spirv::Vector>(el, size);
+ });
+}
+
+const spirv::Matrix* TypeManager::Matrix(const Type* el,
+ uint32_t columns,
+ uint32_t rows) {
+ return utils::GetOrCreate(
+ state->matrices_, spirv::Matrix(el, columns, rows), [&] {
+ return state->allocator_.Create<spirv::Matrix>(el, columns, rows);
+ });
+}
+
+const spirv::Array* TypeManager::Array(const Type* el,
+ uint32_t size,
+ uint32_t stride) {
+ return utils::GetOrCreate(
+ state->arrays_, spirv::Array(el, size, stride),
+ [&] { return state->allocator_.Create<spirv::Array>(el, size, stride); });
+}
+
+const spirv::AccessControl* TypeManager::AccessControl(
+ const Type* ty,
+ ast::AccessControl::Access ac) {
+ return utils::GetOrCreate(
+ state->access_controls_, spirv::AccessControl(ty, ac),
+ [&] { return state->allocator_.Create<spirv::AccessControl>(ty, ac); });
+}
+
+const spirv::Alias* TypeManager::Alias(Symbol name, const Type* ty) {
+ return utils::GetOrCreate(state->aliases_, name, [&] {
+ return state->allocator_.Create<spirv::Alias>(name, ty);
+ });
+}
+
+const spirv::Struct* TypeManager::Struct(Symbol name, TypeList members) {
+ return utils::GetOrCreate(state->structs_, name, [&] {
+ return state->allocator_.Create<spirv::Struct>(name, std::move(members));
+ });
+}
+
+const spirv::Sampler* TypeManager::Sampler(ast::SamplerKind kind) {
+ return utils::GetOrCreate(state->samplers_, kind, [&] {
+ return state->allocator_.Create<spirv::Sampler>(kind);
+ });
+}
+
+const spirv::DepthTexture* TypeManager::DepthTexture(
+ ast::TextureDimension dims) {
+ return utils::GetOrCreate(state->depth_textures_, dims, [&] {
+ return state->allocator_.Create<spirv::DepthTexture>(dims);
+ });
+}
+
+const spirv::MultisampledTexture* TypeManager::MultisampledTexture(
+ ast::TextureDimension dims,
+ const Type* ty) {
+ return utils::GetOrCreate(
+ state->multisampled_textures_, spirv::MultisampledTexture(dims, ty), [&] {
+ return state->allocator_.Create<spirv::MultisampledTexture>(dims, ty);
+ });
+}
+
+const spirv::SampledTexture* TypeManager::SampledTexture(
+ ast::TextureDimension dims,
+ const Type* ty) {
+ return utils::GetOrCreate(
+ state->sampled_textures_, spirv::SampledTexture(dims, ty), [&] {
+ return state->allocator_.Create<spirv::SampledTexture>(dims, ty);
+ });
+}
+
+const spirv::StorageTexture* TypeManager::StorageTexture(
+ ast::TextureDimension dims,
+ ast::ImageFormat fmt) {
+ return utils::GetOrCreate(
+ state->storage_textures_, spirv::StorageTexture(dims, fmt), [&] {
+ return state->allocator_.Create<spirv::StorageTexture>(dims, fmt);
+ });
+}
+
+} // namespace spirv
+} // namespace reader
+} // namespace tint
diff --git a/src/reader/spirv/parser_type.h b/src/reader/spirv/parser_type.h
new file mode 100644
index 0000000..c3ca2bd
--- /dev/null
+++ b/src/reader/spirv/parser_type.h
@@ -0,0 +1,498 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_READER_SPIRV_PARSER_TYPE_H_
+#define SRC_READER_SPIRV_PARSER_TYPE_H_
+
+#include <memory>
+#include <vector>
+
+#include "src/ast/access_control.h"
+#include "src/ast/sampler.h"
+#include "src/ast/storage_class.h"
+#include "src/ast/storage_texture.h"
+#include "src/ast/texture.h"
+#include "src/block_allocator.h"
+#include "src/castable.h"
+
+// Forward declarations
+namespace tint {
+class ProgramBuilder;
+namespace ast {
+class Type;
+} // namespace ast
+} // namespace tint
+
+namespace tint {
+namespace reader {
+namespace spirv {
+
+/// Type is the base class for all types
+class Type : public Castable<Type> {
+ public:
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ virtual ast::Type* Build(ProgramBuilder& b) const = 0;
+
+ /// @returns the pointee type if this is a pointer, `this` otherwise
+ const Type* UnwrapPtrIfNeeded() const;
+
+ /// @returns the most deeply nested aliased type if this is an alias, `this`
+ /// otherwise
+ const Type* UnwrapAliasIfNeeded() const;
+
+ /// Removes all levels of aliasing and access control.
+ /// This is just enough to assist with WGSL translation
+ /// in that you want see through one level of pointer to get from an
+ /// identifier-like expression as an l-value to its corresponding r-value,
+ /// plus see through the wrappers on either side.
+ /// @returns the completely unaliased type.
+ const Type* UnwrapIfNeeded() const;
+
+ /// Returns the type found after:
+ /// - removing all layers of aliasing and access control if they exist, then
+ /// - removing the pointer, if it exists, then
+ /// - removing all further layers of aliasing or access control, if they exist
+ /// @returns the unwrapped type
+ const Type* UnwrapAll() const;
+
+ /// @returns true if this type is a float scalar
+ bool IsFloatScalar() const;
+ /// @returns true if this type is a float scalar or vector
+ bool IsFloatScalarOrVector() const;
+ /// @returns true if this type is a float vector
+ bool IsFloatVector() const;
+ /// @returns true if this type is an integer scalar
+ bool IsIntegerScalar() const;
+ /// @returns true if this type is an integer scalar or vector
+ bool IsIntegerScalarOrVector() const;
+ /// @returns true if this type is a scalar
+ bool IsScalar() const;
+ /// @returns true if this type is a signed integer vector
+ bool IsSignedIntegerVector() const;
+ /// @returns true if this type is a signed scalar or vector
+ bool IsSignedScalarOrVector() const;
+ /// @returns true if this type is an unsigned integer vector
+ bool IsUnsignedIntegerVector() const;
+ /// @returns true if this type is an unsigned scalar or vector
+ bool IsUnsignedScalarOrVector() const;
+};
+
+using TypeList = std::vector<const Type*>;
+
+/// `void` type
+struct Void : public Castable<Void, Type> {
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+};
+
+/// `bool` type
+struct Bool : public Castable<Bool, Type> {
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+};
+
+/// `u32` type
+struct U32 : public Castable<U32, Type> {
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+};
+
+/// `f32` type
+struct F32 : public Castable<F32, Type> {
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+};
+
+/// `i32` type
+struct I32 : public Castable<I32, Type> {
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+};
+
+/// `ptr<SC, T>` type
+struct Pointer : public Castable<Pointer, Type> {
+ /// Constructor
+ /// @param ty the pointee type
+ /// @param sc the pointer storage class
+ Pointer(const Type* ty, ast::StorageClass sc);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Pointer(const Pointer& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the pointee type
+ Type const* const type;
+ /// the pointer storage class
+ ast::StorageClass const storage_class;
+};
+
+/// `vecN<T>` type
+struct Vector : public Castable<Vector, Type> {
+ /// Constructor
+ /// @param ty the element type
+ /// @param sz the number of elements in the vector
+ Vector(const Type* ty, uint32_t sz);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Vector(const Vector& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the element type
+ Type const* const type;
+ /// the number of elements in the vector
+ uint32_t const size;
+};
+
+/// `matNxM<T>` type
+struct Matrix : public Castable<Matrix, Type> {
+ /// Constructor
+ /// @param ty the matrix element type
+ /// @param c the number of columns in the matrix
+ /// @param r the number of rows in the matrix
+ Matrix(const Type* ty, uint32_t c, uint32_t r);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Matrix(const Matrix& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the matrix element type
+ Type const* const type;
+ /// the number of columns in the matrix
+ uint32_t const columns;
+ /// the number of rows in the matrix
+ uint32_t const rows;
+};
+
+/// `array<T, N>` type
+struct Array : public Castable<Array, Type> {
+ /// Constructor
+ /// @param el the element type
+ /// @param sz the number of elements in the array. 0 represents runtime-sized
+ /// array.
+ /// @param st the byte stride of the array
+ Array(const Type* el, uint32_t sz, uint32_t st);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Array(const Array& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the element type
+ Type const* const type;
+ /// the number of elements in the array. 0 represents runtime-sized array.
+ uint32_t const size;
+ /// the byte stride of the array
+ uint32_t const stride;
+};
+
+/// `[[access]]` type
+struct AccessControl : public Castable<AccessControl, Type> {
+ /// Constructor
+ /// @param ty the inner type
+ /// @param ac the access control
+ AccessControl(const Type* ty, ast::AccessControl::Access ac);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ AccessControl(const AccessControl& other);
+
+ /// @return the
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the inner type
+ Type const* const type;
+ /// the access control
+ ast::AccessControl::Access const access;
+};
+
+/// `sampler` type
+struct Sampler : public Castable<Sampler, Type> {
+ /// Constructor
+ /// @param k the sampler kind
+ explicit Sampler(ast::SamplerKind k);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Sampler(const Sampler& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the sampler kind
+ ast::SamplerKind const kind;
+};
+
+/// Base class for texture types
+struct Texture : public Castable<Texture, Type> {
+ /// Constructor
+ /// @param d the texture dimensions
+ explicit Texture(ast::TextureDimension d);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Texture(const Texture& other);
+
+ /// the texture dimensions
+ ast::TextureDimension const dims;
+};
+
+/// `texture_depth_D` type
+struct DepthTexture : public Castable<DepthTexture, Texture> {
+ /// Constructor
+ /// @param d the texture dimensions
+ explicit DepthTexture(ast::TextureDimension d);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ DepthTexture(const DepthTexture& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+};
+
+/// `texture_multisampled_D<T>` type
+struct MultisampledTexture : public Castable<MultisampledTexture, Texture> {
+ /// Constructor
+ /// @param d the texture dimensions
+ /// @param t the multisampled texture type
+ MultisampledTexture(ast::TextureDimension d, const Type* t);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ MultisampledTexture(const MultisampledTexture& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the multisampled texture type
+ Type const* const type;
+};
+
+/// `texture_D<T>` type
+struct SampledTexture : public Castable<SampledTexture, Texture> {
+ /// Constructor
+ /// @param d the texture dimensions
+ /// @param t the sampled texture type
+ SampledTexture(ast::TextureDimension d, const Type* t);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ SampledTexture(const SampledTexture& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the sampled texture type
+ Type const* const type;
+};
+
+/// `texture_storage_D<F>` type
+struct StorageTexture : public Castable<StorageTexture, Texture> {
+ /// Constructor
+ /// @param d the texture dimensions
+ /// @param f the storage image format
+ StorageTexture(ast::TextureDimension d, ast::ImageFormat f);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ StorageTexture(const StorageTexture& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the storage image format
+ ast::ImageFormat const format;
+};
+
+/// Base class for named types
+struct Named : public Castable<Named, Type> {
+ /// Constructor
+ /// @param n the type name
+ explicit Named(Symbol n);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Named(const Named& other);
+
+ /// Destructor
+ ~Named() override;
+
+ /// the type name
+ Symbol const name;
+};
+
+/// `type T = N` type
+struct Alias : public Castable<Alias, Named> {
+ /// Constructor
+ /// @param n the alias name
+ /// @param t the aliased type
+ Alias(Symbol n, const Type* t);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Alias(const Alias& other);
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the aliased type
+ Type const* const type;
+};
+
+/// `struct N { ... };` type
+struct Struct : public Castable<Struct, Named> {
+ /// Constructor
+ /// @param n the struct name
+ /// @param m the member types
+ Struct(Symbol n, TypeList m);
+
+ /// Copy constructor
+ /// @param other the other type to copy
+ Struct(const Struct& other);
+
+ /// Destructor
+ ~Struct() override;
+
+ /// @param b the ProgramBuilder used to construct the AST types
+ /// @returns the constructed ast::Type node for the given type
+ ast::Type* Build(ProgramBuilder& b) const override;
+
+ /// the member types
+ TypeList const members;
+};
+
+/// A manager of types
+class TypeManager {
+ public:
+ /// Constructor
+ TypeManager();
+
+ /// Destructor
+ ~TypeManager();
+
+ /// @return a Void type. Repeated calls will return the same pointer.
+ const spirv::Void* Void();
+ /// @return a Bool type. Repeated calls will return the same pointer.
+ const spirv::Bool* Bool();
+ /// @return a U32 type. Repeated calls will return the same pointer.
+ const spirv::U32* U32();
+ /// @return a F32 type. Repeated calls will return the same pointer.
+ const spirv::F32* F32();
+ /// @return a I32 type. Repeated calls will return the same pointer.
+ const spirv::I32* I32();
+ /// @param ty the pointee type
+ /// @param sc the pointer storage class
+ /// @return a Pointer type. Repeated calls with the same arguments will return
+ /// the same pointer.
+ const spirv::Pointer* Pointer(const Type* ty, ast::StorageClass sc);
+ /// @param ty the element type
+ /// @param sz the number of elements in the vector
+ /// @return a Vector type. Repeated calls with the same arguments will return
+ /// the same pointer.
+ const spirv::Vector* Vector(const Type* ty, uint32_t sz);
+ /// @param ty the matrix element type
+ /// @param c the number of columns in the matrix
+ /// @param r the number of rows in the matrix
+ /// @return a Matrix type. Repeated calls with the same arguments will return
+ /// the same pointer.
+ const spirv::Matrix* Matrix(const Type* ty, uint32_t c, uint32_t r);
+ /// @param el the element type
+ /// @param sz the number of elements in the array. 0 represents runtime-sized
+ /// array.
+ /// @param st the byte stride of the array
+ /// @return a Array type. Repeated calls with the same arguments will return
+ /// the same pointer.
+ const spirv::Array* Array(const Type* el, uint32_t sz, uint32_t st);
+ /// @param ty the inner type
+ /// @param ac the access control
+ /// @return a AccessControl type. Repeated calls with the same arguments will
+ /// return the same pointer.
+ const spirv::AccessControl* AccessControl(const Type* ty,
+ ast::AccessControl::Access ac);
+ /// @param n the alias name
+ /// @param t the aliased type
+ /// @return a Alias type. Repeated calls with the same arguments will return
+ /// the same pointer.
+ const spirv::Alias* Alias(Symbol n, const Type* t);
+ /// @param n the struct name
+ /// @param m the member types
+ /// @return a Struct type. Repeated calls with the same arguments will return
+ /// the same pointer.
+ const spirv::Struct* Struct(Symbol n, TypeList m);
+ /// @param k the sampler kind
+ /// @return a Sampler type. Repeated calls with the same arguments will return
+ /// the same pointer.
+ const spirv::Sampler* Sampler(ast::SamplerKind k);
+ /// @param d the texture dimensions
+ /// @return a DepthTexture type. Repeated calls with the same arguments will
+ /// return the same pointer.
+ const spirv::DepthTexture* DepthTexture(ast::TextureDimension d);
+ /// @param d the texture dimensions
+ /// @param t the multisampled texture type
+ /// @return a MultisampledTexture type. Repeated calls with the same arguments
+ /// will return the same pointer.
+ const spirv::MultisampledTexture* MultisampledTexture(ast::TextureDimension d,
+ const Type* t);
+ /// @param d the texture dimensions
+ /// @param t the sampled texture type
+ /// @return a SampledTexture type. Repeated calls with the same arguments will
+ /// return the same pointer.
+ const spirv::SampledTexture* SampledTexture(ast::TextureDimension d,
+ const Type* t);
+ /// @param d the texture dimensions
+ /// @param f the storage image format
+ /// @return a StorageTexture type. Repeated calls with the same arguments will
+ /// return the same pointer.
+ const spirv::StorageTexture* StorageTexture(ast::TextureDimension d,
+ ast::ImageFormat f);
+
+ private:
+ struct State;
+ std::unique_ptr<State> state;
+};
+
+} // namespace spirv
+} // namespace reader
+} // namespace tint
+
+#endif // SRC_READER_SPIRV_PARSER_TYPE_H_
diff --git a/src/reader/spirv/parser_type_test.cc b/src/reader/spirv/parser_type_test.cc
new file mode 100644
index 0000000..523e736
--- /dev/null
+++ b/src/reader/spirv/parser_type_test.cc
@@ -0,0 +1,104 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest.h"
+
+#include "src/reader/spirv/parser_type.h"
+
+namespace tint {
+namespace reader {
+namespace spirv {
+namespace {
+
+TEST(SpvParserTypeTest, SameArgumentsGivesSamePointer) {
+ Symbol sym(Symbol(1, {}));
+
+ TypeManager ty;
+ EXPECT_EQ(ty.Void(), ty.Void());
+ EXPECT_EQ(ty.Bool(), ty.Bool());
+ EXPECT_EQ(ty.U32(), ty.U32());
+ EXPECT_EQ(ty.F32(), ty.F32());
+ EXPECT_EQ(ty.I32(), ty.I32());
+ EXPECT_EQ(ty.Pointer(ty.I32(), ast::StorageClass::kNone),
+ ty.Pointer(ty.I32(), ast::StorageClass::kNone));
+ EXPECT_EQ(ty.Vector(ty.I32(), 3), ty.Vector(ty.I32(), 3));
+ EXPECT_EQ(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.I32(), 3, 2));
+ EXPECT_EQ(ty.Array(ty.I32(), 3, 2), ty.Array(ty.I32(), 3, 2));
+ EXPECT_EQ(ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly),
+ ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly));
+ EXPECT_EQ(ty.Alias(sym, ty.I32()), ty.Alias(sym, ty.I32()));
+ EXPECT_EQ(ty.Struct(sym, {ty.I32()}), ty.Struct(sym, {ty.I32()}));
+ EXPECT_EQ(ty.Sampler(ast::SamplerKind::kSampler),
+ ty.Sampler(ast::SamplerKind::kSampler));
+ EXPECT_EQ(ty.DepthTexture(ast::TextureDimension::k2d),
+ ty.DepthTexture(ast::TextureDimension::k2d));
+ EXPECT_EQ(ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32()),
+ ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32()));
+ EXPECT_EQ(ty.SampledTexture(ast::TextureDimension::k2d, ty.I32()),
+ ty.SampledTexture(ast::TextureDimension::k2d, ty.I32()));
+ EXPECT_EQ(
+ ty.StorageTexture(ast::TextureDimension::k2d, ast::ImageFormat::kR16Sint),
+ ty.StorageTexture(ast::TextureDimension::k2d,
+ ast::ImageFormat::kR16Sint));
+}
+
+TEST(SpvParserTypeTest, DifferentArgumentsGivesDifferentPointer) {
+ Symbol sym_a(Symbol(1, {}));
+ Symbol sym_b(Symbol(2, {}));
+
+ TypeManager ty;
+ EXPECT_NE(ty.Pointer(ty.I32(), ast::StorageClass::kNone),
+ ty.Pointer(ty.U32(), ast::StorageClass::kNone));
+ EXPECT_NE(ty.Pointer(ty.I32(), ast::StorageClass::kNone),
+ ty.Pointer(ty.I32(), ast::StorageClass::kInput));
+ EXPECT_NE(ty.Vector(ty.I32(), 3), ty.Vector(ty.U32(), 3));
+ EXPECT_NE(ty.Vector(ty.I32(), 3), ty.Vector(ty.I32(), 2));
+ EXPECT_NE(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.U32(), 3, 2));
+ EXPECT_NE(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.I32(), 2, 2));
+ EXPECT_NE(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.I32(), 3, 3));
+ EXPECT_NE(ty.Array(ty.I32(), 3, 2), ty.Array(ty.U32(), 3, 2));
+ EXPECT_NE(ty.Array(ty.I32(), 3, 2), ty.Array(ty.I32(), 2, 2));
+ EXPECT_NE(ty.Array(ty.I32(), 3, 2), ty.Array(ty.I32(), 3, 3));
+ EXPECT_NE(ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly),
+ ty.AccessControl(ty.U32(), ast::AccessControl::Access::kReadOnly));
+ EXPECT_NE(ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly),
+ ty.AccessControl(ty.I32(), ast::AccessControl::Access::kWriteOnly));
+ EXPECT_NE(ty.Alias(sym_a, ty.I32()), ty.Alias(sym_b, ty.I32()));
+ EXPECT_NE(ty.Struct(sym_a, {ty.I32()}), ty.Struct(sym_b, {ty.I32()}));
+ EXPECT_NE(ty.Sampler(ast::SamplerKind::kSampler),
+ ty.Sampler(ast::SamplerKind::kComparisonSampler));
+ EXPECT_NE(ty.DepthTexture(ast::TextureDimension::k2d),
+ ty.DepthTexture(ast::TextureDimension::k1d));
+ EXPECT_NE(ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32()),
+ ty.MultisampledTexture(ast::TextureDimension::k3d, ty.I32()));
+ EXPECT_NE(ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32()),
+ ty.MultisampledTexture(ast::TextureDimension::k2d, ty.U32()));
+ EXPECT_NE(ty.SampledTexture(ast::TextureDimension::k2d, ty.I32()),
+ ty.SampledTexture(ast::TextureDimension::k3d, ty.I32()));
+ EXPECT_NE(ty.SampledTexture(ast::TextureDimension::k2d, ty.I32()),
+ ty.SampledTexture(ast::TextureDimension::k2d, ty.U32()));
+ EXPECT_NE(
+ ty.StorageTexture(ast::TextureDimension::k2d, ast::ImageFormat::kR16Sint),
+ ty.StorageTexture(ast::TextureDimension::k3d,
+ ast::ImageFormat::kR16Sint));
+ EXPECT_NE(
+ ty.StorageTexture(ast::TextureDimension::k2d, ast::ImageFormat::kR16Sint),
+ ty.StorageTexture(ast::TextureDimension::k2d,
+ ast::ImageFormat::kR32Sint));
+}
+
+} // namespace
+} // namespace spirv
+} // namespace reader
+} // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn
index 740ce6b..b691997 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -349,6 +349,7 @@
"../src/reader/spirv/parser_impl_test_helper.cc",
"../src/reader/spirv/parser_impl_test_helper.h",
"../src/reader/spirv/parser_impl_user_name_test.cc",
+ "../src/reader/spirv/parser_type_test.cc",
"../src/reader/spirv/parser_test.cc",
"../src/reader/spirv/spirv_tools_helpers_test.cc",
"../src/reader/spirv/spirv_tools_helpers_test.h",