reader/spirv - add type hierarchy Don't create disjoint AST type nodes. Instead use a new bespoke type hierarchy that can Build() the required AST nodes. Change-Id: I523f97054de2c553095056c0bafc17c48064cf53 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49966 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: David Neto <dneto@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn index 0be18bd..4cb6563 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn
@@ -590,6 +590,8 @@ "reader/spirv/function.h", "reader/spirv/namer.cc", "reader/spirv/namer.h", + "reader/spirv/parser_type.cc", + "reader/spirv/parser_type.h", "reader/spirv/parser.cc", "reader/spirv/parser.h", "reader/spirv/parser_impl.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b067715..4b2467d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt
@@ -360,6 +360,8 @@ reader/spirv/function.h reader/spirv/namer.cc reader/spirv/namer.h + reader/spirv/parser_type.cc + reader/spirv/parser_type.h reader/spirv/parser.cc reader/spirv/parser.h reader/spirv/parser_impl.cc @@ -627,6 +629,7 @@ reader/spirv/parser_impl_test_helper.h reader/spirv/parser_impl_test.cc reader/spirv/parser_impl_user_name_test.cc + reader/spirv/parser_type_test.cc reader/spirv/parser_test.cc reader/spirv/spirv_tools_helpers_test.cc reader/spirv/spirv_tools_helpers_test.h
diff --git a/src/program_builder.h b/src/program_builder.h index b013ffb..41d3842 100644 --- a/src/program_builder.h +++ b/src/program_builder.h
@@ -617,26 +617,32 @@ /// @param subtype the array element type /// @param n the array size. 0 represents a runtime-array - /// @param stride the array stride + /// @param stride the array stride. 0 represents implicit stride /// @return the tint AST type for a array of size `n` of type `T` ast::Array* array(typ::Type subtype, uint32_t n, uint32_t stride) const { subtype = MaybeCreateTypename(subtype); - return array(subtype, n, - {builder->create<ast::StrideDecoration>(stride)}); + ast::DecorationList decos; + if (stride) { + decos.emplace_back(builder->create<ast::StrideDecoration>(stride)); + } + return array(subtype, n, std::move(decos)); } /// @param source the Source of the node /// @param subtype the array element type /// @param n the array size. 0 represents a runtime-array - /// @param stride the array stride + /// @param stride the array stride. 0 represents implicit stride /// @return the tint AST type for a array of size `n` of type `T` ast::Array* array(const Source& source, typ::Type subtype, uint32_t n, uint32_t stride) const { subtype = MaybeCreateTypename(subtype); - return array(source, subtype, n, - {builder->create<ast::StrideDecoration>(stride)}); + ast::DecorationList decos; + if (stride) { + decos.emplace_back(builder->create<ast::StrideDecoration>(stride)); + } + return array(source, subtype, n, std::move(decos)); } /// @return the tint AST type for an array of size `N` of type `T`
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 1cbbc34..e9c15c9 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc
@@ -713,6 +713,7 @@ const spvtools::opt::Function& function, const EntryPointInfo* ep_info) : parser_impl_(*pi), + ty_(pi->type_manager()), builder_(pi->builder()), ir_context_(*(pi->ir_context())), def_use_mgr_(ir_context_.get_def_use_mgr()), @@ -733,6 +734,7 @@ FunctionEmitter::FunctionEmitter(FunctionEmitter&& other) : parser_impl_(other.parser_impl_), + ty_(other.ty_), builder_(other.builder_), ir_context_(other.ir_context_), def_use_mgr_(ir_context_.get_def_use_mgr()), @@ -875,7 +877,7 @@ auto* body = create<ast::BlockStatement>(Source{}, statements); builder_.AST().AddFunction(create<ast::Function>( decl.source, builder_.Symbols().Register(decl.name), - std::move(decl.params), decl.return_type, body, + std::move(decl.params), decl.return_type->Build(builder_), body, std::move(decl.decorations), ast::DecorationList{})); // Maintain the invariant by repopulating the one and only element. @@ -943,7 +945,7 @@ return success(); } -ast::Type* FunctionEmitter::GetVariableStoreType( +const Type* FunctionEmitter::GetVariableStoreType( const spvtools::opt::Instruction& var_decl_inst) { const auto type_id = var_decl_inst.type_id(); auto* var_ref_type = type_mgr_->GetType(type_id); @@ -2042,7 +2044,7 @@ << id; return {}; case SkipReason::kPointSizeBuiltinValue: { - return {create<ast::F32>(), + return {ty_.F32(), create<ast::ScalarConstructorExpression>( Source{}, create<ast::FloatLiteral>(Source{}, 1.0f))}; } @@ -2072,7 +2074,7 @@ case SkipReason::kSampleMaskOutBuiltinPointer: // The result type is always u32. auto name = namer_.Name(sample_mask_out_id); - return TypedExpression{builder_.ty.u32(), + return TypedExpression{ty_.U32(), create<ast::IdentifierExpression>( Source{}, builder_.Symbols().Register(name))}; } @@ -2348,14 +2350,10 @@ const std::string guard_name = block_info.flow_guard_name; if (!guard_name.empty()) { // Declare the guard variable just before the "if", initialized to true. - auto* guard_var = create<ast::Variable>( - Source{}, // source - builder_.Symbols().Register(guard_name), // symbol - ast::StorageClass::kFunction, // storage_class - builder_.ty.bool_(), // type - false, // is_const - MakeTrue(Source{}), // constructor - ast::DecorationList{}); // decorations + auto* guard_var = + create<ast::Variable>(Source{}, builder_.Symbols().Register(guard_name), + ast::StorageClass::kFunction, builder_.ty.bool_(), + false, MakeTrue(Source{}), ast::DecorationList{}); auto* guard_decl = create<ast::VariableDeclStatement>(Source{}, guard_var); AddStatement(guard_decl); } @@ -2557,7 +2555,7 @@ // The rest of this module can handle up to 64 bit switch values. // The Tint AST handles 32-bit values. const uint32_t value32 = uint32_t(value & 0xFFFFFFFF); - if (selector.type->is_unsigned_scalar_or_vector()) { + if (selector.type->IsUnsignedScalarOrVector()) { selectors.emplace_back(create<ast::UintLiteral>(Source{}, value32)); } else { selectors.emplace_back(create<ast::SintLiteral>(Source{}, value32)); @@ -2914,13 +2912,10 @@ const auto phi_var_name = GetDefInfo(id)->phi_var; TINT_ASSERT(!phi_var_name.empty()); auto* var = create<ast::Variable>( - Source{}, // source - builder_.Symbols().Register(phi_var_name), // symbol - ast::StorageClass::kFunction, // storage_class - parser_impl_.ConvertType(def_inst->type_id()), // type - false, // is_const - nullptr, // constructor - ast::DecorationList{}); // decorations + Source{}, builder_.Symbols().Register(phi_var_name), + ast::StorageClass::kFunction, + parser_impl_.ConvertType(def_inst->type_id())->Build(builder_), false, + nullptr, ast::DecorationList{}); AddStatement(create<ast::VariableDeclStatement>(Source{}, var)); } @@ -3102,9 +3097,9 @@ case SkipReason::kSampleMaskOutBuiltinPointer: ptr_id = sample_mask_out_id; - if (!rhs.type->Is<ast::U32>()) { + if (!rhs.type->Is<U32>()) { // WGSL requires sample_mask_out to be signed. - rhs = TypedExpression{builder_.ty.u32(), + rhs = TypedExpression{ty_.U32(), create<ast::TypeConstructorExpression>( Source{}, builder_.ty.u32(), ast::ExpressionList{rhs.expr})}; @@ -3148,7 +3143,7 @@ ast::Expression* id_expr = create<ast::IdentifierExpression>( Source{}, builder_.Symbols().Register(name)); auto expr = TypedExpression{ - builder_.ty.i32(), + ty_.I32(), create<ast::TypeConstructorExpression>( Source{}, builder_.ty.i32(), ast::ExpressionList{id_expr})}; return EmitConstDefinition(inst, expr); @@ -3159,10 +3154,10 @@ Source{}, builder_.Symbols().Register(name)); auto* load_result_type = parser_impl_.ConvertType(inst.type_id()); ast::Expression* ast_expr = nullptr; - if (load_result_type->Is<ast::I32>()) { + if (load_result_type->Is<I32>()) { ast_expr = create<ast::TypeConstructorExpression>( Source{}, builder_.ty.i32(), ast::ExpressionList{id_expr}); - } else if (load_result_type->Is<ast::U32>()) { + } else if (load_result_type->Is<U32>()) { ast_expr = id_expr; } else { return Fail() << "loading the whole SampleMask input array is not " @@ -3181,8 +3176,8 @@ } // The load result type is the pointee type of its operand. - TINT_ASSERT(expr.type->Is<ast::Pointer>()); - expr.type = expr.type->As<ast::Pointer>()->type(); + TINT_ASSERT(expr.type->Is<Pointer>()); + expr.type = expr.type->As<Pointer>()->type; return EmitConstDefOrWriteToHoistedVar(inst, expr); } @@ -3284,7 +3279,7 @@ const auto opcode = inst.opcode(); - ast::Type* ast_type = + const Type* ast_type = inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr; auto binary_op = ConvertBinaryOp(opcode); @@ -3329,8 +3324,9 @@ } if (opcode == SpvOpBitcast) { - return {ast_type, create<ast::BitcastExpression>( - Source{}, ast_type, MakeOperand(inst, 0).expr)}; + return {ast_type, + create<ast::BitcastExpression>(Source{}, ast_type->Build(builder_), + MakeOperand(inst, 0).expr)}; } if (opcode == SpvOpShiftLeftLogical || opcode == SpvOpShiftRightLogical || @@ -3390,9 +3386,9 @@ for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) { operands.emplace_back(MakeOperand(inst, iarg).expr); } - return {ast_type, create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(ast_type), - std::move(operands))}; + return {ast_type, + create<ast::TypeConstructorExpression>( + Source{}, ast_type->Build(builder_), std::move(operands))}; } if (opcode == SpvOpCompositeExtract) { @@ -3457,7 +3453,7 @@ auto* func = create<ast::IdentifierExpression>( Source{}, builder_.Symbols().Register(name)); ast::ExpressionList operands; - ast::Type* first_operand_type = nullptr; + const Type* first_operand_type = nullptr; // All parameters to GLSL.std.450 extended instructions are IDs. for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) { TypedExpression operand = MakeOperand(inst, iarg); @@ -3703,7 +3699,7 @@ type_mgr_->FindPointerToType(pointee_type_id, storage_class); auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id); TINT_ASSERT(ast_pointer_type); - TINT_ASSERT(ast_pointer_type->Is<ast::Pointer>()); + TINT_ASSERT(ast_pointer_type->Is<Pointer>()); current_expr = TypedExpression{ast_pointer_type, next_expr}; } return current_expr; @@ -3887,8 +3883,8 @@ // Generate an ast::TypeConstructor expression. // Assume the literal indices are valid, and there is a valid number of them. auto source = GetSourceForInst(inst); - ast::Vector* result_type = - parser_impl_.ConvertType(inst.type_id())->As<ast::Vector>(); + const Vector* result_type = + As<Vector>(parser_impl_.ConvertType(inst.type_id())); ast::ExpressionList values; for (uint32_t i = 2; i < inst.NumInOperands(); ++i) { const auto index = inst.GetSingleWordInOperand(i); @@ -3910,16 +3906,15 @@ source, expr.expr, Swizzle(sub_index))); } else if (index == 0xFFFFFFFF) { // By rule, this maps to OpUndef. Instead, make it zero. - values.emplace_back(parser_impl_.MakeNullValue(result_type->type())); + values.emplace_back(parser_impl_.MakeNullValue(result_type->type)); } else { Fail() << "invalid vectorshuffle ID %" << inst.result_id() << ": index too large: " << index; return {}; } } - return {result_type, - create<ast::TypeConstructorExpression>( - source, builder_.ty.MaybeCreateTypename(result_type), values)}; + return {result_type, create<ast::TypeConstructorExpression>( + source, result_type->Build(builder_), values)}; } bool FunctionEmitter::RegisterSpecialBuiltInVariables() { @@ -3988,8 +3983,8 @@ if (type) { if (type->AsPointer()) { if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) { - if (auto* ptr = ast_type->As<ast::Pointer>()) { - info->storage_class = ptr->storage_class(); + if (auto* ptr = ast_type->As<Pointer>()) { + info->storage_class = ptr->storage_class; } } switch (inst.opcode()) { @@ -4033,21 +4028,21 @@ const auto type_id = def_use_mgr_->GetDef(id)->type_id(); if (type_id) { auto* ast_type = parser_impl_.ConvertType(type_id); - if (auto* ptr = As<ast::Pointer>(ast_type)) { - return ptr->storage_class(); + if (auto* ptr = As<Pointer>(ast_type)) { + return ptr->storage_class; } } return ast::StorageClass::kNone; } -ast::Type* FunctionEmitter::RemapStorageClass(ast::Type* type, - uint32_t result_id) { - if (auto* ast_ptr_type = type->As<ast::Pointer>()) { +const Type* FunctionEmitter::RemapStorageClass(const Type* type, + uint32_t result_id) { + if (auto* ast_ptr_type = As<Pointer>(type)) { // Remap an old-style storage buffer pointer to a new-style storage // buffer pointer. const auto sc = GetStorageClassForPointerValue(result_id); - if (ast_ptr_type->storage_class() != sc) { - return builder_.ty.pointer(ast_ptr_type->type(), sc); + if (ast_ptr_type->storage_class != sc) { + return ty_.Pointer(ast_ptr_type->type, sc); } } return type; @@ -4230,30 +4225,27 @@ return {}; } - ast::Type* expr_type = nullptr; + const Type* expr_type = nullptr; if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) { - if (arg_expr.type->is_integer_scalar_or_vector()) { + if (arg_expr.type->IsIntegerScalarOrVector()) { expr_type = requested_type; } else { Fail() << "operand for conversion to floating point must be integral " - "scalar or vector, but got: " - << arg_expr.type->type_name(); + "scalar or vector"; } } else if (inst.opcode() == SpvOpConvertFToU) { - if (arg_expr.type->is_float_scalar_or_vector()) { + if (arg_expr.type->IsFloatScalarOrVector()) { expr_type = parser_impl_.GetUnsignedIntMatchingShape(arg_expr.type); } else { Fail() << "operand for conversion to unsigned integer must be floating " - "point scalar or vector, but got: " - << arg_expr.type->type_name(); + "point scalar or vector"; } } else if (inst.opcode() == SpvOpConvertFToS) { - if (arg_expr.type->is_float_scalar_or_vector()) { + if (arg_expr.type->IsFloatScalarOrVector()) { expr_type = parser_impl_.GetSignedIntMatchingShape(arg_expr.type); } else { Fail() << "operand for conversion to signed integer must be floating " - "point scalar or vector, but got: " - << arg_expr.type->type_name(); + "point scalar or vector"; } } if (expr_type == nullptr) { @@ -4265,14 +4257,14 @@ params.push_back(arg_expr.expr); TypedExpression result{ expr_type, create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(expr_type), - std::move(params))}; + Source{}, expr_type->Build(builder_), std::move(params))}; - if (AstTypesEquivalent(requested_type, expr_type)) { + if (requested_type == expr_type) { return result; } - return {requested_type, create<ast::BitcastExpression>( - Source{}, requested_type, result.expr)}; + return {requested_type, + create<ast::BitcastExpression>( + Source{}, requested_type->Build(builder_), result.expr)}; } bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) { @@ -4296,7 +4288,7 @@ << inst.PrettyPrint(); } - if (result_type->Is<ast::Void>()) { + if (result_type->Is<Void>()) { return nullptr != AddStatement(create<ast::CallStatement>(Source{}, call_expr)); } @@ -4359,7 +4351,7 @@ Source{}, builder_.Symbols().Register(name)); ast::ExpressionList params; - ast::Type* first_operand_type = nullptr; + const Type* first_operand_type = nullptr; for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) { TypedExpression operand = MakeOperand(inst, iarg); if (first_operand_type == nullptr) { @@ -4391,8 +4383,8 @@ // - you can't select over pointers or pointer vectors, unless you also have // a VariablePointers* capability, which is not allowed in by WebGPU. auto* op_ty = operand1.type; - if (op_ty->Is<ast::Vector>() || op_ty->is_float_scalar() || - op_ty->is_integer_scalar() || op_ty->Is<ast::Bool>()) { + if (op_ty->Is<Vector>() || op_ty->IsFloatScalar() || + op_ty->IsIntegerScalar() || op_ty->Is<Bool>()) { ast::ExpressionList params; params.push_back(operand1.expr); params.push_back(operand2.expr); @@ -4430,9 +4422,9 @@ return image; } -ast::Texture* FunctionEmitter::GetImageType( +const Texture* FunctionEmitter::GetImageType( const spvtools::opt::Instruction& image) { - ast::Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image); + const Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image); if (!parser_impl_.success()) { Fail(); return {}; @@ -4441,7 +4433,7 @@ Fail() << "invalid texture type for " << image.PrettyPrint(); return {}; } - auto* result = ptr_type->type()->UnwrapAll()->As<ast::Texture>(); + auto* result = ptr_type->type->UnwrapAll()->As<Texture>(); if (!result) { Fail() << "invalid texture type for " << image.PrettyPrint(); return {}; @@ -4496,12 +4488,12 @@ } } - ast::Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image); + const Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image); if (!texture_ptr_type) { return Fail(); } - ast::Texture* texture_type = - texture_ptr_type->type()->UnwrapAll()->As<ast::Texture>(); + const Texture* texture_type = + texture_ptr_type->type->UnwrapAll()->As<Texture>(); if (!texture_type) { return Fail(); @@ -4604,7 +4596,7 @@ } TypedExpression lod = MakeOperand(inst, arg_index); // When sampling from a depth texture, the Lod operand must be an I32. - if (texture_type->Is<ast::DepthTexture>()) { + if (texture_type->Is<DepthTexture>()) { // Convert it to a signed integer type. lod = ToI32(lod); } @@ -4612,11 +4604,11 @@ image_operands_mask ^= SpvImageOperandsLodMask; arg_index++; } else if ((opcode == SpvOpImageFetch) && - (texture_type->Is<ast::SampledTexture>() || - texture_type->Is<ast::DepthTexture>())) { + (texture_type->Is<SampledTexture>() || + texture_type->Is<DepthTexture>())) { // textureLoad on sampled texture and depth texture requires an explicit // level-of-detail parameter. - params.push_back(parser_impl_.MakeNullValue(builder_.ty.i32())); + params.push_back(parser_impl_.MakeNullValue(ty_.I32())); } if (arg_index + 1 < num_args && (image_operands_mask & SpvImageOperandsGradMask)) { @@ -4637,7 +4629,7 @@ return Fail() << "ConstOffset is only permitted for sampling operations: " << inst.PrettyPrint(); } - switch (texture_type->dim()) { + switch (texture_type->dims) { case ast::TextureDimension::k2d: case ast::TextureDimension::k2dArray: case ast::TextureDimension::k3d: @@ -4676,8 +4668,8 @@ // The result type, derived from the SPIR-V instruction. auto* result_type = parser_impl_.ConvertType(inst.type_id()); auto* result_component_type = result_type; - if (auto* result_vector_type = result_type->As<ast::Vector>()) { - result_component_type = result_vector_type->type(); + if (auto* result_vector_type = As<Vector>(result_type)) { + result_component_type = result_vector_type->type; } // For depth textures, the arity might mot match WGSL: @@ -4691,11 +4683,11 @@ // dref gather vec4 ImageFetch vec4 TODO(dneto) // Construct a 4-element vector with the result from the builtin in the // first component. - if (texture_type->Is<ast::DepthTexture>()) { + if (texture_type->Is<DepthTexture>()) { if (is_non_dref_sample || (opcode == SpvOpImageFetch)) { value = create<ast::TypeConstructorExpression>( Source{}, - builder_.ty.MaybeCreateTypename(result_type), // a vec4 + result_type->Build(builder_), // a vec4 ast::ExpressionList{ value, parser_impl_.MakeNullValue(result_component_type), parser_impl_.MakeNullValue(result_component_type), @@ -4714,13 +4706,13 @@ } auto* expected_component_type = parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0)); - if (!AstTypesEquivalent(expected_component_type, result_component_type)) { + if (expected_component_type != result_component_type) { // This occurs if one is signed integer and the other is unsigned integer, // or vice versa. Perform a bitcast. - value = create<ast::BitcastExpression>(Source{}, result_type, call_expr); + value = create<ast::BitcastExpression>( + Source{}, result_type->Build(builder_), call_expr); } - if (!expected_component_type->Is<ast::F32>() && - IsSampledImageAccess(opcode)) { + if (!expected_component_type->Is<F32>() && IsSampledImageAccess(opcode)) { // WGSL permits sampled image access only on float textures. // Reject this case in the SPIR-V reader, at least until SPIR-V validation // catches up with this rule and can reject it earlier in the workflow. @@ -4763,7 +4755,7 @@ } exprs.push_back( create<ast::CallExpression>(Source{}, dims_ident, dims_args)); - if (ast::IsTextureArray(texture_type->dim())) { + if (ast::IsTextureArray(texture_type->dims)) { auto* layers_ident = create<ast::IdentifierExpression>( Source{}, builder_.Symbols().Register("textureNumLayers")); exprs.push_back(create<ast::CallExpression>( @@ -4772,9 +4764,8 @@ } auto* result_type = parser_impl_.ConvertType(inst.type_id()); TypedExpression expr = { - result_type, - create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(result_type), exprs)}; + result_type, create<ast::TypeConstructorExpression>( + Source{}, result_type->Build(builder_), exprs)}; return EmitConstDefOrWriteToHoistedVar(inst, expr); } case SpvOpImageQueryLod: @@ -4794,9 +4785,9 @@ auto* result_type = parser_impl_.ConvertType(inst.type_id()); // The SPIR-V result type must be integer scalar. The WGSL bulitin // returns i32. If they aren't the same then convert the result. - if (!result_type->Is<ast::I32>()) { + if (!result_type->Is<I32>()) { ast_expr = create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(result_type), + Source{}, result_type->Build(builder_), ast::ExpressionList{ast_expr}); } TypedExpression expr{result_type, ast_expr}; @@ -4840,28 +4831,27 @@ if (!raw_coords.type) { return {}; } - ast::Texture* texture_type = GetImageType(*image); + const Texture* texture_type = GetImageType(*image); if (!texture_type) { return {}; } - ast::TextureDimension dim = texture_type->dim(); + ast::TextureDimension dim = texture_type->dims; // Number of regular coordinates. uint32_t num_axes = ast::NumCoordinateAxes(dim); bool is_arrayed = ast::IsTextureArray(dim); if ((num_axes == 0) || (num_axes > 3)) { Fail() << "unsupported image dimensionality for " - << texture_type->type_name() << " prompted by " + << texture_type->TypeInfo().name << " prompted by " << inst.PrettyPrint(); } const auto num_coords_required = num_axes + (is_arrayed ? 1 : 0); uint32_t num_coords_supplied = 0; auto* component_type = raw_coords.type; - if (component_type->is_float_scalar() || - component_type->is_integer_scalar()) { + if (component_type->IsFloatScalar() || component_type->IsIntegerScalar()) { num_coords_supplied = 1; - } else if (auto* vec_type = raw_coords.type->As<ast::Vector>()) { - component_type = vec_type->type(); - num_coords_supplied = vec_type->size(); + } else if (auto* vec_type = As<Vector>(raw_coords.type)) { + component_type = vec_type->type; + num_coords_supplied = vec_type->size; } if (num_coords_supplied == 0) { Fail() << "bad or unsupported coordinate type for image access: " @@ -4884,10 +4874,8 @@ // will actually use them. auto prefix_swizzle_expr = [this, num_axes, component_type, raw_coords]() -> ast::Expression* { - auto* swizzle_type = (num_axes == 1) - ? component_type - : static_cast<ast::Type*>( - builder_.ty.vec(component_type, num_axes)); + auto* swizzle_type = + (num_axes == 1) ? component_type : ty_.Vector(component_type, num_axes); auto* swizzle = create<ast::MemberAccessorExpression>( Source{}, raw_coords.expr, PrefixSwizzle(num_axes)); return ToSignedIfUnsigned({swizzle_type, swizzle}).expr; @@ -4921,32 +4909,32 @@ ast::Expression* FunctionEmitter::ConvertTexelForStorage( const spvtools::opt::Instruction& inst, TypedExpression texel, - ast::Texture* texture_type) { - auto* storage_texture_type = texture_type->As<ast::StorageTexture>(); + const Texture* texture_type) { + auto* storage_texture_type = As<StorageTexture>(texture_type); auto* src_type = texel.type; if (!storage_texture_type) { Fail() << "writing to other than storage texture: " << inst.PrettyPrint(); return nullptr; } - const auto format = storage_texture_type->image_format(); + const auto format = storage_texture_type->format; auto* dest_type = parser_impl_.GetTexelTypeForFormat(format); if (!dest_type) { Fail(); return nullptr; } - if (AstTypesEquivalent(src_type, dest_type)) { + if (src_type == dest_type) { return texel.expr; } const uint32_t dest_count = - dest_type->is_scalar() ? 1 : dest_type->As<ast::Vector>()->size(); + dest_type->IsScalar() ? 1 : dest_type->As<Vector>()->size; if (dest_count == 3) { Fail() << "3-channel storage textures are not supported: " << inst.PrettyPrint(); return nullptr; } const uint32_t src_count = - src_type->is_scalar() ? 1 : src_type->As<ast::Vector>()->size(); + src_type->IsScalar() ? 1 : src_type->As<Vector>()->size; if (src_count < dest_count) { Fail() << "texel has too few components for storage texture: " << src_count << " provided but " << dest_count @@ -4961,29 +4949,29 @@ : create<ast::MemberAccessorExpression>(Source{}, texel.expr, PrefixSwizzle(dest_count)); - if (!(dest_type->is_float_scalar_or_vector() || - dest_type->is_unsigned_scalar_or_vector() || - dest_type->is_signed_scalar_or_vector())) { + if (!(dest_type->IsFloatScalarOrVector() || + dest_type->IsUnsignedScalarOrVector() || + dest_type->IsSignedScalarOrVector())) { Fail() << "invalid destination type for storage texture write: " - << dest_type->type_name(); + << dest_type->TypeInfo().name; return nullptr; } - if (!(src_type->is_float_scalar_or_vector() || - src_type->is_unsigned_scalar_or_vector() || - src_type->is_signed_scalar_or_vector())) { + if (!(src_type->IsFloatScalarOrVector() || + src_type->IsUnsignedScalarOrVector() || + src_type->IsSignedScalarOrVector())) { Fail() << "invalid texel type for storage texture write: " << inst.PrettyPrint(); return nullptr; } - if (dest_type->is_float_scalar_or_vector() && - !src_type->is_float_scalar_or_vector()) { + if (dest_type->IsFloatScalarOrVector() && + !src_type->IsFloatScalarOrVector()) { Fail() << "can only write float or float vector to a storage image with " "floating texel format: " << inst.PrettyPrint(); return nullptr; } - if (!dest_type->is_float_scalar_or_vector() && - src_type->is_float_scalar_or_vector()) { + if (!dest_type->IsFloatScalarOrVector() && + src_type->IsFloatScalarOrVector()) { Fail() << "float or float vector can only be written to a storage image with " "floating texel format: " @@ -4991,36 +4979,37 @@ return nullptr; } - if (dest_type->is_float_scalar_or_vector()) { + if (dest_type->IsFloatScalarOrVector()) { return texel_prefix; } // The only remaining cases are signed/unsigned source, and signed/unsigned // destination. - if (dest_type->is_unsigned_scalar_or_vector() == - src_type->is_unsigned_scalar_or_vector()) { + if (dest_type->IsUnsignedScalarOrVector() == + src_type->IsUnsignedScalarOrVector()) { return texel_prefix; } // We must do a bitcast conversion. - return create<ast::BitcastExpression>(Source{}, dest_type, texel_prefix); + return create<ast::BitcastExpression>(Source{}, dest_type->Build(builder_), + texel_prefix); } TypedExpression FunctionEmitter::ToI32(TypedExpression value) { - if (!value.type || value.type->Is<ast::I32>()) { + if (!value.type || value.type->Is<I32>()) { return value; } - return {builder_.ty.i32(), + return {ty_.I32(), create<ast::TypeConstructorExpression>( Source{}, builder_.ty.i32(), ast::ExpressionList{value.expr})}; } TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) { - if (!value.type || !value.type->is_unsigned_scalar_or_vector()) { + if (!value.type || !value.type->IsUnsignedScalarOrVector()) { return value; } - if (auto* vec_type = value.type->As<ast::Vector>()) { - auto new_type = builder_.ty.vec(builder_.ty.i32(), vec_type->size()); - return {new_type, - builder_.Construct(new_type, ast::ExpressionList{value.expr})}; + if (auto* vec_type = value.type->As<Vector>()) { + auto* new_type = ty_.Vector(ty_.I32(), vec_type->size); + return {new_type, builder_.Construct(new_type->Build(builder_), + ast::ExpressionList{value.expr})}; } return ToI32(value); } @@ -5073,14 +5062,12 @@ // Synthesize the result. auto col = MakeOperand(inst, 0); auto row = MakeOperand(inst, 1); - auto* col_ty = col.type->As<ast::Vector>(); - auto* row_ty = row.type->As<ast::Vector>(); - auto* result_ty = parser_impl_.ConvertType(inst.type_id())->As<ast::Matrix>(); - if (!col_ty || !col_ty || !result_ty || - !AstTypesEquivalent(result_ty->type(), col_ty->type()) || - !AstTypesEquivalent(result_ty->type(), row_ty->type()) || - result_ty->columns() != row_ty->size() || - result_ty->rows() != col_ty->size()) { + auto* col_ty = As<Vector>(col.type); + auto* row_ty = As<Vector>(row.type); + auto* result_ty = As<Matrix>(parser_impl_.ConvertType(inst.type_id())); + if (!col_ty || !col_ty || !result_ty || result_ty->type != col_ty->type || + result_ty->type != row_ty->type || result_ty->columns != row_ty->size || + result_ty->rows != col_ty->size) { Fail() << "invalid outer product instruction: bad types " << inst.PrettyPrint(); return {}; @@ -5096,11 +5083,11 @@ // | c.z * r.x c.z * r.y | ast::ExpressionList result_columns; - for (uint32_t icol = 0; icol < result_ty->columns(); icol++) { + for (uint32_t icol = 0; icol < result_ty->columns; icol++) { ast::ExpressionList result_row; auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr, Swizzle(icol)); - for (uint32_t irow = 0; irow < result_ty->rows(); irow++) { + for (uint32_t irow = 0; irow < result_ty->rows; irow++) { auto* column_factor = create<ast::MemberAccessorExpression>( Source{}, col.expr, Swizzle(irow)); auto* elem = create<ast::BinaryExpression>( @@ -5108,11 +5095,10 @@ result_row.push_back(elem); } result_columns.push_back(create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(col_ty), result_row)); + Source{}, col_ty->Build(builder_), result_row)); } return {result_ty, create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(result_ty), - result_columns)}; + Source{}, result_ty->Build(builder_), result_columns)}; } bool FunctionEmitter::MakeVectorInsertDynamic( @@ -5142,8 +5128,7 @@ auto* temp_var = create<ast::Variable>( Source{}, registered_temp_name, ast::StorageClass::kFunction, - builder_.ty.MaybeCreateTypename(ast_type), false, src_vector.expr, - ast::DecorationList{}); + ast_type->Build(builder_), false, src_vector.expr, ast::DecorationList{}); AddStatement(create<ast::VariableDeclStatement>(Source{}, temp_var)); auto* lhs = create<ast::ArrayAccessorExpression>( @@ -5189,7 +5174,7 @@ auto* temp_var = create<ast::Variable>( Source{}, registered_temp_name, ast::StorageClass::kFunction, - builder_.ty.MaybeCreateTypename(ast_type), false, src_composite.expr, + ast_type->Build(builder_), false, src_composite.expr, ast::DecorationList{}); AddStatement(create<ast::VariableDeclStatement>(Source{}, temp_var));
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 1f41bd1..64c2a93 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h
@@ -515,7 +515,7 @@ /// @param type the AST type /// @param result_id the SPIR-V ID for the locally defined value /// @returns an possibly updated type - ast::Type* RemapStorageClass(ast::Type* type, uint32_t result_id); + const Type* RemapStorageClass(const Type* type, uint32_t result_id); /// Marks locally defined values when they should get a 'const' /// definition in WGSL, or a 'var' definition at an outer scope. @@ -856,7 +856,7 @@ /// Function parameters ast::VariableList params; /// Function return type - ast::Type* return_type; + const Type* return_type; /// Function decorations ast::DecorationList decorations; }; @@ -869,7 +869,7 @@ /// @returns the store type for the OpVariable instruction, or /// null on failure. - ast::Type* GetVariableStoreType( + const Type* GetVariableStoreType( const spvtools::opt::Instruction& var_decl_inst); /// Returns an expression for an instruction operand. Signedness conversion is @@ -937,7 +937,7 @@ /// Get the AST texture the SPIR-V image memory object declaration. /// @param inst the SPIR-V memory object declaration for the image. /// @returns a texture type, or null on error - ast::Texture* GetImageType(const spvtools::opt::Instruction& inst); + const Texture* GetImageType(const spvtools::opt::Instruction& inst); /// Get the expression for the image operand from the first operand to the /// given instruction. @@ -974,7 +974,7 @@ ast::Expression* ConvertTexelForStorage( const spvtools::opt::Instruction& inst, TypedExpression texel, - ast::Texture* texture_type); + const Texture* texture_type); /// Returns an expression for an OpSelect, if its operands are scalars /// or vectors. These translate directly to WGSL select. Otherwise, return @@ -1128,6 +1128,7 @@ using StatementsStack = std::vector<StatementBlock>; ParserImpl& parser_impl_; + TypeManager& ty_; ProgramBuilder& builder_; spvtools::opt::IRContext& ir_context_; spvtools::opt::analysis::DefUseManager* def_use_mgr_;
diff --git a/src/reader/spirv/function_conversion_test.cc b/src/reader/spirv/function_conversion_test.cc index 0d18996..8dc4999 100644 --- a/src/reader/spirv/function_conversion_test.cc +++ b/src/reader/spirv/function_conversion_test.cc
@@ -202,7 +202,7 @@ EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), HasSubstr("operand for conversion to floating point must be " - "integral scalar or vector, but got: __bool")); + "integral scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertSToF_Vector_BadArgType) { @@ -220,7 +220,7 @@ EXPECT_THAT( p->error(), HasSubstr("operand for conversion to floating point must be integral " - "scalar or vector, but got: __vec_2__bool")); + "scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertSToF_Scalar_FromSigned) { @@ -344,7 +344,7 @@ auto fe = p->function_emitter(100); EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("operand for conversion to floating point must be " - "integral scalar or vector, but got: __bool")); + "integral scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertUToF_Vector_BadArgType) { @@ -361,7 +361,7 @@ EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("operand for conversion to floating point must be integral " - "scalar or vector, but got: __vec_2__bool")); + "scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertUToF_Scalar_FromSigned) { @@ -486,7 +486,7 @@ EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("operand for conversion to signed integer must be floating " - "point scalar or vector, but got: __u32")); + "point scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertFToS_Vector_BadArgType) { @@ -503,7 +503,7 @@ EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("operand for conversion to signed integer must be floating " - "point scalar or vector, but got: __vec_2__bool")); + "point scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertFToS_Scalar_ToSigned) { @@ -628,7 +628,7 @@ EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("operand for conversion to unsigned integer must be floating " - "point scalar or vector, but got: __u32")); + "point scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertFToU_Vector_BadArgType) { @@ -645,7 +645,7 @@ EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("operand for conversion to unsigned integer must be floating " - "point scalar or vector, but got: __vec_2__bool")); + "point scalar or vector")); } TEST_F(SpvUnaryConversionTest, ConvertFToU_Scalar_ToSigned_IsError) {
diff --git a/src/reader/spirv/function_memory_test.cc b/src/reader/spirv/function_memory_test.cc index 7121502..696e09f 100644 --- a/src/reader/spirv/function_memory_test.cc +++ b/src/reader/spirv/function_memory_test.cc
@@ -836,7 +836,7 @@ Struct S { [[block]] StructMember{[[ offset 0 ]] field0: __u32} - StructMember{[[ offset 4 ]] field1: __alias_RTArr__array__u32_stride_4} + StructMember{[[ offset 4 ]] field1: __type_name_RTArr} } Variable{ Decorations{
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index e1f6484..368b257 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc
@@ -240,7 +240,7 @@ TypedExpression& TypedExpression::operator=(const TypedExpression&) = default; -TypedExpression::TypedExpression(ast::Type* type_in, ast::Expression* expr_in) +TypedExpression::TypedExpression(const Type* type_in, ast::Expression* expr_in) : type(type_in), expr(expr_in) {} ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary) @@ -305,7 +305,7 @@ return tint::Program(std::move(builder_)); } -ast::Type* ParserImpl::ConvertType(uint32_t type_id) { +const Type* ParserImpl::ConvertType(uint32_t type_id) { if (!success_) { return nullptr; } @@ -322,18 +322,18 @@ } auto maybe_generate_alias = [this, type_id, - spirv_type](ast::Type* type) -> ast::Type* { + spirv_type](const Type* type) -> const Type* { if (type != nullptr) { return MaybeGenerateAlias(type_id, spirv_type, type); } - return {}; + return type; }; switch (spirv_type->kind()) { case spvtools::opt::analysis::Type::kVoid: - return maybe_generate_alias(builder_.ty.void_()); + return maybe_generate_alias(ty_.Void()); case spvtools::opt::analysis::Type::kBool: - return maybe_generate_alias(builder_.ty.bool_()); + return maybe_generate_alias(ty_.Bool()); case spvtools::opt::analysis::Type::kInteger: return maybe_generate_alias(ConvertType(spirv_type->AsInteger())); case spvtools::opt::analysis::Type::kFloat: @@ -362,7 +362,7 @@ case spvtools::opt::analysis::Type::kImage: // Fake it for sampler and texture types. These are handled in an // entirely different way. - return maybe_generate_alias(builder_.ty.void_()); + return maybe_generate_alias(ty_.Void()); default: break; } @@ -774,36 +774,36 @@ return success_; } -ast::Type* ParserImpl::ConvertType( +const Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Integer* int_ty) { if (int_ty->width() == 32) { - return int_ty->IsSigned() ? static_cast<ast::Type*>(builder_.ty.i32()) - : static_cast<ast::Type*>(builder_.ty.u32()); + return int_ty->IsSigned() ? static_cast<const Type*>(ty_.I32()) + : static_cast<const Type*>(ty_.U32()); } Fail() << "unhandled integer width: " << int_ty->width(); return nullptr; } -ast::Type* ParserImpl::ConvertType( +const Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Float* float_ty) { if (float_ty->width() == 32) { - return builder_.ty.f32(); + return ty_.F32(); } Fail() << "unhandled float width: " << float_ty->width(); return nullptr; } -ast::Type* ParserImpl::ConvertType( +const Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Vector* vec_ty) { const auto num_elem = vec_ty->element_count(); auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type())); if (ast_elem_ty == nullptr) { - return nullptr; + return ast_elem_ty; } - return builder_.ty.vec(ast_elem_ty, num_elem); + return ty_.Vector(ast_elem_ty, num_elem); } -ast::Type* ParserImpl::ConvertType( +const Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Matrix* mat_ty) { const auto* vec_ty = mat_ty->element_type()->AsVector(); const auto* scalar_ty = vec_ty->element_type(); @@ -813,23 +813,23 @@ if (ast_scalar_ty == nullptr) { return nullptr; } - return builder_.ty.mat(ast_scalar_ty, num_columns, num_rows); + return ty_.Matrix(ast_scalar_ty, num_columns, num_rows); } -ast::Type* ParserImpl::ConvertType( +const Type* ParserImpl::ConvertType( const spvtools::opt::analysis::RuntimeArray* rtarr_ty) { auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type())); if (ast_elem_ty == nullptr) { return nullptr; } - ast::DecorationList decorations; - if (!ParseArrayDecorations(rtarr_ty, &decorations)) { + uint32_t array_stride = 0; + if (!ParseArrayDecorations(rtarr_ty, &array_stride)) { return nullptr; } - return builder_.ty.array(ast_elem_ty, 0, std::move(decorations)); + return ty_.Array(ast_elem_ty, 0, array_stride); } -ast::Type* ParserImpl::ConvertType( +const Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Array* arr_ty) { const auto elem_type_id = type_mgr_->GetId(arr_ty->element_type()); auto* ast_elem_ty = ConvertType(elem_type_id); @@ -863,21 +863,19 @@ << num_elem; return nullptr; } - ast::DecorationList decorations; - if (!ParseArrayDecorations(arr_ty, &decorations)) { + uint32_t array_stride = 0; + if (!ParseArrayDecorations(arr_ty, &array_stride)) { return nullptr; } - if (remap_buffer_block_type_.count(elem_type_id)) { remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty)); } - return builder_.ty.array(ast_elem_ty, static_cast<uint32_t>(num_elem), - std::move(decorations)); + return ty_.Array(ast_elem_ty, static_cast<uint32_t>(num_elem), array_stride); } bool ParserImpl::ParseArrayDecorations( const spvtools::opt::analysis::Type* spv_type, - ast::DecorationList* decorations) { + uint32_t* array_stride) { bool has_array_stride = false; const auto type_id = type_mgr_->GetId(spv_type); for (auto& decoration : this->GetDecorationsFor(type_id)) { @@ -892,7 +890,7 @@ << ": multiple ArrayStride decorations"; } has_array_stride = true; - decorations->push_back(create<ast::StrideDecoration>(Source{}, stride)); + *array_stride = stride; } else { return Fail() << "invalid array type ID " << type_id << ": unknown decoration " @@ -904,7 +902,7 @@ return true; } -ast::Type* ParserImpl::ConvertType( +const Type* ParserImpl::ConvertType( uint32_t type_id, const spvtools::opt::analysis::Struct* struct_ty) { // Compute the struct decoration. @@ -930,6 +928,7 @@ // Compute members ast::StructMemberList ast_members; const auto members = struct_ty->element_types(); + TypeList ast_member_types; unsigned num_non_writable_members = 0; for (uint32_t member_index = 0; member_index < members.size(); ++member_index) { @@ -940,6 +939,8 @@ return nullptr; } + ast_member_types.emplace_back(ast_member_ty); + // Scan member for built-in decorations. Some vertex built-ins are handled // specially, and should not generate a structure member. bool create_ast_member = true; @@ -1003,8 +1004,8 @@ } const auto member_name = namer_.GetMemberName(type_id, member_index); auto* ast_struct_member = create<ast::StructMember>( - Source{}, builder_.Symbols().Register(member_name), ast_member_ty, - std::move(ast_member_decorations)); + Source{}, builder_.Symbols().Register(member_name), + ast_member_ty->Build(builder_), std::move(ast_member_decorations)); ast_members.push_back(ast_struct_member); } @@ -1030,7 +1031,7 @@ read_only_struct_types_.insert(ast_struct->name()); } AddConstructedType(sym, ast_struct); - return ast_struct; + return ty_.Struct(sym, std::move(ast_member_types)); } void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) { @@ -1040,8 +1041,8 @@ } } -ast::Type* ParserImpl::ConvertType(uint32_t type_id, - const spvtools::opt::analysis::Pointer*) { +const Type* ParserImpl::ConvertType(uint32_t type_id, + const spvtools::opt::analysis::Pointer*) { const auto* inst = def_use_mgr_->GetDef(type_id); const auto pointee_type_id = inst->GetSingleWordInOperand(1); const auto storage_class = SpvStorageClass(inst->GetSingleWordInOperand(0)); @@ -1080,8 +1081,7 @@ } } - ast_elem_ty = builder_.ty.MaybeCreateTypename(ast_elem_ty); - return builder_.ty.pointer(ast_elem_ty, ast_storage_class); + return ty_.Pointer(ast_elem_ty, ast_storage_class); } bool ParserImpl::RegisterTypes() { @@ -1114,7 +1114,7 @@ // that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant. for (auto& inst : module_->types_values()) { // These will be populated for a valid scalar spec constant. - ast::Type* ast_type = nullptr; + const Type* ast_type = nullptr; ast::ScalarConstructorExpression* ast_expr = nullptr; switch (inst.opcode()) { @@ -1129,15 +1129,15 @@ case SpvOpSpecConstant: { ast_type = ConvertType(inst.type_id()); const uint32_t literal_value = inst.GetSingleWordInOperand(0); - if (ast_type->Is<ast::I32>()) { + if (ast_type->Is<I32>()) { ast_expr = create<ast::ScalarConstructorExpression>( Source{}, create<ast::SintLiteral>( Source{}, static_cast<int32_t>(literal_value))); - } else if (ast_type->Is<ast::U32>()) { + } else if (ast_type->Is<U32>()) { ast_expr = create<ast::ScalarConstructorExpression>( Source{}, create<ast::UintLiteral>( Source{}, static_cast<uint32_t>(literal_value))); - } else if (ast_type->Is<ast::F32>()) { + } else if (ast_type->Is<F32>()) { float float_value; // Copy the bits so we can read them as a float. std::memcpy(&float_value, &literal_value, sizeof(float_value)); @@ -1173,12 +1173,12 @@ return success_; } -ast::Type* ParserImpl::MaybeGenerateAlias( +const Type* ParserImpl::MaybeGenerateAlias( uint32_t type_id, const spvtools::opt::analysis::Type* type, - ast::Type* ast_type) { + const Type* ast_type) { if (!success_) { - return {}; + return nullptr; } // We only care about arrays, and runtime arrays. @@ -1202,16 +1202,17 @@ auto* ast_underlying_type = ast_type; if (ast_underlying_type == nullptr) { Fail() << "internal error: no type registered for SPIR-V ID: " << type_id; - return {}; + return nullptr; } const auto name = namer_.GetName(type_id); const auto sym = builder_.Symbols().Register(name); - auto ast_alias_type = builder_.ty.alias(sym, ast_underlying_type); + auto ast_alias_type = + builder_.ty.alias(sym, ast_underlying_type->Build(builder_)); // Record this new alias as the AST type for this SPIR-V ID. AddConstructedType(sym, ast_alias_type); - return ast_alias_type; + return ty_.Alias(sym, ast_underlying_type); } bool ParserImpl::EmitModuleScopeVariables() { @@ -1252,7 +1253,7 @@ if (!success_) { return false; } - ast::Type* ast_type; + const Type* ast_type = nullptr; if (spirv_storage_class == SpvStorageClassUniformConstant) { // These are opaque handles: samplers or textures ast_type = GetTypeForHandleVar(var); @@ -1266,14 +1267,14 @@ "SPIR-V type with ID: " << var.type_id(); } - if (!ast_type->Is<ast::Pointer>()) { + if (!ast_type->Is<Pointer>()) { return Fail() << "variable with ID " << var.result_id() << " has non-pointer type " << var.type_id(); } } - auto* ast_store_type = ast_type->As<ast::Pointer>()->type(); - auto ast_storage_class = ast_type->As<ast::Pointer>()->storage_class(); + auto* ast_store_type = ast_type->As<Pointer>()->type; + auto ast_storage_class = ast_type->As<Pointer>()->storage_class; ast::Expression* ast_constructor = nullptr; if (var.NumInOperands() > 1) { // SPIR-V initializers are always constants. @@ -1336,7 +1337,7 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id, ast::StorageClass sc, - ast::Type* type, + const Type* type, bool is_const, ast::Expression* constructor, ast::DecorationList decorations) { @@ -1347,14 +1348,14 @@ if (sc == ast::StorageClass::kStorage) { bool read_only = false; - if (auto* tn = type->As<ast::TypeName>()) { - read_only = read_only_struct_types_.count(tn->name()) > 0; + if (auto* tn = type->As<Named>()) { + read_only = read_only_struct_types_.count(tn->name) > 0; } // Apply the access(read) or access(read_write) modifier. auto access = read_only ? ast::AccessControl::kReadOnly : ast::AccessControl::kReadWrite; - type = builder_.ty.access(access, type); + type = ty_.AccessControl(type, access); } for (auto& deco : GetDecorationsFor(id)) { @@ -1396,7 +1397,7 @@ "SampleMask must be an array of 1 element."; } special_builtins_[id] = spv_builtin; - type = builder_.ty.u32(); + type = ty_.U32(); break; } default: @@ -1439,8 +1440,8 @@ std::string name = namer_.Name(id); return create<ast::Variable>(Source{}, builder_.Symbols().Register(name), sc, - builder_.ty.MaybeCreateTypename(type), is_const, - constructor, decorations); + type->Build(builder_), is_const, constructor, + decorations); } TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { @@ -1476,27 +1477,27 @@ // So canonicalization should map that way too. // Currently "null<type>" is missing from the WGSL parser. // See https://bugs.chromium.org/p/tint/issues/detail?id=34 - if (ast_type->Is<ast::U32>()) { - return {ast_type, create<ast::ScalarConstructorExpression>( - Source{}, create<ast::UintLiteral>( - source, spirv_const->GetU32()))}; + if (ast_type->Is<U32>()) { + return {ty_.U32(), create<ast::ScalarConstructorExpression>( + Source{}, create<ast::UintLiteral>( + source, spirv_const->GetU32()))}; } - if (ast_type->Is<ast::I32>()) { - return {ast_type, create<ast::ScalarConstructorExpression>( - Source{}, create<ast::SintLiteral>( - source, spirv_const->GetS32()))}; + if (ast_type->Is<I32>()) { + return {ty_.I32(), create<ast::ScalarConstructorExpression>( + Source{}, create<ast::SintLiteral>( + source, spirv_const->GetS32()))}; } - if (ast_type->Is<ast::F32>()) { - return {ast_type, create<ast::ScalarConstructorExpression>( - Source{}, create<ast::FloatLiteral>( - source, spirv_const->GetFloat()))}; + if (ast_type->Is<F32>()) { + return {ty_.F32(), create<ast::ScalarConstructorExpression>( + Source{}, create<ast::FloatLiteral>( + source, spirv_const->GetFloat()))}; } - if (ast_type->Is<ast::Bool>()) { + if (ast_type->Is<Bool>()) { const bool value = spirv_const->AsNullConstant() ? false : spirv_const->AsBoolConstant()->value(); - return {ast_type, create<ast::ScalarConstructorExpression>( - Source{}, create<ast::BoolLiteral>(source, value))}; + return {ty_.Bool(), create<ast::ScalarConstructorExpression>( + Source{}, create<ast::BoolLiteral>(source, value))}; } auto* spirv_composite_const = spirv_const->AsCompositeConstant(); if (spirv_composite_const != nullptr) { @@ -1521,10 +1522,9 @@ } ast_components.emplace_back(ast_component.expr); } - return {original_ast_type, - create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(original_ast_type), - std::move(ast_components))}; + return {original_ast_type, create<ast::TypeConstructorExpression>( + Source{}, original_ast_type->Build(builder_), + std::move(ast_components))}; } auto* spirv_null_const = spirv_const->AsNullConstant(); if (spirv_null_const != nullptr) { @@ -1535,7 +1535,7 @@ return {}; } -ast::Expression* ParserImpl::MakeNullValue(ast::Type* type) { +ast::Expression* ParserImpl::MakeNullValue(const Type* type) { // TODO(dneto): Use the no-operands constructor syntax when it becomes // available in Tint. // https://github.com/gpuweb/gpuweb/issues/685 @@ -1549,93 +1549,89 @@ auto* original_type = type; type = type->UnwrapIfNeeded(); - if (type->Is<ast::Bool>()) { + if (type->Is<Bool>()) { return create<ast::ScalarConstructorExpression>( Source{}, create<ast::BoolLiteral>(Source{}, false)); } - if (type->Is<ast::U32>()) { + if (type->Is<U32>()) { return create<ast::ScalarConstructorExpression>( Source{}, create<ast::UintLiteral>(Source{}, 0u)); } - if (type->Is<ast::I32>()) { + if (type->Is<I32>()) { return create<ast::ScalarConstructorExpression>( Source{}, create<ast::SintLiteral>(Source{}, 0)); } - if (type->Is<ast::F32>()) { + if (type->Is<F32>()) { return create<ast::ScalarConstructorExpression>( Source{}, create<ast::FloatLiteral>(Source{}, 0.0f)); } - if (type->Is<ast::TypeName>()) { + if (type->Is<Alias>()) { // TODO(amaiorano): No type constructor for TypeName (yet?) ast::ExpressionList ast_components; - return create<ast::TypeConstructorExpression>(Source{}, original_type, - std::move(ast_components)); + return create<ast::TypeConstructorExpression>( + Source{}, original_type->Build(builder_), std::move(ast_components)); } - if (auto* vec_ty = type->As<ast::Vector>()) { + if (auto* vec_ty = type->As<Vector>()) { ast::ExpressionList ast_components; - for (size_t i = 0; i < vec_ty->size(); ++i) { - ast_components.emplace_back(MakeNullValue(vec_ty->type())); + for (size_t i = 0; i < vec_ty->size; ++i) { + ast_components.emplace_back(MakeNullValue(vec_ty->type)); } return create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(type), - std::move(ast_components)); + Source{}, type->Build(builder_), std::move(ast_components)); } - if (auto* mat_ty = type->As<ast::Matrix>()) { + if (auto* mat_ty = type->As<Matrix>()) { // Matrix components are columns - auto column_ty = builder_.ty.vec(mat_ty->type(), mat_ty->rows()); + auto* column_ty = ty_.Vector(mat_ty->type, mat_ty->rows); ast::ExpressionList ast_components; - for (size_t i = 0; i < mat_ty->columns(); ++i) { + for (size_t i = 0; i < mat_ty->columns; ++i) { ast_components.emplace_back(MakeNullValue(column_ty)); } return create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(type), - std::move(ast_components)); + Source{}, type->Build(builder_), std::move(ast_components)); } - if (auto* arr_ty = type->As<ast::Array>()) { + if (auto* arr_ty = type->As<Array>()) { ast::ExpressionList ast_components; - for (size_t i = 0; i < arr_ty->size(); ++i) { - ast_components.emplace_back(MakeNullValue(arr_ty->type())); + for (size_t i = 0; i < arr_ty->size; ++i) { + ast_components.emplace_back(MakeNullValue(arr_ty->type)); } return create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(original_type), - std::move(ast_components)); + Source{}, original_type->Build(builder_), std::move(ast_components)); } - if (auto* struct_ty = type->As<ast::Struct>()) { + if (auto* struct_ty = type->As<Struct>()) { ast::ExpressionList ast_components; - for (auto* member : struct_ty->members()) { - ast_components.emplace_back(MakeNullValue(member->type())); + for (auto* member : struct_ty->members) { + ast_components.emplace_back(MakeNullValue(member)); } return create<ast::TypeConstructorExpression>( - Source{}, builder_.ty.MaybeCreateTypename(original_type), - std::move(ast_components)); + Source{}, original_type->Build(builder_), std::move(ast_components)); } - Fail() << "can't make null value for type: " << type->type_name(); + Fail() << "can't make null value for type: " << type->TypeInfo().name; return nullptr; } -TypedExpression ParserImpl::MakeNullExpression(ast::Type* type) { +TypedExpression ParserImpl::MakeNullExpression(const Type* type) { return {type, MakeNullValue(type)}; } -ast::Type* ParserImpl::UnsignedTypeFor(ast::Type* type) { - if (type->Is<ast::I32>()) { - return builder_.ty.u32(); +const Type* ParserImpl::UnsignedTypeFor(const Type* type) { + if (type->Is<I32>()) { + return ty_.U32(); } - if (auto* v = type->As<ast::Vector>()) { - if (v->type()->Is<ast::I32>()) { - return builder_.ty.vec(builder_.ty.u32(), v->size()); + if (auto* v = type->As<Vector>()) { + if (v->type->Is<I32>()) { + return ty_.Vector(ty_.U32(), v->size); } } return {}; } -ast::Type* ParserImpl::SignedTypeFor(ast::Type* type) { - if (type->Is<ast::U32>()) { - return builder_.ty.i32(); +const Type* ParserImpl::SignedTypeFor(const Type* type) { + if (type->Is<U32>()) { + return ty_.I32(); } - if (auto* v = type->As<ast::Vector>()) { - if (v->type()->Is<ast::U32>()) { - return builder_.ty.vec(builder_.ty.i32(), v->size()); + if (auto* v = type->As<Vector>()) { + if (v->type->Is<U32>()) { + return ty_.Vector(ty_.I32(), v->size); } } return {}; @@ -1674,13 +1670,14 @@ if (auto* unsigned_ty = UnsignedTypeFor(type)) { // Conversion is required. return {unsigned_ty, - create<ast::BitcastExpression>(Source{}, unsigned_ty, expr.expr)}; + create<ast::BitcastExpression>( + Source{}, unsigned_ty->Build(builder_), expr.expr)}; } } else if (requires_signed) { if (auto* signed_ty = SignedTypeFor(type)) { // Conversion is required. - return {signed_ty, - create<ast::BitcastExpression>(Source{}, signed_ty, expr.expr)}; + return {signed_ty, create<ast::BitcastExpression>( + Source{}, signed_ty->Build(builder_), expr.expr)}; } } // We should not reach here. @@ -1689,21 +1686,22 @@ TypedExpression ParserImpl::RectifySecondOperandSignedness( const spvtools::opt::Instruction& inst, - ast::Type* first_operand_type, + const Type* first_operand_type, TypedExpression&& second_operand_expr) { - if (!AstTypesEquivalent(first_operand_type, second_operand_expr.type) && + if ((first_operand_type != second_operand_expr.type) && AssumesSecondOperandSignednessMatchesFirstOperand(inst.opcode())) { // Conversion is required. return {first_operand_type, - create<ast::BitcastExpression>(Source{}, first_operand_type, + create<ast::BitcastExpression>(Source{}, + first_operand_type->Build(builder_), second_operand_expr.expr)}; } // No conversion necessary. return std::move(second_operand_expr); } -ast::Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst, - ast::Type* first_operand_type) { +const Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst, + const Type* first_operand_type) { const auto opcode = inst.opcode(); if (AssumesResultSignednessMatchesFirstOperand(opcode)) { return first_operand_type; @@ -1718,66 +1716,63 @@ return nullptr; } -ast::Type* ParserImpl::GetSignedIntMatchingShape(ast::Type* other) { +const Type* ParserImpl::GetSignedIntMatchingShape(const Type* other) { if (other == nullptr) { Fail() << "no type provided"; } - auto i32 = builder_.ty.i32(); - if (other->Is<ast::F32>() || other->Is<ast::U32>() || other->Is<ast::I32>()) { - return i32; + if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) { + return ty_.I32(); } - auto* vec_ty = other->As<ast::Vector>(); - if (vec_ty) { - return builder_.ty.vec(i32, vec_ty->size()); + if (auto* vec_ty = other->As<Vector>()) { + return ty_.Vector(ty_.I32(), vec_ty->size); } - Fail() << "required numeric scalar or vector, but got " << other->type_name(); + Fail() << "required numeric scalar or vector, but got " + << other->TypeInfo().name; return nullptr; } -ast::Type* ParserImpl::GetUnsignedIntMatchingShape(ast::Type* other) { +const Type* ParserImpl::GetUnsignedIntMatchingShape(const Type* other) { if (other == nullptr) { Fail() << "no type provided"; return nullptr; } - auto u32 = builder_.ty.u32(); - if (other->Is<ast::F32>() || other->Is<ast::U32>() || other->Is<ast::I32>()) { - return u32; + if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) { + return ty_.U32(); } - auto* vec_ty = other->As<ast::Vector>(); - if (vec_ty) { - return builder_.ty.vec(u32, vec_ty->size()); + if (auto* vec_ty = other->As<Vector>()) { + return ty_.Vector(ty_.U32(), vec_ty->size); } - Fail() << "required numeric scalar or vector, but got " << other->type_name(); + Fail() << "required numeric scalar or vector, but got " + << other->TypeInfo().name; return nullptr; } TypedExpression ParserImpl::RectifyForcedResultType( TypedExpression expr, const spvtools::opt::Instruction& inst, - ast::Type* first_operand_type) { + const Type* first_operand_type) { auto* forced_result_ty = ForcedResultType(inst, first_operand_type); - if ((forced_result_ty == nullptr) || - AstTypesEquivalent(forced_result_ty, expr.type)) { + if ((!forced_result_ty) || (forced_result_ty == expr.type)) { return expr; } - return {expr.type, - create<ast::BitcastExpression>(Source{}, expr.type, expr.expr)}; + return {expr.type, create<ast::BitcastExpression>( + Source{}, expr.type->Build(builder_), expr.expr)}; } TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) { - if (expr.type && expr.type->is_signed_scalar_or_vector()) { + if (expr.type && expr.type->IsSignedScalarOrVector()) { auto* new_type = GetUnsignedIntMatchingShape(expr.type); - return {new_type, - create<ast::BitcastExpression>(Source{}, new_type, expr.expr)}; + return {new_type, create<ast::BitcastExpression>( + Source{}, new_type->Build(builder_), expr.expr)}; } return expr; } TypedExpression ParserImpl::AsSigned(TypedExpression expr) { - if (expr.type && expr.type->is_unsigned_scalar_or_vector()) { + if (expr.type && expr.type->IsUnsignedScalarOrVector()) { auto* new_type = GetSignedIntMatchingShape(expr.type); - return {new_type, - create<ast::BitcastExpression>(Source{}, new_type, expr.expr)}; + return {new_type, create<ast::BitcastExpression>( + Source{}, new_type->Build(builder_), expr.expr)}; } return expr; } @@ -1952,7 +1947,7 @@ return raw_handle_type; } -ast::Pointer* ParserImpl::GetTypeForHandleVar( +const Pointer* ParserImpl::GetTypeForHandleVar( const spvtools::opt::Instruction& var) { auto where = handle_type_.find(&var); if (where != handle_type_.end()) { @@ -2036,11 +2031,11 @@ } // Construct the Tint handle type. - ast::Type* ast_store_type; + const Type* ast_store_type = nullptr; if (usage.IsSampler()) { - ast_store_type = builder_.ty.sampler( - usage.IsComparisonSampler() ? ast::SamplerKind::kComparisonSampler - : ast::SamplerKind::kSampler); + ast_store_type = ty_.Sampler(usage.IsComparisonSampler() + ? ast::SamplerKind::kComparisonSampler + : ast::SamplerKind::kSampler); } else if (usage.IsTexture()) { const spvtools::opt::analysis::Image* image_type = type_mgr_->GetType(raw_handle_type->result_id())->AsImage(); @@ -2069,14 +2064,13 @@ // OpImage variable with an OpImage*Dref* instruction. In WGSL we must // treat that as a depth texture. if (image_type->depth() || usage.IsDepthTexture()) { - ast_store_type = builder_.ty.depth_texture(dim); + ast_store_type = ty_.DepthTexture(dim); } else if (image_type->is_multisampled()) { // Multisampled textures are never depth textures. ast_store_type = - builder_.ty.multisampled_texture(dim, ast_sampled_component_type); + ty_.MultisampledTexture(dim, ast_sampled_component_type); } else { - ast_store_type = - builder_.ty.sampled_texture(dim, ast_sampled_component_type); + ast_store_type = ty_.SampledTexture(dim, ast_sampled_component_type); } } else { const auto access = usage.IsStorageReadTexture() @@ -2087,7 +2081,7 @@ return nullptr; } ast_store_type = - builder_.ty.access(access, builder_.ty.storage_texture(dim, format)); + ty_.AccessControl(ty_.StorageTexture(dim, format), access); } } else { Fail() << "unsupported: UniformConstant variable is not a recognized " @@ -2097,14 +2091,14 @@ } // Form the pointer type. - auto result = - builder_.ty.pointer(ast_store_type, ast::StorageClass::kUniformConstant); + auto* result = + ty_.Pointer(ast_store_type, ast::StorageClass::kUniformConstant); // Remember it for later. handle_type_[&var] = result; return result; } -ast::Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) { +const Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) { switch (format) { case ast::ImageFormat::kR8Uint: case ast::ImageFormat::kR16Uint: @@ -2115,7 +2109,7 @@ case ast::ImageFormat::kRg32Uint: case ast::ImageFormat::kRgba16Uint: case ast::ImageFormat::kRgba32Uint: - return builder_.ty.u32(); + return ty_.U32(); case ast::ImageFormat::kR8Sint: case ast::ImageFormat::kR16Sint: @@ -2126,7 +2120,7 @@ case ast::ImageFormat::kRg32Sint: case ast::ImageFormat::kRgba16Sint: case ast::ImageFormat::kRgba32Sint: - return builder_.ty.i32(); + return ty_.I32(); case ast::ImageFormat::kR8Unorm: case ast::ImageFormat::kRg8Unorm: @@ -2145,7 +2139,7 @@ case ast::ImageFormat::kRg32Float: case ast::ImageFormat::kRgba16Float: case ast::ImageFormat::kRgba32Float: - return builder_.ty.f32(); + return ty_.F32(); default: break; } @@ -2153,7 +2147,7 @@ return nullptr; } -ast::Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) { +const Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) { auto* component_type = GetComponentTypeForFormat(format); if (!component_type) { return nullptr; @@ -2185,7 +2179,7 @@ case ast::ImageFormat::kRg8Uint: case ast::ImageFormat::kRg8Unorm: // Two channels - return builder_.ty.vec(component_type, 2); + return ty_.Vector(component_type, 2); case ast::ImageFormat::kBgra8Unorm: case ast::ImageFormat::kBgra8UnormSrgb: @@ -2202,7 +2196,7 @@ case ast::ImageFormat::kRgba8Unorm: case ast::ImageFormat::kRgba8UnormSrgb: // Four channels - return builder_.ty.vec(component_type, 4); + return ty_.Vector(component_type, 4); default: break;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 900165b..7253cdd 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h
@@ -28,6 +28,7 @@ #include "src/reader/spirv/entry_point_info.h" #include "src/reader/spirv/enum_converter.h" #include "src/reader/spirv/namer.h" +#include "src/reader/spirv/parser_type.h" #include "src/reader/spirv/usage.h" /// This is the implementation of the SPIR-V parser for Tint. @@ -51,14 +52,6 @@ namespace reader { namespace spirv { -/// Returns true of the two input ast types are semantically equivalent -/// @param lhs first type to compare -/// @param rhs other type to compare -/// @returns true if both types are semantically equivalent -inline bool AstTypesEquivalent(ast::Type* lhs, ast::Type* rhs) { - return lhs->type_name() == rhs->type_name(); -} - /// The binary representation of a SPIR-V decoration enum followed by its /// operands, if any. /// Example: { SpvDecorationBlock } @@ -81,10 +74,10 @@ /// Constructor /// @param type_in the type of the expression /// @param expr_in the expression - TypedExpression(ast::Type* type_in, ast::Expression* expr_in); + TypedExpression(const Type* type_in, ast::Expression* expr_in); /// The type - ast::Type* type; + Type const* type = nullptr; /// The expression ast::Expression* expr = nullptr; }; @@ -110,6 +103,9 @@ /// program. To be used only for testing. ProgramBuilder& builder() { return builder_; } + /// @returns the type manager + TypeManager& type_manager() { return ty_; } + /// Logs failure, ands return a failure stream to accumulate diagnostic /// messages. By convention, a failure should only be logged along with /// a non-empty string diagnostic. @@ -163,7 +159,7 @@ /// after the internal representation of the module has been built. /// @param type_id the SPIR-V ID of a type. /// @returns a Tint type, or nullptr - ast::Type* ConvertType(uint32_t type_id); + const Type* ConvertType(uint32_t type_id); /// Emits an alias type declaration for the given type, if necessary, and /// also updates the mapping of the SPIR-V type ID to the alias type. @@ -176,9 +172,9 @@ /// @param type the type that might get an alias /// @param ast_type the ast type that might get an alias /// @returns an alias type or `ast_type` if no alias was created - ast::Type* MaybeGenerateAlias(uint32_t type_id, - const spvtools::opt::analysis::Type* type, - ast::Type* ast_type); + const Type* MaybeGenerateAlias(uint32_t type_id, + const spvtools::opt::analysis::Type* type, + const Type* ast_type); /// @returns the fail stream object FailStream& fail_stream() { return fail_stream_; } @@ -328,7 +324,7 @@ /// in the error case ast::Variable* MakeVariable(uint32_t id, ast::StorageClass sc, - ast::Type* type, + const Type* type, bool is_const, ast::Expression* constructor, ast::DecorationList decorations); @@ -341,12 +337,12 @@ /// Creates an AST expression node for the null value for the given type. /// @param type the AST type /// @returns a new expression - ast::Expression* MakeNullValue(ast::Type* type); + ast::Expression* MakeNullValue(const Type* type); /// Make a typed expression for the null value for the given type. /// @param type the AST type /// @returns a new typed expression - TypedExpression MakeNullExpression(ast::Type* type); + TypedExpression MakeNullExpression(const Type* type); /// Converts a given expression to the signedness demanded for an operand /// of the given SPIR-V instruction, if required. If the instruction assumes @@ -371,7 +367,7 @@ /// @returns second_operand_expr, or a cast of it TypedExpression RectifySecondOperandSignedness( const spvtools::opt::Instruction& inst, - ast::Type* first_operand_type, + const Type* first_operand_type, TypedExpression&& second_operand_expr); /// Returns the "forced" result type for the given SPIR-V instruction. @@ -382,8 +378,8 @@ /// @param inst the SPIR-V instruction /// @param first_operand_type the AST type for the first operand. /// @returns the forced AST result type, or nullptr if no forcing is required. - ast::Type* ForcedResultType(const spvtools::opt::Instruction& inst, - ast::Type* first_operand_type); + const Type* ForcedResultType(const spvtools::opt::Instruction& inst, + const Type* first_operand_type); /// Returns a signed integer scalar or vector type matching the shape (scalar, /// vector, and component bit width) of another type, which itself is a @@ -391,7 +387,7 @@ /// requirement. /// @param other the type whose shape must be matched /// @returns the signed scalar or vector type - ast::Type* GetSignedIntMatchingShape(ast::Type* other); + const Type* GetSignedIntMatchingShape(const Type* other); /// Returns a signed integer scalar or vector type matching the shape (scalar, /// vector, and component bit width) of another type, which itself is a @@ -399,7 +395,7 @@ /// requirement. /// @param other the type whose shape must be matched /// @returns the unsigned scalar or vector type - ast::Type* GetUnsignedIntMatchingShape(ast::Type* other); + const Type* GetUnsignedIntMatchingShape(const Type* other); /// Wraps the given expression in an as-cast to the given expression's type, /// when the underlying operation produces a forced result type different @@ -412,18 +408,20 @@ TypedExpression RectifyForcedResultType( TypedExpression expr, const spvtools::opt::Instruction& inst, - ast::Type* first_operand_type); + const Type* first_operand_type); - /// @returns the given expression, but ensuring it's an unsigned type of the - /// same shape as the operand. Wraps the expresion with a bitcast if needed. + /// Returns the given expression, but ensuring it's an unsigned type of the + /// same shape as the operand. Wraps the expression with a bitcast if needed. /// Assumes the given expresion is a integer scalar or vector. /// @param expr an integer scalar or integer vector expression. + /// @return the potentially cast TypedExpression TypedExpression AsUnsigned(TypedExpression expr); - /// @returns the given expression, but ensuring it's a signed type of the - /// same shape as the operand. Wraps the expresion with a bitcast if needed. + /// Returns the given expression, but ensuring it's a signed type of the + /// same shape as the operand. Wraps the expression with a bitcast if needed. /// Assumes the given expresion is a integer scalar or vector. /// @param expr an integer scalar or integer vector expression. + /// @return the potentially cast TypedExpression TypedExpression AsSigned(TypedExpression expr); /// Bookkeeping used for tracking the "position" builtin variable. @@ -512,18 +510,18 @@ /// @param var the OpVariable instruction /// @returns the Tint AST type for the poiner-to-{sampler|texture} or null on /// error - ast::Pointer* GetTypeForHandleVar(const spvtools::opt::Instruction& var); + const Pointer* GetTypeForHandleVar(const spvtools::opt::Instruction& var); /// Returns the channel component type corresponding to the given image /// format. /// @param format image texel format /// @returns the component type, one of f32, i32, u32 - ast::Type* GetComponentTypeForFormat(ast::ImageFormat format); + const Type* GetComponentTypeForFormat(ast::ImageFormat format); /// Returns texel type corresponding to the given image format. /// @param format image texel format /// @returns the texel format - ast::Type* GetTexelTypeForFormat(ast::ImageFormat format); + const Type* GetTexelTypeForFormat(ast::ImageFormat format); /// Returns the SPIR-V instruction with the given ID, or nullptr. /// @param id the SPIR-V result ID @@ -561,19 +559,20 @@ private: /// Converts a specific SPIR-V type to a Tint type. Integer case - ast::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); + const Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); /// Converts a specific SPIR-V type to a Tint type. Float case - ast::Type* ConvertType(const spvtools::opt::analysis::Float* float_ty); + const Type* ConvertType(const spvtools::opt::analysis::Float* float_ty); /// Converts a specific SPIR-V type to a Tint type. Vector case - ast::Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty); + const Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty); /// Converts a specific SPIR-V type to a Tint type. Matrix case - ast::Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty); + const Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty); /// Converts a specific SPIR-V type to a Tint type. RuntimeArray case /// @param rtarr_ty the Tint type - ast::Type* ConvertType(const spvtools::opt::analysis::RuntimeArray* rtarr_ty); + const Type* ConvertType( + const spvtools::opt::analysis::RuntimeArray* rtarr_ty); /// Converts a specific SPIR-V type to a Tint type. Array case /// @param arr_ty the Tint type - ast::Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty); + const Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty); /// Converts a specific SPIR-V type to a Tint type. Struct case. /// SPIR-V allows distinct struct type definitions for two OpTypeStruct /// that otherwise have the same set of members (and struct and member @@ -585,34 +584,34 @@ /// not significant to the optimizer's module representation. /// @param type_id the SPIR-V ID for the type. /// @param struct_ty the Tint type - ast::Type* ConvertType(uint32_t type_id, - const spvtools::opt::analysis::Struct* struct_ty); + const Type* ConvertType(uint32_t type_id, + const spvtools::opt::analysis::Struct* struct_ty); /// Converts a specific SPIR-V type to a Tint type. Pointer case /// The pointer to gl_PerVertex maps to nullptr, and instead is recorded /// in member #builtin_position_. /// @param type_id the SPIR-V ID for the type. /// @param ptr_ty the Tint type - ast::Type* ConvertType(uint32_t type_id, - const spvtools::opt::analysis::Pointer* ptr_ty); + const Type* ConvertType(uint32_t type_id, + const spvtools::opt::analysis::Pointer* ptr_ty); /// If `type` is a signed integral, or vector of signed integral, /// returns the unsigned type, otherwise returns `type`. /// @param type the possibly signed type /// @returns the unsigned type - ast::Type* UnsignedTypeFor(ast::Type* type); + const Type* UnsignedTypeFor(const Type* type); /// If `type` is a unsigned integral, or vector of unsigned integral, /// returns the signed type, otherwise returns `type`. /// @param type the possibly unsigned type /// @returns the signed type - ast::Type* SignedTypeFor(ast::Type* type); + const Type* SignedTypeFor(const Type* type); /// Parses the array or runtime-array decorations. /// @param spv_type the SPIR-V array or runtime-array type. - /// @param decorations the populated decoration list + /// @param array_stride pointer to the array stride /// @returns true on success. bool ParseArrayDecorations(const spvtools::opt::analysis::Type* spv_type, - ast::DecorationList* decorations); + uint32_t* array_stride); /// Adds `type` as a constructed type if it hasn't been added yet. /// @param name the type's unique name @@ -633,6 +632,9 @@ // The program builder. ProgramBuilder builder_; + // The type manager. + TypeManager ty_; + // Is the parse successful? bool success_ = true; // Collector for diagnostic messages. @@ -716,7 +718,7 @@ // usages implied by usages of the memory-object-declaration. std::unordered_map<const spvtools::opt::Instruction*, Usage> handle_usage_; // The inferred pointer type for the given handle variable. - std::unordered_map<const spvtools::opt::Instruction*, ast::Pointer*> + std::unordered_map<const spvtools::opt::Instruction*, const Pointer*> handle_type_; // Set of symbols of constructed types that have been added, used to avoid
diff --git a/src/reader/spirv/parser_impl_convert_type_test.cc b/src/reader/spirv/parser_impl_convert_type_test.cc index 180f119..3b32374 100644 --- a/src/reader/spirv/parser_impl_convert_type_test.cc +++ b/src/reader/spirv/parser_impl_convert_type_test.cc
@@ -75,7 +75,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(1); - EXPECT_TRUE(type->Is<ast::Void>()); + EXPECT_TRUE(type->Is<Void>()); EXPECT_TRUE(p->error().empty()); } @@ -84,7 +84,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(100); - EXPECT_TRUE(type->Is<ast::Bool>()); + EXPECT_TRUE(type->Is<Bool>()); EXPECT_TRUE(p->error().empty()); } @@ -93,7 +93,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(2); - EXPECT_TRUE(type->Is<ast::I32>()); + EXPECT_TRUE(type->Is<I32>()); EXPECT_TRUE(p->error().empty()); } @@ -102,7 +102,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::U32>()); + EXPECT_TRUE(type->Is<U32>()); EXPECT_TRUE(p->error().empty()); } @@ -111,7 +111,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(4); - EXPECT_TRUE(type->Is<ast::F32>()); + EXPECT_TRUE(type->Is<F32>()); EXPECT_TRUE(p->error().empty()); } @@ -155,19 +155,19 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* v2xf32 = p->ConvertType(20); - EXPECT_TRUE(v2xf32->Is<ast::Vector>()); - EXPECT_TRUE(v2xf32->As<ast::Vector>()->type()->Is<ast::F32>()); - EXPECT_EQ(v2xf32->As<ast::Vector>()->size(), 2u); + EXPECT_TRUE(v2xf32->Is<Vector>()); + EXPECT_TRUE(v2xf32->As<Vector>()->type->Is<F32>()); + EXPECT_EQ(v2xf32->As<Vector>()->size, 2u); auto* v3xf32 = p->ConvertType(30); - EXPECT_TRUE(v3xf32->Is<ast::Vector>()); - EXPECT_TRUE(v3xf32->As<ast::Vector>()->type()->Is<ast::F32>()); - EXPECT_EQ(v3xf32->As<ast::Vector>()->size(), 3u); + EXPECT_TRUE(v3xf32->Is<Vector>()); + EXPECT_TRUE(v3xf32->As<Vector>()->type->Is<F32>()); + EXPECT_EQ(v3xf32->As<Vector>()->size, 3u); auto* v4xf32 = p->ConvertType(40); - EXPECT_TRUE(v4xf32->Is<ast::Vector>()); - EXPECT_TRUE(v4xf32->As<ast::Vector>()->type()->Is<ast::F32>()); - EXPECT_EQ(v4xf32->As<ast::Vector>()->size(), 4u); + EXPECT_TRUE(v4xf32->Is<Vector>()); + EXPECT_TRUE(v4xf32->As<Vector>()->type->Is<F32>()); + EXPECT_EQ(v4xf32->As<Vector>()->size, 4u); EXPECT_TRUE(p->error().empty()); } @@ -182,19 +182,19 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* v2xi32 = p->ConvertType(20); - EXPECT_TRUE(v2xi32->Is<ast::Vector>()); - EXPECT_TRUE(v2xi32->As<ast::Vector>()->type()->Is<ast::I32>()); - EXPECT_EQ(v2xi32->As<ast::Vector>()->size(), 2u); + EXPECT_TRUE(v2xi32->Is<Vector>()); + EXPECT_TRUE(v2xi32->As<Vector>()->type->Is<I32>()); + EXPECT_EQ(v2xi32->As<Vector>()->size, 2u); auto* v3xi32 = p->ConvertType(30); - EXPECT_TRUE(v3xi32->Is<ast::Vector>()); - EXPECT_TRUE(v3xi32->As<ast::Vector>()->type()->Is<ast::I32>()); - EXPECT_EQ(v3xi32->As<ast::Vector>()->size(), 3u); + EXPECT_TRUE(v3xi32->Is<Vector>()); + EXPECT_TRUE(v3xi32->As<Vector>()->type->Is<I32>()); + EXPECT_EQ(v3xi32->As<Vector>()->size, 3u); auto* v4xi32 = p->ConvertType(40); - EXPECT_TRUE(v4xi32->Is<ast::Vector>()); - EXPECT_TRUE(v4xi32->As<ast::Vector>()->type()->Is<ast::I32>()); - EXPECT_EQ(v4xi32->As<ast::Vector>()->size(), 4u); + EXPECT_TRUE(v4xi32->Is<Vector>()); + EXPECT_TRUE(v4xi32->As<Vector>()->type->Is<I32>()); + EXPECT_EQ(v4xi32->As<Vector>()->size, 4u); EXPECT_TRUE(p->error().empty()); } @@ -209,19 +209,19 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* v2xu32 = p->ConvertType(20); - EXPECT_TRUE(v2xu32->Is<ast::Vector>()); - EXPECT_TRUE(v2xu32->As<ast::Vector>()->type()->Is<ast::U32>()); - EXPECT_EQ(v2xu32->As<ast::Vector>()->size(), 2u); + EXPECT_TRUE(v2xu32->Is<Vector>()); + EXPECT_TRUE(v2xu32->As<Vector>()->type->Is<U32>()); + EXPECT_EQ(v2xu32->As<Vector>()->size, 2u); auto* v3xu32 = p->ConvertType(30); - EXPECT_TRUE(v3xu32->Is<ast::Vector>()); - EXPECT_TRUE(v3xu32->As<ast::Vector>()->type()->Is<ast::U32>()); - EXPECT_EQ(v3xu32->As<ast::Vector>()->size(), 3u); + EXPECT_TRUE(v3xu32->Is<Vector>()); + EXPECT_TRUE(v3xu32->As<Vector>()->type->Is<U32>()); + EXPECT_EQ(v3xu32->As<Vector>()->size, 3u); auto* v4xu32 = p->ConvertType(40); - EXPECT_TRUE(v4xu32->Is<ast::Vector>()); - EXPECT_TRUE(v4xu32->As<ast::Vector>()->type()->Is<ast::U32>()); - EXPECT_EQ(v4xu32->As<ast::Vector>()->size(), 4u); + EXPECT_TRUE(v4xu32->Is<Vector>()); + EXPECT_TRUE(v4xu32->As<Vector>()->type->Is<U32>()); + EXPECT_EQ(v4xu32->As<Vector>()->size, 4u); EXPECT_TRUE(p->error().empty()); } @@ -261,58 +261,58 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* m22 = p->ConvertType(22); - EXPECT_TRUE(m22->Is<ast::Matrix>()); - EXPECT_TRUE(m22->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m22->As<ast::Matrix>()->rows(), 2u); - EXPECT_EQ(m22->As<ast::Matrix>()->columns(), 2u); + EXPECT_TRUE(m22->Is<Matrix>()); + EXPECT_TRUE(m22->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m22->As<Matrix>()->rows, 2u); + EXPECT_EQ(m22->As<Matrix>()->columns, 2u); auto* m23 = p->ConvertType(23); - EXPECT_TRUE(m23->Is<ast::Matrix>()); - EXPECT_TRUE(m23->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m23->As<ast::Matrix>()->rows(), 2u); - EXPECT_EQ(m23->As<ast::Matrix>()->columns(), 3u); + EXPECT_TRUE(m23->Is<Matrix>()); + EXPECT_TRUE(m23->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m23->As<Matrix>()->rows, 2u); + EXPECT_EQ(m23->As<Matrix>()->columns, 3u); auto* m24 = p->ConvertType(24); - EXPECT_TRUE(m24->Is<ast::Matrix>()); - EXPECT_TRUE(m24->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m24->As<ast::Matrix>()->rows(), 2u); - EXPECT_EQ(m24->As<ast::Matrix>()->columns(), 4u); + EXPECT_TRUE(m24->Is<Matrix>()); + EXPECT_TRUE(m24->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m24->As<Matrix>()->rows, 2u); + EXPECT_EQ(m24->As<Matrix>()->columns, 4u); auto* m32 = p->ConvertType(32); - EXPECT_TRUE(m32->Is<ast::Matrix>()); - EXPECT_TRUE(m32->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m32->As<ast::Matrix>()->rows(), 3u); - EXPECT_EQ(m32->As<ast::Matrix>()->columns(), 2u); + EXPECT_TRUE(m32->Is<Matrix>()); + EXPECT_TRUE(m32->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m32->As<Matrix>()->rows, 3u); + EXPECT_EQ(m32->As<Matrix>()->columns, 2u); auto* m33 = p->ConvertType(33); - EXPECT_TRUE(m33->Is<ast::Matrix>()); - EXPECT_TRUE(m33->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m33->As<ast::Matrix>()->rows(), 3u); - EXPECT_EQ(m33->As<ast::Matrix>()->columns(), 3u); + EXPECT_TRUE(m33->Is<Matrix>()); + EXPECT_TRUE(m33->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m33->As<Matrix>()->rows, 3u); + EXPECT_EQ(m33->As<Matrix>()->columns, 3u); auto* m34 = p->ConvertType(34); - EXPECT_TRUE(m34->Is<ast::Matrix>()); - EXPECT_TRUE(m34->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m34->As<ast::Matrix>()->rows(), 3u); - EXPECT_EQ(m34->As<ast::Matrix>()->columns(), 4u); + EXPECT_TRUE(m34->Is<Matrix>()); + EXPECT_TRUE(m34->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m34->As<Matrix>()->rows, 3u); + EXPECT_EQ(m34->As<Matrix>()->columns, 4u); auto* m42 = p->ConvertType(42); - EXPECT_TRUE(m42->Is<ast::Matrix>()); - EXPECT_TRUE(m42->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m42->As<ast::Matrix>()->rows(), 4u); - EXPECT_EQ(m42->As<ast::Matrix>()->columns(), 2u); + EXPECT_TRUE(m42->Is<Matrix>()); + EXPECT_TRUE(m42->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m42->As<Matrix>()->rows, 4u); + EXPECT_EQ(m42->As<Matrix>()->columns, 2u); auto* m43 = p->ConvertType(43); - EXPECT_TRUE(m43->Is<ast::Matrix>()); - EXPECT_TRUE(m43->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m43->As<ast::Matrix>()->rows(), 4u); - EXPECT_EQ(m43->As<ast::Matrix>()->columns(), 3u); + EXPECT_TRUE(m43->Is<Matrix>()); + EXPECT_TRUE(m43->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m43->As<Matrix>()->rows, 4u); + EXPECT_EQ(m43->As<Matrix>()->columns, 3u); auto* m44 = p->ConvertType(44); - EXPECT_TRUE(m44->Is<ast::Matrix>()); - EXPECT_TRUE(m44->As<ast::Matrix>()->type()->Is<ast::F32>()); - EXPECT_EQ(m44->As<ast::Matrix>()->rows(), 4u); - EXPECT_EQ(m44->As<ast::Matrix>()->columns(), 4u); + EXPECT_TRUE(m44->Is<Matrix>()); + EXPECT_TRUE(m44->As<Matrix>()->type->Is<F32>()); + EXPECT_EQ(m44->As<Matrix>()->rows, 4u); + EXPECT_EQ(m44->As<Matrix>()->columns, 4u); EXPECT_TRUE(p->error().empty()); } @@ -326,15 +326,14 @@ auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is<ast::Array>()); - auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>(); - EXPECT_TRUE(arr_type->IsRuntimeArray()); + EXPECT_TRUE(type->UnwrapAll()->Is<Array>()); + auto* arr_type = type->UnwrapAll()->As<Array>(); ASSERT_NE(arr_type, nullptr); - EXPECT_EQ(arr_type->size(), 0u); - EXPECT_EQ(arr_type->decorations().size(), 0u); - auto* elem_type = arr_type->type(); + EXPECT_EQ(arr_type->size, 0u); + EXPECT_EQ(arr_type->stride, 0u); + auto* elem_type = arr_type->type; ASSERT_NE(elem_type, nullptr); - EXPECT_TRUE(elem_type->Is<ast::U32>()); + EXPECT_TRUE(elem_type->Is<U32>()); EXPECT_TRUE(p->error().empty()); } @@ -361,14 +360,10 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>(); - EXPECT_TRUE(arr_type->IsRuntimeArray()); + auto* arr_type = type->UnwrapAll()->As<Array>(); + EXPECT_EQ(arr_type->size, 0u); ASSERT_NE(arr_type, nullptr); - ASSERT_EQ(arr_type->decorations().size(), 1u); - auto* stride = arr_type->decorations()[0]; - ASSERT_TRUE(stride->Is<ast::StrideDecoration>()); - ASSERT_EQ(stride->As<ast::StrideDecoration>()->stride(), 64u); - EXPECT_TRUE(p->error().empty()); + EXPECT_EQ(arr_type->stride, 64u); } TEST_F(SpvParserTest, ConvertType_RuntimeArray_ArrayStride_ZeroIsError) { @@ -409,15 +404,14 @@ auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type->Is<ast::Array>()); - auto* arr_type = type->As<ast::Array>(); - EXPECT_FALSE(arr_type->IsRuntimeArray()); + EXPECT_TRUE(type->Is<Array>()); + auto* arr_type = type->As<Array>(); ASSERT_NE(arr_type, nullptr); - EXPECT_EQ(arr_type->size(), 42u); - EXPECT_EQ(arr_type->decorations().size(), 0u); - auto* elem_type = arr_type->type(); + EXPECT_EQ(arr_type->size, 42u); + EXPECT_EQ(arr_type->stride, 0u); + auto* elem_type = arr_type->type; ASSERT_NE(elem_type, nullptr); - EXPECT_TRUE(elem_type->Is<ast::U32>()); + EXPECT_TRUE(elem_type->Is<U32>()); EXPECT_TRUE(p->error().empty()); } @@ -496,15 +490,10 @@ auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type->UnwrapAliasIfNeeded()->Is<ast::Array>()); - auto* arr_type = type->UnwrapAliasIfNeeded()->As<ast::Array>(); + EXPECT_TRUE(type->UnwrapAll()->Is<Array>()); + auto* arr_type = type->UnwrapAll()->As<Array>(); ASSERT_NE(arr_type, nullptr); - - ASSERT_EQ(arr_type->decorations().size(), 1u); - auto* stride = arr_type->decorations()[0]; - ASSERT_TRUE(stride->Is<ast::StrideDecoration>()); - ASSERT_EQ(stride->As<ast::StrideDecoration>()->stride(), 8u); - + EXPECT_EQ(arr_type->stride, 8u); EXPECT_TRUE(p->error().empty()); } @@ -550,14 +539,11 @@ auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type->Is<ast::Struct>()); + EXPECT_TRUE(type->Is<Struct>()); + auto* str = type->Build(p->builder()); Program program = p->program(); - EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S { - StructMember{field0: __u32} - StructMember{field1: __f32} -} -)")); + EXPECT_THAT(program.str(str), Eq(R"(__type_name_S)")); } TEST_F(SpvParserTest, ConvertType_StructWithBlockDecoration) { @@ -571,14 +557,11 @@ auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type->Is<ast::Struct>()); + EXPECT_TRUE(type->Is<Struct>()); + auto* str = type->Build(p->builder()); Program program = p->program(); - EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S { - [[block]] - StructMember{field0: __u32} -} -)")); + EXPECT_THAT(program.str(str), Eq(R"(__type_name_S)")); } TEST_F(SpvParserTest, ConvertType_StructWithMemberDecorations) { @@ -596,15 +579,11 @@ auto* type = p->ConvertType(10); ASSERT_NE(type, nullptr); - EXPECT_TRUE(type->Is<ast::Struct>()); + EXPECT_TRUE(type->Is<Struct>()); + auto* str = type->Build(p->builder()); Program program = p->program(); - EXPECT_THAT(program.str(type->As<ast::Struct>()), Eq(R"(Struct S { - StructMember{[[ offset 0 ]] field0: __f32} - StructMember{[[ offset 8 ]] field1: __vec_2__f32} - StructMember{[[ offset 16 ]] field2: __mat_2_2__f32} -} -)")); + EXPECT_THAT(program.str(str), Eq(R"(__type_name_S)")); } // TODO(dneto): Demonstrate other member decorations. Blocked on @@ -645,11 +624,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kInput); EXPECT_TRUE(p->error().empty()); } @@ -661,11 +640,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kOutput); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kOutput); EXPECT_TRUE(p->error().empty()); } @@ -677,11 +656,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniform); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kUniform); EXPECT_TRUE(p->error().empty()); } @@ -693,11 +672,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kWorkgroup); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kWorkgroup); EXPECT_TRUE(p->error().empty()); } @@ -709,11 +688,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kUniformConstant); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kUniformConstant); EXPECT_TRUE(p->error().empty()); } @@ -725,11 +704,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kStorage); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kStorage); EXPECT_TRUE(p->error().empty()); } @@ -741,11 +720,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kImage); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kImage); EXPECT_TRUE(p->error().empty()); } @@ -757,11 +736,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kPrivate); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kPrivate); EXPECT_TRUE(p->error().empty()); } @@ -773,11 +752,11 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(3); - EXPECT_TRUE(type->Is<ast::Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + EXPECT_TRUE(type->Is<Pointer>()); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_TRUE(ptr_ty->type()->Is<ast::F32>()); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kFunction); + EXPECT_TRUE(ptr_ty->type->Is<F32>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kFunction); EXPECT_TRUE(p->error().empty()); } @@ -792,17 +771,17 @@ auto* type = p->ConvertType(3); EXPECT_NE(type, nullptr); - EXPECT_TRUE(type->Is<ast::Pointer>()); + EXPECT_TRUE(type->Is<Pointer>()); - auto* ptr_ty = type->As<ast::Pointer>(); + auto* ptr_ty = type->As<Pointer>(); EXPECT_NE(ptr_ty, nullptr); - EXPECT_EQ(ptr_ty->storage_class(), ast::StorageClass::kInput); - EXPECT_TRUE(ptr_ty->type()->Is<ast::Pointer>()); + EXPECT_EQ(ptr_ty->storage_class, ast::StorageClass::kInput); + EXPECT_TRUE(ptr_ty->type->Is<Pointer>()); - auto* ptr_ptr_ty = ptr_ty->type()->As<ast::Pointer>(); + auto* ptr_ptr_ty = ptr_ty->type->As<Pointer>(); EXPECT_NE(ptr_ptr_ty, nullptr); - EXPECT_EQ(ptr_ptr_ty->storage_class(), ast::StorageClass::kOutput); - EXPECT_TRUE(ptr_ptr_ty->type()->Is<ast::F32>()); + EXPECT_EQ(ptr_ptr_ty->storage_class, ast::StorageClass::kOutput); + EXPECT_TRUE(ptr_ptr_ty->type->Is<F32>()); EXPECT_TRUE(p->error().empty()); } @@ -815,7 +794,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(1); - EXPECT_TRUE(type->Is<ast::Void>()); + EXPECT_TRUE(type->Is<Void>()); EXPECT_TRUE(p->error().empty()); } @@ -828,7 +807,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(1); - EXPECT_TRUE(type->Is<ast::Void>()); + EXPECT_TRUE(type->Is<Void>()); EXPECT_TRUE(p->error().empty()); } @@ -841,7 +820,7 @@ EXPECT_TRUE(p->BuildInternalModule()); auto* type = p->ConvertType(1); - EXPECT_TRUE(type->Is<ast::Void>()); + EXPECT_TRUE(type->Is<Void>()); EXPECT_TRUE(p->error().empty()); }
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index 91002a0..c8afb2c 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -1772,7 +1772,7 @@ [[block]] StructMember{[[ offset 0 ]] field0: __u32} StructMember{[[ offset 4 ]] field1: __f32} - StructMember{[[ offset 8 ]] field2: __alias_Arr__array__u32_2_stride_4} + StructMember{[[ offset 8 ]] field2: __type_name_Arr} } Variable{ x_1
diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h index bfa5ca6..dd83ca6 100644 --- a/src/reader/spirv/parser_impl_test_helper.h +++ b/src/reader/spirv/parser_impl_test_helper.h
@@ -158,7 +158,7 @@ /// after the internal representation of the module has been built. /// @param id the SPIR-V ID of a type. /// @returns a Tint type, or nullptr - ast::Type* ConvertType(uint32_t id) { return impl_.ConvertType(id); } + const Type* ConvertType(uint32_t id) { return impl_.ConvertType(id); } /// Gets the list of decorations for a SPIR-V result ID. Returns an empty /// vector if the ID is not a result ID, or if no decorations target that ID.
diff --git a/src/reader/spirv/parser_type.cc b/src/reader/spirv/parser_type.cc new file mode 100644 index 0000000..4f2ca5a --- /dev/null +++ b/src/reader/spirv/parser_type.cc
@@ -0,0 +1,514 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or stateied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/reader/spirv/parser_type.h" + +#include <unordered_map> +#include <utility> + +#include "src/program_builder.h" +#include "src/utils/get_or_create.h" +#include "src/utils/hash.h" + +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Type); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Void); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Bool); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::U32); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::F32); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::I32); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Pointer); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Vector); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Matrix); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Array); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::AccessControl); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Sampler); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Texture); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::DepthTexture); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::MultisampledTexture); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::SampledTexture); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::StorageTexture); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Named); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Alias); +TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::Struct); + +namespace tint { +namespace reader { +namespace spirv { + +namespace { +struct PointerHasher { + size_t operator()(const Pointer& t) const { + return utils::Hash(t.type, t.storage_class); + } +}; + +struct VectorHasher { + size_t operator()(const Vector& t) const { + return utils::Hash(t.type, t.size); + } +}; + +struct MatrixHasher { + size_t operator()(const Matrix& t) const { + return utils::Hash(t.type, t.columns, t.rows); + } +}; + +struct ArrayHasher { + size_t operator()(const Array& t) const { + return utils::Hash(t.type, t.size, t.stride); + } +}; + +struct AccessControlHasher { + size_t operator()(const AccessControl& t) const { + return utils::Hash(t.type, t.access); + } +}; + +struct MultisampledTextureHasher { + size_t operator()(const MultisampledTexture& t) const { + return utils::Hash(t.dims, t.type); + } +}; + +struct SampledTextureHasher { + size_t operator()(const SampledTexture& t) const { + return utils::Hash(t.dims, t.type); + } +}; + +struct StorageTextureHasher { + size_t operator()(const StorageTexture& t) const { + return utils::Hash(t.dims, t.format); + } +}; +} // namespace + +static bool operator==(const Pointer& a, const Pointer& b) { + return a.type == b.type && a.storage_class == b.storage_class; +} + +static bool operator==(const Vector& a, const Vector& b) { + return a.type == b.type && a.size == b.size; +} + +static bool operator==(const Matrix& a, const Matrix& b) { + return a.type == b.type && a.columns == b.columns && a.rows == b.rows; +} + +static bool operator==(const Array& a, const Array& b) { + return a.type == b.type && a.size == b.size && a.stride == b.stride; +} + +static bool operator==(const AccessControl& a, const AccessControl& b) { + return a.type == b.type && a.access == b.access; +} + +static bool operator==(const MultisampledTexture& a, + const MultisampledTexture& b) { + return a.dims == b.dims && a.type == b.type; +} + +static bool operator==(const SampledTexture& a, const SampledTexture& b) { + return a.dims == b.dims && a.type == b.type; +} + +static bool operator==(const StorageTexture& a, const StorageTexture& b) { + return a.dims == b.dims && a.format == b.format; +} + +ast::Type* Void::Build(ProgramBuilder& b) const { + return b.ty.void_(); +} + +ast::Type* Bool::Build(ProgramBuilder& b) const { + return b.ty.bool_(); +} + +ast::Type* U32::Build(ProgramBuilder& b) const { + return b.ty.u32(); +} + +ast::Type* F32::Build(ProgramBuilder& b) const { + return b.ty.f32(); +} + +ast::Type* I32::Build(ProgramBuilder& b) const { + return b.ty.i32(); +} + +Pointer::Pointer(const Type* t, ast::StorageClass s) + : type(t), storage_class(s) {} +Pointer::Pointer(const Pointer&) = default; + +ast::Type* Pointer::Build(ProgramBuilder& b) const { + return b.ty.pointer(type->Build(b), storage_class); +} + +Vector::Vector(const Type* t, uint32_t s) : type(t), size(s) {} +Vector::Vector(const Vector&) = default; + +ast::Type* Vector::Build(ProgramBuilder& b) const { + return b.ty.vec(type->Build(b), size); +} + +Matrix::Matrix(const Type* t, uint32_t c, uint32_t r) + : type(t), columns(c), rows(r) {} +Matrix::Matrix(const Matrix&) = default; + +ast::Type* Matrix::Build(ProgramBuilder& b) const { + return b.ty.mat(type->Build(b), columns, rows); +} + +Array::Array(const Type* t, uint32_t sz, uint32_t st) + : type(t), size(sz), stride(st) {} +Array::Array(const Array&) = default; + +ast::Type* Array::Build(ProgramBuilder& b) const { + return b.ty.array(type->Build(b), size, stride); +} + +AccessControl::AccessControl(const Type* t, ast::AccessControl::Access a) + : type(t), access(a) {} +AccessControl::AccessControl(const AccessControl&) = default; + +ast::Type* AccessControl::Build(ProgramBuilder& b) const { + return b.ty.access(access, type->Build(b)); +} + +Sampler::Sampler(ast::SamplerKind k) : kind(k) {} +Sampler::Sampler(const Sampler&) = default; + +ast::Type* Sampler::Build(ProgramBuilder& b) const { + return b.ty.sampler(kind); +} + +Texture::Texture(ast::TextureDimension d) : dims(d) {} +Texture::Texture(const Texture&) = default; + +DepthTexture::DepthTexture(ast::TextureDimension d) : Base(d) {} +DepthTexture::DepthTexture(const DepthTexture&) = default; + +ast::Type* DepthTexture::Build(ProgramBuilder& b) const { + return b.ty.depth_texture(dims); +} + +MultisampledTexture::MultisampledTexture(ast::TextureDimension d, const Type* t) + : Base(d), type(t) {} +MultisampledTexture::MultisampledTexture(const MultisampledTexture&) = default; + +ast::Type* MultisampledTexture::Build(ProgramBuilder& b) const { + return b.ty.multisampled_texture(dims, type->Build(b)); +} + +SampledTexture::SampledTexture(ast::TextureDimension d, const Type* t) + : Base(d), type(t) {} +SampledTexture::SampledTexture(const SampledTexture&) = default; + +ast::Type* SampledTexture::Build(ProgramBuilder& b) const { + return b.ty.sampled_texture(dims, type->Build(b)); +} + +StorageTexture::StorageTexture(ast::TextureDimension d, ast::ImageFormat f) + : Base(d), format(f) {} +StorageTexture::StorageTexture(const StorageTexture&) = default; + +ast::Type* StorageTexture::Build(ProgramBuilder& b) const { + return b.ty.storage_texture(dims, format); +} + +Named::Named(Symbol n) : name(n) {} +Named::Named(const Named&) = default; +Named::~Named() = default; + +Alias::Alias(Symbol n, const Type* ty) : Base(n), type(ty) {} +Alias::Alias(const Alias&) = default; + +ast::Type* Alias::Build(ProgramBuilder& b) const { + return b.ty.type_name(name); +} + +Struct::Struct(Symbol n, TypeList m) : Base(n), members(std::move(m)) {} +Struct::Struct(const Struct&) = default; +Struct::~Struct() = default; + +ast::Type* Struct::Build(ProgramBuilder& b) const { + return b.ty.type_name(name); +} + +/// The PIMPL state of the Types object. +struct TypeManager::State { + /// The allocator of types + BlockAllocator<Type> allocator_; + /// The lazily-created Void type + spirv::Void const* void_ = nullptr; + /// The lazily-created Bool type + spirv::Bool const* bool_ = nullptr; + /// The lazily-created U32 type + spirv::U32 const* u32_ = nullptr; + /// The lazily-created F32 type + spirv::F32 const* f32_ = nullptr; + /// The lazily-created I32 type + spirv::I32 const* i32_ = nullptr; + /// Map of Pointer to the returned Pointer type instance + std::unordered_map<spirv::Pointer, const spirv::Pointer*, PointerHasher> + pointers_; + /// Map of Vector to the returned Vector type instance + std::unordered_map<spirv::Vector, const spirv::Vector*, VectorHasher> + vectors_; + /// Map of Matrix to the returned Matrix type instance + std::unordered_map<spirv::Matrix, const spirv::Matrix*, MatrixHasher> + matrices_; + /// Map of Array to the returned Array type instance + std::unordered_map<spirv::Array, const spirv::Array*, ArrayHasher> arrays_; + /// Map of AccessControl to the returned AccessControl type instance + std::unordered_map<spirv::AccessControl, + const spirv::AccessControl*, + AccessControlHasher> + access_controls_; + /// Map of type name to returned Alias instance + std::unordered_map<Symbol, const spirv::Alias*> aliases_; + /// Map of type name to returned Struct instance + std::unordered_map<Symbol, const spirv::Struct*> structs_; + /// Map of ast::SamplerKind to returned Sampler instance + std::unordered_map<ast::SamplerKind, const spirv::Sampler*> samplers_; + /// Map of ast::TextureDimension to returned DepthTexture instance + std::unordered_map<ast::TextureDimension, const spirv::DepthTexture*> + depth_textures_; + /// Map of MultisampledTexture to the returned MultisampledTexture type + /// instance + std::unordered_map<spirv::MultisampledTexture, + const spirv::MultisampledTexture*, + MultisampledTextureHasher> + multisampled_textures_; + /// Map of SampledTexture to the returned SampledTexture type instance + std::unordered_map<spirv::SampledTexture, + const spirv::SampledTexture*, + SampledTextureHasher> + sampled_textures_; + /// Map of StorageTexture to the returned StorageTexture type instance + std::unordered_map<spirv::StorageTexture, + const spirv::StorageTexture*, + StorageTextureHasher> + storage_textures_; +}; + +const Type* Type::UnwrapPtrIfNeeded() const { + if (auto* ptr = As<Pointer>()) { + return ptr->type; + } + return this; +} + +const Type* Type::UnwrapAliasIfNeeded() const { + const Type* unwrapped = this; + while (auto* ptr = unwrapped->As<Alias>()) { + unwrapped = ptr->type; + } + return unwrapped; +} + +const Type* Type::UnwrapIfNeeded() const { + auto* where = this; + while (true) { + if (auto* alias = where->As<Alias>()) { + where = alias->type; + } else if (auto* access = where->As<AccessControl>()) { + where = access->type; + } else { + break; + } + } + return where; +} + +const Type* Type::UnwrapAll() const { + return UnwrapIfNeeded()->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); +} + +bool Type::IsFloatScalar() const { + return Is<F32>(); +} + +bool Type::IsFloatScalarOrVector() const { + return IsFloatScalar() || IsFloatVector(); +} + +bool Type::IsFloatVector() const { + return Is<Vector>([](const Vector* v) { return v->type->IsFloatScalar(); }); +} + +bool Type::IsIntegerScalar() const { + return IsAnyOf<U32, I32>(); +} + +bool Type::IsIntegerScalarOrVector() const { + return IsUnsignedScalarOrVector() || IsSignedScalarOrVector(); +} + +bool Type::IsScalar() const { + return IsAnyOf<F32, U32, I32, Bool>(); +} + +bool Type::IsSignedIntegerVector() const { + return Is<Vector>([](const Vector* v) { return v->type->Is<I32>(); }); +} + +bool Type::IsSignedScalarOrVector() const { + return Is<I32>() || IsSignedIntegerVector(); +} + +bool Type::IsUnsignedIntegerVector() const { + return Is<Vector>([](const Vector* v) { return v->type->Is<U32>(); }); +} + +bool Type::IsUnsignedScalarOrVector() const { + return Is<U32>() || IsUnsignedIntegerVector(); +} + +TypeManager::TypeManager() { + state = std::make_unique<State>(); +} + +TypeManager::~TypeManager() = default; + +const spirv::Void* TypeManager::Void() { + if (!state->void_) { + state->void_ = state->allocator_.Create<spirv::Void>(); + } + return state->void_; +} + +const spirv::Bool* TypeManager::Bool() { + if (!state->bool_) { + state->bool_ = state->allocator_.Create<spirv::Bool>(); + } + return state->bool_; +} + +const spirv::U32* TypeManager::U32() { + if (!state->u32_) { + state->u32_ = state->allocator_.Create<spirv::U32>(); + } + return state->u32_; +} + +const spirv::F32* TypeManager::F32() { + if (!state->f32_) { + state->f32_ = state->allocator_.Create<spirv::F32>(); + } + return state->f32_; +} + +const spirv::I32* TypeManager::I32() { + if (!state->i32_) { + state->i32_ = state->allocator_.Create<spirv::I32>(); + } + return state->i32_; +} + +const spirv::Pointer* TypeManager::Pointer(const Type* el, + ast::StorageClass sc) { + return utils::GetOrCreate(state->pointers_, spirv::Pointer(el, sc), [&] { + return state->allocator_.Create<spirv::Pointer>(el, sc); + }); +} + +const spirv::Vector* TypeManager::Vector(const Type* el, uint32_t size) { + return utils::GetOrCreate(state->vectors_, spirv::Vector(el, size), [&] { + return state->allocator_.Create<spirv::Vector>(el, size); + }); +} + +const spirv::Matrix* TypeManager::Matrix(const Type* el, + uint32_t columns, + uint32_t rows) { + return utils::GetOrCreate( + state->matrices_, spirv::Matrix(el, columns, rows), [&] { + return state->allocator_.Create<spirv::Matrix>(el, columns, rows); + }); +} + +const spirv::Array* TypeManager::Array(const Type* el, + uint32_t size, + uint32_t stride) { + return utils::GetOrCreate( + state->arrays_, spirv::Array(el, size, stride), + [&] { return state->allocator_.Create<spirv::Array>(el, size, stride); }); +} + +const spirv::AccessControl* TypeManager::AccessControl( + const Type* ty, + ast::AccessControl::Access ac) { + return utils::GetOrCreate( + state->access_controls_, spirv::AccessControl(ty, ac), + [&] { return state->allocator_.Create<spirv::AccessControl>(ty, ac); }); +} + +const spirv::Alias* TypeManager::Alias(Symbol name, const Type* ty) { + return utils::GetOrCreate(state->aliases_, name, [&] { + return state->allocator_.Create<spirv::Alias>(name, ty); + }); +} + +const spirv::Struct* TypeManager::Struct(Symbol name, TypeList members) { + return utils::GetOrCreate(state->structs_, name, [&] { + return state->allocator_.Create<spirv::Struct>(name, std::move(members)); + }); +} + +const spirv::Sampler* TypeManager::Sampler(ast::SamplerKind kind) { + return utils::GetOrCreate(state->samplers_, kind, [&] { + return state->allocator_.Create<spirv::Sampler>(kind); + }); +} + +const spirv::DepthTexture* TypeManager::DepthTexture( + ast::TextureDimension dims) { + return utils::GetOrCreate(state->depth_textures_, dims, [&] { + return state->allocator_.Create<spirv::DepthTexture>(dims); + }); +} + +const spirv::MultisampledTexture* TypeManager::MultisampledTexture( + ast::TextureDimension dims, + const Type* ty) { + return utils::GetOrCreate( + state->multisampled_textures_, spirv::MultisampledTexture(dims, ty), [&] { + return state->allocator_.Create<spirv::MultisampledTexture>(dims, ty); + }); +} + +const spirv::SampledTexture* TypeManager::SampledTexture( + ast::TextureDimension dims, + const Type* ty) { + return utils::GetOrCreate( + state->sampled_textures_, spirv::SampledTexture(dims, ty), [&] { + return state->allocator_.Create<spirv::SampledTexture>(dims, ty); + }); +} + +const spirv::StorageTexture* TypeManager::StorageTexture( + ast::TextureDimension dims, + ast::ImageFormat fmt) { + return utils::GetOrCreate( + state->storage_textures_, spirv::StorageTexture(dims, fmt), [&] { + return state->allocator_.Create<spirv::StorageTexture>(dims, fmt); + }); +} + +} // namespace spirv +} // namespace reader +} // namespace tint
diff --git a/src/reader/spirv/parser_type.h b/src/reader/spirv/parser_type.h new file mode 100644 index 0000000..c3ca2bd --- /dev/null +++ b/src/reader/spirv/parser_type.h
@@ -0,0 +1,498 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_READER_SPIRV_PARSER_TYPE_H_ +#define SRC_READER_SPIRV_PARSER_TYPE_H_ + +#include <memory> +#include <vector> + +#include "src/ast/access_control.h" +#include "src/ast/sampler.h" +#include "src/ast/storage_class.h" +#include "src/ast/storage_texture.h" +#include "src/ast/texture.h" +#include "src/block_allocator.h" +#include "src/castable.h" + +// Forward declarations +namespace tint { +class ProgramBuilder; +namespace ast { +class Type; +} // namespace ast +} // namespace tint + +namespace tint { +namespace reader { +namespace spirv { + +/// Type is the base class for all types +class Type : public Castable<Type> { + public: + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + virtual ast::Type* Build(ProgramBuilder& b) const = 0; + + /// @returns the pointee type if this is a pointer, `this` otherwise + const Type* UnwrapPtrIfNeeded() const; + + /// @returns the most deeply nested aliased type if this is an alias, `this` + /// otherwise + const Type* UnwrapAliasIfNeeded() const; + + /// Removes all levels of aliasing and access control. + /// This is just enough to assist with WGSL translation + /// in that you want see through one level of pointer to get from an + /// identifier-like expression as an l-value to its corresponding r-value, + /// plus see through the wrappers on either side. + /// @returns the completely unaliased type. + const Type* UnwrapIfNeeded() const; + + /// Returns the type found after: + /// - removing all layers of aliasing and access control if they exist, then + /// - removing the pointer, if it exists, then + /// - removing all further layers of aliasing or access control, if they exist + /// @returns the unwrapped type + const Type* UnwrapAll() const; + + /// @returns true if this type is a float scalar + bool IsFloatScalar() const; + /// @returns true if this type is a float scalar or vector + bool IsFloatScalarOrVector() const; + /// @returns true if this type is a float vector + bool IsFloatVector() const; + /// @returns true if this type is an integer scalar + bool IsIntegerScalar() const; + /// @returns true if this type is an integer scalar or vector + bool IsIntegerScalarOrVector() const; + /// @returns true if this type is a scalar + bool IsScalar() const; + /// @returns true if this type is a signed integer vector + bool IsSignedIntegerVector() const; + /// @returns true if this type is a signed scalar or vector + bool IsSignedScalarOrVector() const; + /// @returns true if this type is an unsigned integer vector + bool IsUnsignedIntegerVector() const; + /// @returns true if this type is an unsigned scalar or vector + bool IsUnsignedScalarOrVector() const; +}; + +using TypeList = std::vector<const Type*>; + +/// `void` type +struct Void : public Castable<Void, Type> { + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; +}; + +/// `bool` type +struct Bool : public Castable<Bool, Type> { + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; +}; + +/// `u32` type +struct U32 : public Castable<U32, Type> { + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; +}; + +/// `f32` type +struct F32 : public Castable<F32, Type> { + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; +}; + +/// `i32` type +struct I32 : public Castable<I32, Type> { + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; +}; + +/// `ptr<SC, T>` type +struct Pointer : public Castable<Pointer, Type> { + /// Constructor + /// @param ty the pointee type + /// @param sc the pointer storage class + Pointer(const Type* ty, ast::StorageClass sc); + + /// Copy constructor + /// @param other the other type to copy + Pointer(const Pointer& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the pointee type + Type const* const type; + /// the pointer storage class + ast::StorageClass const storage_class; +}; + +/// `vecN<T>` type +struct Vector : public Castable<Vector, Type> { + /// Constructor + /// @param ty the element type + /// @param sz the number of elements in the vector + Vector(const Type* ty, uint32_t sz); + + /// Copy constructor + /// @param other the other type to copy + Vector(const Vector& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the element type + Type const* const type; + /// the number of elements in the vector + uint32_t const size; +}; + +/// `matNxM<T>` type +struct Matrix : public Castable<Matrix, Type> { + /// Constructor + /// @param ty the matrix element type + /// @param c the number of columns in the matrix + /// @param r the number of rows in the matrix + Matrix(const Type* ty, uint32_t c, uint32_t r); + + /// Copy constructor + /// @param other the other type to copy + Matrix(const Matrix& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the matrix element type + Type const* const type; + /// the number of columns in the matrix + uint32_t const columns; + /// the number of rows in the matrix + uint32_t const rows; +}; + +/// `array<T, N>` type +struct Array : public Castable<Array, Type> { + /// Constructor + /// @param el the element type + /// @param sz the number of elements in the array. 0 represents runtime-sized + /// array. + /// @param st the byte stride of the array + Array(const Type* el, uint32_t sz, uint32_t st); + + /// Copy constructor + /// @param other the other type to copy + Array(const Array& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the element type + Type const* const type; + /// the number of elements in the array. 0 represents runtime-sized array. + uint32_t const size; + /// the byte stride of the array + uint32_t const stride; +}; + +/// `[[access]]` type +struct AccessControl : public Castable<AccessControl, Type> { + /// Constructor + /// @param ty the inner type + /// @param ac the access control + AccessControl(const Type* ty, ast::AccessControl::Access ac); + + /// Copy constructor + /// @param other the other type to copy + AccessControl(const AccessControl& other); + + /// @return the + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the inner type + Type const* const type; + /// the access control + ast::AccessControl::Access const access; +}; + +/// `sampler` type +struct Sampler : public Castable<Sampler, Type> { + /// Constructor + /// @param k the sampler kind + explicit Sampler(ast::SamplerKind k); + + /// Copy constructor + /// @param other the other type to copy + Sampler(const Sampler& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the sampler kind + ast::SamplerKind const kind; +}; + +/// Base class for texture types +struct Texture : public Castable<Texture, Type> { + /// Constructor + /// @param d the texture dimensions + explicit Texture(ast::TextureDimension d); + + /// Copy constructor + /// @param other the other type to copy + Texture(const Texture& other); + + /// the texture dimensions + ast::TextureDimension const dims; +}; + +/// `texture_depth_D` type +struct DepthTexture : public Castable<DepthTexture, Texture> { + /// Constructor + /// @param d the texture dimensions + explicit DepthTexture(ast::TextureDimension d); + + /// Copy constructor + /// @param other the other type to copy + DepthTexture(const DepthTexture& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; +}; + +/// `texture_multisampled_D<T>` type +struct MultisampledTexture : public Castable<MultisampledTexture, Texture> { + /// Constructor + /// @param d the texture dimensions + /// @param t the multisampled texture type + MultisampledTexture(ast::TextureDimension d, const Type* t); + + /// Copy constructor + /// @param other the other type to copy + MultisampledTexture(const MultisampledTexture& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the multisampled texture type + Type const* const type; +}; + +/// `texture_D<T>` type +struct SampledTexture : public Castable<SampledTexture, Texture> { + /// Constructor + /// @param d the texture dimensions + /// @param t the sampled texture type + SampledTexture(ast::TextureDimension d, const Type* t); + + /// Copy constructor + /// @param other the other type to copy + SampledTexture(const SampledTexture& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the sampled texture type + Type const* const type; +}; + +/// `texture_storage_D<F>` type +struct StorageTexture : public Castable<StorageTexture, Texture> { + /// Constructor + /// @param d the texture dimensions + /// @param f the storage image format + StorageTexture(ast::TextureDimension d, ast::ImageFormat f); + + /// Copy constructor + /// @param other the other type to copy + StorageTexture(const StorageTexture& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the storage image format + ast::ImageFormat const format; +}; + +/// Base class for named types +struct Named : public Castable<Named, Type> { + /// Constructor + /// @param n the type name + explicit Named(Symbol n); + + /// Copy constructor + /// @param other the other type to copy + Named(const Named& other); + + /// Destructor + ~Named() override; + + /// the type name + Symbol const name; +}; + +/// `type T = N` type +struct Alias : public Castable<Alias, Named> { + /// Constructor + /// @param n the alias name + /// @param t the aliased type + Alias(Symbol n, const Type* t); + + /// Copy constructor + /// @param other the other type to copy + Alias(const Alias& other); + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the aliased type + Type const* const type; +}; + +/// `struct N { ... };` type +struct Struct : public Castable<Struct, Named> { + /// Constructor + /// @param n the struct name + /// @param m the member types + Struct(Symbol n, TypeList m); + + /// Copy constructor + /// @param other the other type to copy + Struct(const Struct& other); + + /// Destructor + ~Struct() override; + + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type* Build(ProgramBuilder& b) const override; + + /// the member types + TypeList const members; +}; + +/// A manager of types +class TypeManager { + public: + /// Constructor + TypeManager(); + + /// Destructor + ~TypeManager(); + + /// @return a Void type. Repeated calls will return the same pointer. + const spirv::Void* Void(); + /// @return a Bool type. Repeated calls will return the same pointer. + const spirv::Bool* Bool(); + /// @return a U32 type. Repeated calls will return the same pointer. + const spirv::U32* U32(); + /// @return a F32 type. Repeated calls will return the same pointer. + const spirv::F32* F32(); + /// @return a I32 type. Repeated calls will return the same pointer. + const spirv::I32* I32(); + /// @param ty the pointee type + /// @param sc the pointer storage class + /// @return a Pointer type. Repeated calls with the same arguments will return + /// the same pointer. + const spirv::Pointer* Pointer(const Type* ty, ast::StorageClass sc); + /// @param ty the element type + /// @param sz the number of elements in the vector + /// @return a Vector type. Repeated calls with the same arguments will return + /// the same pointer. + const spirv::Vector* Vector(const Type* ty, uint32_t sz); + /// @param ty the matrix element type + /// @param c the number of columns in the matrix + /// @param r the number of rows in the matrix + /// @return a Matrix type. Repeated calls with the same arguments will return + /// the same pointer. + const spirv::Matrix* Matrix(const Type* ty, uint32_t c, uint32_t r); + /// @param el the element type + /// @param sz the number of elements in the array. 0 represents runtime-sized + /// array. + /// @param st the byte stride of the array + /// @return a Array type. Repeated calls with the same arguments will return + /// the same pointer. + const spirv::Array* Array(const Type* el, uint32_t sz, uint32_t st); + /// @param ty the inner type + /// @param ac the access control + /// @return a AccessControl type. Repeated calls with the same arguments will + /// return the same pointer. + const spirv::AccessControl* AccessControl(const Type* ty, + ast::AccessControl::Access ac); + /// @param n the alias name + /// @param t the aliased type + /// @return a Alias type. Repeated calls with the same arguments will return + /// the same pointer. + const spirv::Alias* Alias(Symbol n, const Type* t); + /// @param n the struct name + /// @param m the member types + /// @return a Struct type. Repeated calls with the same arguments will return + /// the same pointer. + const spirv::Struct* Struct(Symbol n, TypeList m); + /// @param k the sampler kind + /// @return a Sampler type. Repeated calls with the same arguments will return + /// the same pointer. + const spirv::Sampler* Sampler(ast::SamplerKind k); + /// @param d the texture dimensions + /// @return a DepthTexture type. Repeated calls with the same arguments will + /// return the same pointer. + const spirv::DepthTexture* DepthTexture(ast::TextureDimension d); + /// @param d the texture dimensions + /// @param t the multisampled texture type + /// @return a MultisampledTexture type. Repeated calls with the same arguments + /// will return the same pointer. + const spirv::MultisampledTexture* MultisampledTexture(ast::TextureDimension d, + const Type* t); + /// @param d the texture dimensions + /// @param t the sampled texture type + /// @return a SampledTexture type. Repeated calls with the same arguments will + /// return the same pointer. + const spirv::SampledTexture* SampledTexture(ast::TextureDimension d, + const Type* t); + /// @param d the texture dimensions + /// @param f the storage image format + /// @return a StorageTexture type. Repeated calls with the same arguments will + /// return the same pointer. + const spirv::StorageTexture* StorageTexture(ast::TextureDimension d, + ast::ImageFormat f); + + private: + struct State; + std::unique_ptr<State> state; +}; + +} // namespace spirv +} // namespace reader +} // namespace tint + +#endif // SRC_READER_SPIRV_PARSER_TYPE_H_
diff --git a/src/reader/spirv/parser_type_test.cc b/src/reader/spirv/parser_type_test.cc new file mode 100644 index 0000000..523e736 --- /dev/null +++ b/src/reader/spirv/parser_type_test.cc
@@ -0,0 +1,104 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include "src/reader/spirv/parser_type.h" + +namespace tint { +namespace reader { +namespace spirv { +namespace { + +TEST(SpvParserTypeTest, SameArgumentsGivesSamePointer) { + Symbol sym(Symbol(1, {})); + + TypeManager ty; + EXPECT_EQ(ty.Void(), ty.Void()); + EXPECT_EQ(ty.Bool(), ty.Bool()); + EXPECT_EQ(ty.U32(), ty.U32()); + EXPECT_EQ(ty.F32(), ty.F32()); + EXPECT_EQ(ty.I32(), ty.I32()); + EXPECT_EQ(ty.Pointer(ty.I32(), ast::StorageClass::kNone), + ty.Pointer(ty.I32(), ast::StorageClass::kNone)); + EXPECT_EQ(ty.Vector(ty.I32(), 3), ty.Vector(ty.I32(), 3)); + EXPECT_EQ(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.I32(), 3, 2)); + EXPECT_EQ(ty.Array(ty.I32(), 3, 2), ty.Array(ty.I32(), 3, 2)); + EXPECT_EQ(ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly), + ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly)); + EXPECT_EQ(ty.Alias(sym, ty.I32()), ty.Alias(sym, ty.I32())); + EXPECT_EQ(ty.Struct(sym, {ty.I32()}), ty.Struct(sym, {ty.I32()})); + EXPECT_EQ(ty.Sampler(ast::SamplerKind::kSampler), + ty.Sampler(ast::SamplerKind::kSampler)); + EXPECT_EQ(ty.DepthTexture(ast::TextureDimension::k2d), + ty.DepthTexture(ast::TextureDimension::k2d)); + EXPECT_EQ(ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32()), + ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32())); + EXPECT_EQ(ty.SampledTexture(ast::TextureDimension::k2d, ty.I32()), + ty.SampledTexture(ast::TextureDimension::k2d, ty.I32())); + EXPECT_EQ( + ty.StorageTexture(ast::TextureDimension::k2d, ast::ImageFormat::kR16Sint), + ty.StorageTexture(ast::TextureDimension::k2d, + ast::ImageFormat::kR16Sint)); +} + +TEST(SpvParserTypeTest, DifferentArgumentsGivesDifferentPointer) { + Symbol sym_a(Symbol(1, {})); + Symbol sym_b(Symbol(2, {})); + + TypeManager ty; + EXPECT_NE(ty.Pointer(ty.I32(), ast::StorageClass::kNone), + ty.Pointer(ty.U32(), ast::StorageClass::kNone)); + EXPECT_NE(ty.Pointer(ty.I32(), ast::StorageClass::kNone), + ty.Pointer(ty.I32(), ast::StorageClass::kInput)); + EXPECT_NE(ty.Vector(ty.I32(), 3), ty.Vector(ty.U32(), 3)); + EXPECT_NE(ty.Vector(ty.I32(), 3), ty.Vector(ty.I32(), 2)); + EXPECT_NE(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.U32(), 3, 2)); + EXPECT_NE(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.I32(), 2, 2)); + EXPECT_NE(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.I32(), 3, 3)); + EXPECT_NE(ty.Array(ty.I32(), 3, 2), ty.Array(ty.U32(), 3, 2)); + EXPECT_NE(ty.Array(ty.I32(), 3, 2), ty.Array(ty.I32(), 2, 2)); + EXPECT_NE(ty.Array(ty.I32(), 3, 2), ty.Array(ty.I32(), 3, 3)); + EXPECT_NE(ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly), + ty.AccessControl(ty.U32(), ast::AccessControl::Access::kReadOnly)); + EXPECT_NE(ty.AccessControl(ty.I32(), ast::AccessControl::Access::kReadOnly), + ty.AccessControl(ty.I32(), ast::AccessControl::Access::kWriteOnly)); + EXPECT_NE(ty.Alias(sym_a, ty.I32()), ty.Alias(sym_b, ty.I32())); + EXPECT_NE(ty.Struct(sym_a, {ty.I32()}), ty.Struct(sym_b, {ty.I32()})); + EXPECT_NE(ty.Sampler(ast::SamplerKind::kSampler), + ty.Sampler(ast::SamplerKind::kComparisonSampler)); + EXPECT_NE(ty.DepthTexture(ast::TextureDimension::k2d), + ty.DepthTexture(ast::TextureDimension::k1d)); + EXPECT_NE(ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32()), + ty.MultisampledTexture(ast::TextureDimension::k3d, ty.I32())); + EXPECT_NE(ty.MultisampledTexture(ast::TextureDimension::k2d, ty.I32()), + ty.MultisampledTexture(ast::TextureDimension::k2d, ty.U32())); + EXPECT_NE(ty.SampledTexture(ast::TextureDimension::k2d, ty.I32()), + ty.SampledTexture(ast::TextureDimension::k3d, ty.I32())); + EXPECT_NE(ty.SampledTexture(ast::TextureDimension::k2d, ty.I32()), + ty.SampledTexture(ast::TextureDimension::k2d, ty.U32())); + EXPECT_NE( + ty.StorageTexture(ast::TextureDimension::k2d, ast::ImageFormat::kR16Sint), + ty.StorageTexture(ast::TextureDimension::k3d, + ast::ImageFormat::kR16Sint)); + EXPECT_NE( + ty.StorageTexture(ast::TextureDimension::k2d, ast::ImageFormat::kR16Sint), + ty.StorageTexture(ast::TextureDimension::k2d, + ast::ImageFormat::kR32Sint)); +} + +} // namespace +} // namespace spirv +} // namespace reader +} // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn index 740ce6b..b691997 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn
@@ -349,6 +349,7 @@ "../src/reader/spirv/parser_impl_test_helper.cc", "../src/reader/spirv/parser_impl_test_helper.h", "../src/reader/spirv/parser_impl_user_name_test.cc", + "../src/reader/spirv/parser_type_test.cc", "../src/reader/spirv/parser_test.cc", "../src/reader/spirv/spirv_tools_helpers_test.cc", "../src/reader/spirv/spirv_tools_helpers_test.h",