[tint][ir][val] Check vector and matrix element types
Fixes: 380903152
Fixes: 380898781
Change-Id: I1b66805a3186496f55b9150c573ac2d970a4e0c4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/216615
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index d56cf0b..7677f4c 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -83,6 +83,7 @@
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/core/type/i32.h"
#include "src/tint/lang/core/type/i8.h"
+#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/memory_view.h"
#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/core/type/reference.h"
@@ -936,7 +937,8 @@
size_t num_results,
size_t num_operands);
- /// Checks that @p type does not use any types that are prohibited by the target capabilities.
+ /// Checks that @p type is allowed by the spec, and does not use any types that are prohibited
+ /// by the target capabilities.
/// @param type the type
/// @param diag a function that creates an error diagnostic for the source of the type
/// @param ignore_caps a set of capabilities to ignore for this check
@@ -1778,6 +1780,20 @@
}
return true;
},
+ [&](const core::type::Vector* v) {
+ if (!v->Type()->IsScalar()) {
+ diag() << "vector elements must be scalars";
+ return false;
+ }
+ return true;
+ },
+ [&](const core::type::Matrix* m) {
+ if (!m->Type()->IsFloatScalar()) {
+ diag() << "matrix elements must be float scalars";
+ return false;
+ }
+ return true;
+ },
[](Default) { return true; });
};
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 44f1f5c..57f72f0 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -4907,6 +4907,106 @@
)");
}
+TEST_F(IR_ValidatorTest, Type_VectorElements) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.Append(f->Block(), [&] {
+ b.Var("u32_valid", AddressSpace::kFunction, ty.vec4(ty.u32()));
+ b.Var("i32_valid", AddressSpace::kFunction, ty.vec4(ty.i32()));
+ b.Var("bool_valid", AddressSpace::kFunction, ty.vec2(ty.bool_()));
+ b.Var("f16_valid", AddressSpace::kFunction, ty.vec3(ty.f16()));
+ b.Var("f32_valid", AddressSpace::kFunction, ty.vec3(ty.f32()));
+ b.Var("void_invalid", AddressSpace::kFunction, ty.vec2(ty.void_()));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(), R"(:8:5 error: var: vector elements must be scalars
+ %void_invalid:ptr<function, vec2<void>, read_write> = var
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ %u32_valid:ptr<function, vec4<u32>, read_write> = var
+ %i32_valid:ptr<function, vec4<i32>, read_write> = var
+ %bool_valid:ptr<function, vec2<bool>, read_write> = var
+ %f16_valid:ptr<function, vec3<f16>, read_write> = var
+ %f32_valid:ptr<function, vec3<f32>, read_write> = var
+ %void_invalid:ptr<function, vec2<void>, read_write> = var
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Type_MatrixElements) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.Append(f->Block(), [&] {
+ b.Var("u32_invalid", AddressSpace::kFunction, ty.mat2x2(ty.u32()));
+ b.Var("i32_invalid", AddressSpace::kFunction, ty.mat3x2(ty.i32()));
+ b.Var("bool_invalid", AddressSpace::kFunction, ty.mat4x2(ty.bool_()));
+ b.Var("f16_valid", AddressSpace::kFunction, ty.mat2x3(ty.f16()));
+ b.Var("f32_valid", AddressSpace::kFunction, ty.mat4x4(ty.f32()));
+ b.Var("void_invalid", AddressSpace::kFunction, ty.mat3x3(ty.void_()));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(), R"(:3:5 error: var: matrix elements must be float scalars
+ %u32_invalid:ptr<function, mat2x2<u32>, read_write> = var
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+:4:5 error: var: matrix elements must be float scalars
+ %i32_invalid:ptr<function, mat3x2<i32>, read_write> = var
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+:5:5 error: var: matrix elements must be float scalars
+ %bool_invalid:ptr<function, mat4x2<bool>, read_write> = var
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+:8:5 error: var: matrix elements must be float scalars
+ %void_invalid:ptr<function, mat3x3<void>, read_write> = var
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ %u32_invalid:ptr<function, mat2x2<u32>, read_write> = var
+ %i32_invalid:ptr<function, mat3x2<i32>, read_write> = var
+ %bool_invalid:ptr<function, mat4x2<bool>, read_write> = var
+ %f16_valid:ptr<function, mat2x3<f16>, read_write> = var
+ %f32_valid:ptr<function, mat4x4<f32>, read_write> = var
+ %void_invalid:ptr<function, mat3x3<void>, read_write> = var
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidatorTest, Var_RootBlock_NullResult) {
auto* v = mod.CreateInstruction<ir::Var>(nullptr);
v->SetInitializer(b.Constant(0_i));
diff --git a/src/tint/lang/core/type/type.cc b/src/tint/lang/core/type/type.cc
index cf77682..10ef6a5 100644
--- a/src/tint/lang/core/type/type.cc
+++ b/src/tint/lang/core/type/type.cc
@@ -34,6 +34,7 @@
#include "src/tint/lang/core/type/f16.h"
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/core/type/i32.h"
+#include "src/tint/lang/core/type/i8.h"
#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/core/type/reference.h"
@@ -41,6 +42,7 @@
#include "src/tint/lang/core/type/struct.h"
#include "src/tint/lang/core/type/texture.h"
#include "src/tint/lang/core/type/u32.h"
+#include "src/tint/lang/core/type/u8.h"
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/utils/rtti/switch.h"
@@ -109,7 +111,7 @@
}
bool Type::IsIntegerScalar() const {
- return IsAnyOf<U32, I32>();
+ return IsAnyOf<U32, I32, U8, I8>();
}
bool Type::IsIntegerVector() const {