Import Tint changes from Dawn
Changes:
- aa5e77b68bce04114250a15bdf2df2c1a413e2e1 [ir][spirv-writer] Remove logic for depth textures by James Price <jrprice@google.com>
- 225479f0a5f2050a8199d4533ceb2c58acd6f41f [ir][spirv-writer] Fix texture builtin return type by James Price <jrprice@google.com>
- c6223a1b5d2955ef99528ba3ebcebe40b8741cdd [ir][spirv-writer] Implement all builtin by James Price <jrprice@google.com>
- b26348e892ee5cb30b4dfab7078575427085e29c [ir] Fix duplicate StoreVectorElement disassembly. by dan sinclair <dsinclair@chromium.org>
- db2743d288cc4dd9996916af6fa70c01b8b61b63 [ir] Add accessor test by dan sinclair <dsinclair@chromium.org>
- 937670a0ed00021d0e31dd5f0b58facc7d40c232 [ir][spirv-writer] Handle mix builtin by James Price <jrprice@google.com>
- ef8671ca3ec6ed29eca8ca658b584514f08a5fda [ir][spirv-writer] Implement pow builtin by James Price <jrprice@google.com>
- 41090b502f341659fff3be7a45f298d13832be63 [ir][spirv-writer] Implement step and smoothstep by James Price <jrprice@google.com>
- ee9abc3495095550a4d6bce3ac4d185baa573cbd [ir][spirv-writer] Handle matrix arithmetic by James Price <jrprice@google.com>
- 91bbbc4977f14bff6ad86fbde7e4dd40d36f1a2d [ir][spirv-writer] Expand implicit vector splats by James Price <jrprice@google.com>
- c5c8f6356450b097c9ac50836874eda48b8c8d04 [ir][spirv-writer] Fix binary and|or for bools by James Price <jrprice@google.com>
- 41da0ee63647af1db7fff5709644645d833eac5f [ir][spirv-writer] Implement fma builtin by James Price <jrprice@google.com>
- b5996c8db8b2b6ae1af46e0e9f7fec1c43c111e2 [ir][spirv-writer] Implement barrier builtins by James Price <jrprice@google.com>
- 6c4de8b23b6406519774c658fa33df45a74b1f1e Tint: Implement bitcast for f16 by Zhaoming Jiang <zhaoming.jiang@intel.com>
- e4952f38995194c02e633463ab5ddbb9a18607ad [ir][validation] Move Validator to header file by dan sinclair <dsinclair@chromium.org>
- 57c462971bf976c703db46d061186ffd28a43c40 [ir] Fix Builder::Not() for vector types by James Price <jrprice@google.com>
- 05c6fd3a29790612b9667e4e21e6f401819e78a2 [ir][spirv-writer] Handle texture sample builtins by James Price <jrprice@google.com>
- f7386737af831594412c56decfc76ada1163abb6 [ir][spirv-writer] De-dup depth texture types by James Price <jrprice@google.com>
GitOrigin-RevId: aa5e77b68bce04114250a15bdf2df2c1a413e2e1
Change-Id: I173cb706352f7726982f1ecb864e692fc1376e67
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/141400
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 467d112..bc6530d 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -540,6 +540,10 @@
sources += [
"ir/transform/builtin_polyfill_spirv.cc",
"ir/transform/builtin_polyfill_spirv.h",
+ "ir/transform/expand_implicit_splats.cc",
+ "ir/transform/expand_implicit_splats.h",
+ "ir/transform/handle_matrix_arithmetic.cc",
+ "ir/transform/handle_matrix_arithmetic.h",
"ir/transform/merge_return.cc",
"ir/transform/merge_return.h",
"ir/transform/shader_io_spirv.cc",
@@ -1908,6 +1912,8 @@
if (tint_build_spv_writer) {
sources += [
"ir/transform/builtin_polyfill_spirv_test.cc",
+ "ir/transform/expand_implicit_splats_test.cc",
+ "ir/transform/handle_matrix_arithmetic_test.cc",
"ir/transform/merge_return_test.cc",
"ir/transform/shader_io_spirv_test.cc",
"ir/transform/var_for_dynamic_index_test.cc",
@@ -2088,6 +2094,7 @@
"writer/spirv/ir/generator_impl_ir_switch_test.cc",
"writer/spirv/ir/generator_impl_ir_swizzle_test.cc",
"writer/spirv/ir/generator_impl_ir_test.cc",
+ "writer/spirv/ir/generator_impl_ir_texture_builtin_test.cc",
"writer/spirv/ir/generator_impl_ir_type_test.cc",
"writer/spirv/ir/generator_impl_ir_unary_test.cc",
"writer/spirv/ir/generator_impl_ir_var_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index b53cffd..8c6ea8a 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -679,6 +679,10 @@
list(APPEND TINT_LIB_SRCS
ir/transform/builtin_polyfill_spirv.cc
ir/transform/builtin_polyfill_spirv.h
+ ir/transform/expand_implicit_splats.cc
+ ir/transform/expand_implicit_splats.h
+ ir/transform/handle_matrix_arithmetic.cc
+ ir/transform/handle_matrix_arithmetic.h
ir/transform/merge_return.cc
ir/transform/merge_return.h
ir/transform/shader_io_spirv.cc
@@ -1332,6 +1336,8 @@
if(${TINT_BUILD_IR})
list(APPEND TINT_TEST_SRCS
ir/transform/builtin_polyfill_spirv_test.cc
+ ir/transform/handle_matrix_arithmetic_test.cc
+ ir/transform/expand_implicit_splats_test.cc
ir/transform/merge_return_test.cc
ir/transform/shader_io_spirv_test.cc
ir/transform/var_for_dynamic_index_test.cc
@@ -1350,6 +1356,7 @@
writer/spirv/ir/generator_impl_ir_switch_test.cc
writer/spirv/ir/generator_impl_ir_swizzle_test.cc
writer/spirv/ir/generator_impl_ir_test.cc
+ writer/spirv/ir/generator_impl_ir_texture_builtin_test.cc
writer/spirv/ir/generator_impl_ir_type_test.cc
writer/spirv/ir/generator_impl_ir_unary_test.cc
writer/spirv/ir/generator_impl_ir_var_test.cc
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 6d75a37..21bda1f 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -504,7 +504,13 @@
/// @returns the operation
template <typename VAL>
ir::Binary* Not(const type::Type* type, VAL&& val) {
- return Equal(type, std::forward<VAL>(val), Constant(false));
+ if (auto* vec = type->As<type::Vector>()) {
+ return Equal(type, std::forward<VAL>(val),
+ Constant(ir.constant_values.Splat(vec, ir.constant_values.Get(false),
+ vec->Width())));
+ } else {
+ return Equal(type, std::forward<VAL>(val), Constant(false));
+ }
}
/// Creates a bitcast instruction
diff --git a/src/tint/ir/call.h b/src/tint/ir/call.h
index 2d1ff35..ea7e14e 100644
--- a/src/tint/ir/call.h
+++ b/src/tint/ir/call.h
@@ -28,6 +28,10 @@
/// @returns the call arguments
virtual utils::Slice<Value*> Args() { return operands_.Slice(); }
+ /// Append a new argument to the argument list for this call instruction.
+ /// @param arg the argument value to append
+ void AppendArg(ir::Value* arg) { AddOperand(operands_.Length(), arg); }
+
protected:
/// Constructor
Call();
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 946671e..f458585 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -477,8 +477,6 @@
[&](StoreVectorElement* s) {
EmitInstructionName("store_vector_element", s);
out_ << " ";
- EmitValue(s->To());
- out_ << " ";
EmitOperandList(s);
EmitLine();
},
diff --git a/src/tint/ir/from_program_accessor_test.cc b/src/tint/ir/from_program_accessor_test.cc
index be6e58c..cbf6ac5 100644
--- a/src/tint/ir/from_program_accessor_test.cc
+++ b/src/tint/ir/from_program_accessor_test.cc
@@ -52,6 +52,32 @@
)");
}
+TEST_F(IR_FromProgramAccessorTest, Accessor_Multiple) {
+ // let a: vec4<u32> = vec4();
+ // let b = a[2]
+ // let c = a[1]
+
+ auto* a = Let("a", ty.vec3<u32>(), vec(ty.u32(), 3));
+ auto* expr = Decl(Let("b", IndexAccessor(a, 2_u)));
+ auto* expr2 = Decl(Let("c", IndexAccessor(a, 1_u)));
+
+ WrapInFunction(Decl(a), expr, expr2);
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:vec3<u32> = let vec3<u32>(0u)
+ %b:u32 = access %a, 2u
+ %c:u32 = access %a, 1u
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_FromProgramAccessorTest, Accessor_Var_VectorSingleIndex) {
// var a: vec3<u32>
// let b = a[2]
diff --git a/src/tint/ir/from_program_unary_test.cc b/src/tint/ir/from_program_unary_test.cc
index e7eeebb..3335523 100644
--- a/src/tint/ir/from_program_unary_test.cc
+++ b/src/tint/ir/from_program_unary_test.cc
@@ -48,6 +48,30 @@
)");
}
+TEST_F(IR_FromProgramUnaryTest, EmitExpression_Unary_Not_Vector) {
+ Func("my_func", utils::Empty, ty.vec4<bool>(),
+ utils::Vector{Return(vec(ty.bool_(), 4, false))});
+ auto* expr = Not(Call("my_func"));
+ WrapInFunction(expr);
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()), R"(%my_func = func():vec4<bool> -> %b1 {
+ %b1 = block {
+ ret vec4<bool>(false)
+ }
+}
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:vec4<bool> = call %my_func
+ %tint_symbol:vec4<bool> = eq %3, vec4<bool>(false)
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_FromProgramUnaryTest, EmitExpression_Unary_Complement) {
Func("my_func", utils::Empty, ty.u32(), utils::Vector{Return(1_u)});
auto* expr = Complement(Call("my_func"));
diff --git a/src/tint/ir/from_program_var_test.cc b/src/tint/ir/from_program_var_test.cc
index 23bbf1c..7f9f7bc 100644
--- a/src/tint/ir/from_program_var_test.cc
+++ b/src/tint/ir/from_program_var_test.cc
@@ -235,7 +235,7 @@
%11:ptr<function, vec4<f32>, read_write> = access %a, %9, %10
%12:i32 = call %f, 6i
%13:f32 = load_vector_element %11, %12
- store_vector_element %7 %7, %8, %13
+ store_vector_element %7, %8, %13
ret
}
}
@@ -360,7 +360,7 @@
%13:f32 = load_vector_element %11, %12
%14:f32 = load_vector_element %7, %8
%15:f32 = add %14, %13
- store_vector_element %7 %7, %8, %15
+ store_vector_element %7, %8, %15
ret
}
}
diff --git a/src/tint/ir/intrinsic_call.cc b/src/tint/ir/intrinsic_call.cc
index 4063213..a075260 100644
--- a/src/tint/ir/intrinsic_call.cc
+++ b/src/tint/ir/intrinsic_call.cc
@@ -37,9 +37,39 @@
case IntrinsicCall::Kind::kSpirvDot:
out << "spirv.dot";
break;
+ case IntrinsicCall::Kind::kSpirvImageSampleImplicitLod:
+ out << "spirv.image_sample_implicit_lod";
+ break;
+ case IntrinsicCall::Kind::kSpirvImageSampleExplicitLod:
+ out << "spirv.image_sample_explicit_lod";
+ break;
+ case IntrinsicCall::Kind::kSpirvImageSampleDrefImplicitLod:
+ out << "spirv.image_sample_dref_implicit_lod";
+ break;
+ case IntrinsicCall::Kind::kSpirvImageSampleDrefExplicitLod:
+ out << "spirv.image_sample_dref_implicit_lod";
+ break;
+ case IntrinsicCall::Kind::kSpirvMatrixTimesMatrix:
+ out << "spirv.matrix_times_matrix";
+ break;
+ case IntrinsicCall::Kind::kSpirvMatrixTimesScalar:
+ out << "spirv.matrix_times_scalar";
+ break;
+ case IntrinsicCall::Kind::kSpirvMatrixTimesVector:
+ out << "spirv.matrix_times_vector";
+ break;
+ case IntrinsicCall::Kind::kSpirvSampledImage:
+ out << "spirv.sampled_image";
+ break;
case IntrinsicCall::Kind::kSpirvSelect:
out << "spirv.select";
break;
+ case IntrinsicCall::Kind::kSpirvVectorTimesScalar:
+ out << "spirv.vector_times_scalar";
+ break;
+ case IntrinsicCall::Kind::kSpirvVectorTimesMatrix:
+ out << "spirv.vector_times_matrix";
+ break;
}
return out;
}
diff --git a/src/tint/ir/intrinsic_call.h b/src/tint/ir/intrinsic_call.h
index 33cbb03..3a5bef5 100644
--- a/src/tint/ir/intrinsic_call.h
+++ b/src/tint/ir/intrinsic_call.h
@@ -30,7 +30,17 @@
enum class Kind {
// SPIR-V backend intrinsics.
kSpirvDot,
+ kSpirvImageSampleImplicitLod,
+ kSpirvImageSampleExplicitLod,
+ kSpirvImageSampleDrefImplicitLod,
+ kSpirvImageSampleDrefExplicitLod,
+ kSpirvMatrixTimesMatrix,
+ kSpirvMatrixTimesScalar,
+ kSpirvMatrixTimesVector,
+ kSpirvSampledImage,
kSpirvSelect,
+ kSpirvVectorTimesMatrix,
+ kSpirvVectorTimesScalar,
};
/// Constructor
diff --git a/src/tint/ir/transform/builtin_polyfill_spirv.cc b/src/tint/ir/transform/builtin_polyfill_spirv.cc
index 1aa82e5..4833847 100644
--- a/src/tint/ir/transform/builtin_polyfill_spirv.cc
+++ b/src/tint/ir/transform/builtin_polyfill_spirv.cc
@@ -16,10 +16,17 @@
#include <utility>
+#include "spirv/unified1/spirv.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/module.h"
+#include "src/tint/type/depth_multisampled_texture.h"
+#include "src/tint/type/depth_texture.h"
+#include "src/tint/type/sampled_texture.h"
+#include "src/tint/type/texture.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::BuiltinPolyfillSpirv);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::BuiltinPolyfillSpirv::LiteralOperand);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::BuiltinPolyfillSpirv::SampledImage);
using namespace tint::number_suffixes; // NOLINT
@@ -52,6 +59,12 @@
switch (builtin->Func()) {
case builtin::Function::kDot:
case builtin::Function::kSelect:
+ case builtin::Function::kTextureSample:
+ case builtin::Function::kTextureSampleBias:
+ case builtin::Function::kTextureSampleCompare:
+ case builtin::Function::kTextureSampleCompareLevel:
+ case builtin::Function::kTextureSampleGrad:
+ case builtin::Function::kTextureSampleLevel:
worklist.Push(builtin);
break;
default:
@@ -70,6 +83,14 @@
case builtin::Function::kSelect:
replacement = Select(builtin);
break;
+ case builtin::Function::kTextureSample:
+ case builtin::Function::kTextureSampleBias:
+ case builtin::Function::kTextureSampleCompare:
+ case builtin::Function::kTextureSampleCompareLevel:
+ case builtin::Function::kTextureSampleGrad:
+ case builtin::Function::kTextureSampleLevel:
+ replacement = TextureSample(builtin);
+ break;
default:
break;
}
@@ -152,10 +173,178 @@
call->InsertBefore(builtin);
return call->Result();
}
+
+ /// Handle a textureSample*() builtin.
+ /// @param builtin the builtin call instruction
+ /// @returns the replacement value
+ Value* TextureSample(CoreBuiltinCall* builtin) {
+ // Helper to get the next argument from the call, or nullptr if there are no more arguments.
+ uint32_t arg_idx = 0;
+ auto next_arg = [&]() {
+ return arg_idx < builtin->Args().Length() ? builtin->Args()[arg_idx++] : nullptr;
+ };
+
+ auto* texture = next_arg();
+ auto* sampler = next_arg();
+ auto* coords = next_arg();
+ auto* texture_ty = texture->Type()->As<type::Texture>();
+ auto* array_idx = IsTextureArray(texture_ty->dim()) ? next_arg() : nullptr;
+ Value* depth = nullptr;
+
+ // Use OpSampledImage to create an OpTypeSampledImage object.
+ auto* sampled_image =
+ b.Call(ty.Get<SampledImage>(texture_ty), IntrinsicCall::Kind::kSpirvSampledImage,
+ utils::Vector{texture, sampler});
+ sampled_image->InsertBefore(builtin);
+
+ // Append the array index to the coordinates if provided.
+ if (array_idx) {
+ // Convert the index to an f32.
+ auto* array_idx_f32 = b.Convert(ty.f32(), array_idx);
+ array_idx_f32->InsertBefore(builtin);
+
+ // Construct a new coordinate vector.
+ auto num_coords = coords->Type()->As<type::Vector>()->Width();
+ auto* coord_ty = ty.vec(ty.f32(), num_coords + 1);
+ auto* construct = b.Construct(coord_ty, utils::Vector{coords, array_idx_f32->Result()});
+ construct->InsertBefore(builtin);
+ coords = construct->Result();
+ }
+
+ // Determine which SPIR-V intrinsic to use and which optional image operands are needed.
+ enum IntrinsicCall::Kind intrinsic;
+ struct ImageOperands {
+ Value* bias = nullptr;
+ Value* lod = nullptr;
+ Value* ddx = nullptr;
+ Value* ddy = nullptr;
+ Value* offset = nullptr;
+ Value* sample = nullptr;
+ } operands;
+ switch (builtin->Func()) {
+ case builtin::Function::kTextureSample:
+ intrinsic = IntrinsicCall::Kind::kSpirvImageSampleImplicitLod;
+ operands.offset = next_arg();
+ break;
+ case builtin::Function::kTextureSampleBias:
+ intrinsic = IntrinsicCall::Kind::kSpirvImageSampleImplicitLod;
+ operands.bias = next_arg();
+ operands.offset = next_arg();
+ break;
+ case builtin::Function::kTextureSampleCompare:
+ intrinsic = IntrinsicCall::Kind::kSpirvImageSampleDrefImplicitLod;
+ depth = next_arg();
+ operands.offset = next_arg();
+ break;
+ case builtin::Function::kTextureSampleCompareLevel:
+ intrinsic = IntrinsicCall::Kind::kSpirvImageSampleDrefExplicitLod;
+ depth = next_arg();
+ operands.lod = b.Constant(0_f);
+ operands.offset = next_arg();
+ break;
+ case builtin::Function::kTextureSampleGrad:
+ intrinsic = IntrinsicCall::Kind::kSpirvImageSampleExplicitLod;
+ operands.ddx = next_arg();
+ operands.ddy = next_arg();
+ operands.offset = next_arg();
+ break;
+ case builtin::Function::kTextureSampleLevel:
+ intrinsic = IntrinsicCall::Kind::kSpirvImageSampleExplicitLod;
+ operands.lod = next_arg();
+ operands.offset = next_arg();
+ break;
+ default:
+ return nullptr;
+ }
+
+ // Start building the argument list for the intrinsic.
+ // The first two operands are always the sampled image and then the coordinates, followed by
+ // the depth reference if used.
+ utils::Vector<Value*, 8> intrinsic_args;
+ intrinsic_args.Push(sampled_image->Result());
+ intrinsic_args.Push(coords);
+ if (depth) {
+ intrinsic_args.Push(depth);
+ }
+
+ // Add a placeholder argument for the image operand mask, which we'll fill in when we've
+ // processed the image operands.
+ uint32_t image_operand_mask = 0u;
+ size_t mask_idx = intrinsic_args.Length();
+ intrinsic_args.Push(nullptr);
+
+ // Add each of the optional image operands if used, updating the image operand mask.
+ if (operands.bias) {
+ image_operand_mask |= SpvImageOperandsBiasMask;
+ intrinsic_args.Push(operands.bias);
+ }
+ if (operands.lod) {
+ image_operand_mask |= SpvImageOperandsLodMask;
+ if (operands.lod->Type()->is_integer_scalar()) {
+ // Some builtins take the lod as an integer, but SPIR-V always requires an f32.
+ auto* convert = b.Convert(ty.f32(), operands.lod);
+ convert->InsertBefore(builtin);
+ operands.lod = convert->Result();
+ }
+ intrinsic_args.Push(operands.lod);
+ }
+ if (operands.ddx) {
+ image_operand_mask |= SpvImageOperandsGradMask;
+ intrinsic_args.Push(operands.ddx);
+ intrinsic_args.Push(operands.ddy);
+ }
+ if (operands.offset) {
+ image_operand_mask |= SpvImageOperandsConstOffsetMask;
+ intrinsic_args.Push(operands.offset);
+ }
+ if (operands.sample) {
+ image_operand_mask |= SpvImageOperandsSampleMask;
+ intrinsic_args.Push(operands.sample);
+ }
+
+ // Replace the image operand mask with the final mask value, as a literal operand.
+ auto* literal = ir->constant_values.Get(u32(image_operand_mask));
+ intrinsic_args[mask_idx] = ir->values.Create<LiteralOperand>(literal);
+
+ // Call the intrinsic.
+ // If this is a depth comparison, the result is always f32, otherwise vec4f.
+ auto* result_ty = depth ? static_cast<const type::Type*>(ty.f32()) : ty.vec4<f32>();
+ auto* texture_call = b.Call(result_ty, intrinsic, std::move(intrinsic_args));
+ texture_call->InsertBefore(builtin);
+
+ auto* result = texture_call->Result();
+
+ // If this is not a depth comparison but we are sampling a depth texture, extract the first
+ // component to get the scalar f32 that SPIR-V expects.
+ if (!depth && texture_ty->IsAnyOf<type::DepthTexture, type::DepthMultisampledTexture>()) {
+ auto* extract = b.Access(ty.f32(), result, 0_u);
+ extract->InsertBefore(builtin);
+ result = extract->Result();
+ }
+
+ return result;
+ }
};
void BuiltinPolyfillSpirv::Run(ir::Module* ir, const DataMap&, DataMap&) const {
State{ir}.Process();
}
+BuiltinPolyfillSpirv::LiteralOperand::LiteralOperand(const constant::Value* value) : Base(value) {}
+
+BuiltinPolyfillSpirv::LiteralOperand::~LiteralOperand() = default;
+
+BuiltinPolyfillSpirv::SampledImage::SampledImage(const type::Type* image)
+ : Base(static_cast<size_t>(
+ utils::Hash(utils::TypeInfo::Of<BuiltinPolyfillSpirv::SampledImage>().full_hashcode,
+ image)),
+ type::Flags{}),
+ image_(image) {}
+
+BuiltinPolyfillSpirv::SampledImage* BuiltinPolyfillSpirv::SampledImage::Clone(
+ type::CloneContext& ctx) const {
+ auto* image = image_->Clone(ctx);
+ return ctx.dst.mgr->Get<BuiltinPolyfillSpirv::SampledImage>(image);
+}
+
} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/builtin_polyfill_spirv.h b/src/tint/ir/transform/builtin_polyfill_spirv.h
index 8909263..d323f68 100644
--- a/src/tint/ir/transform/builtin_polyfill_spirv.h
+++ b/src/tint/ir/transform/builtin_polyfill_spirv.h
@@ -15,7 +15,16 @@
#ifndef SRC_TINT_IR_TRANSFORM_BUILTIN_POLYFILL_SPIRV_H_
#define SRC_TINT_IR_TRANSFORM_BUILTIN_POLYFILL_SPIRV_H_
+#include <string>
+
+#include "src/tint/ir/constant.h"
#include "src/tint/ir/transform/transform.h"
+#include "src/tint/type/type.h"
+
+// Forward declarations
+namespace tint::type {
+class Texture;
+} // namespace tint::type
namespace tint::ir::transform {
@@ -31,6 +40,44 @@
/// @copydoc Transform::Run
void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+ /// LiteralOperand is a type of constant value that is intended to be emitted as a literal in
+ /// the SPIR-V instruction stream.
+ class LiteralOperand final : public utils::Castable<LiteralOperand, ir::Constant> {
+ public:
+ /// Constructor
+ /// @param value the operand value
+ explicit LiteralOperand(const constant::Value* value);
+ /// Destructor
+ ~LiteralOperand() override;
+ };
+
+ /// SampledImage represents an OpTypeSampledImage in SPIR-V.
+ class SampledImage final : public utils::Castable<SampledImage, type::Type> {
+ public:
+ /// Constructor
+ /// @param image the image type
+ explicit SampledImage(const type::Type* image);
+
+ /// @param other the other node to compare against
+ /// @returns true if the this type is equal to @p other
+ bool Equals(const UniqueNode& other) const override {
+ return &other.TypeInfo() == &TypeInfo();
+ }
+
+ /// @returns the friendly name for this type
+ std::string FriendlyName() const override { return "spirv.sampled_image"; }
+
+ /// @param ctx the clone context
+ /// @returns a clone of this type
+ SampledImage* Clone(type::CloneContext& ctx) const override;
+
+ /// @returns the image type
+ const type::Type* Image() const { return image_; }
+
+ private:
+ const type::Type* image_;
+ };
+
private:
struct State;
};
diff --git a/src/tint/ir/transform/builtin_polyfill_spirv_test.cc b/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
index 007a702..9d2c8e2 100644
--- a/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
+++ b/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
@@ -17,6 +17,8 @@
#include <utility>
#include "src/tint/ir/transform/test_helper.h"
+#include "src/tint/type/depth_texture.h"
+#include "src/tint/type/sampled_texture.h"
namespace tint::ir::transform {
namespace {
@@ -260,5 +262,787 @@
EXPECT_EQ(expect, str());
}
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSample_1D) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k1d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.f32(), builtin::Function::kTextureSample, t, s, coords);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_1d<f32>, %s:sampler, %coords:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:f32 = textureSample %t, %s, %coords
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_1d<f32>, %s:sampler, %coords:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:spirv.sampled_image = spirv.sampled_image %t, %s
+ %6:vec4<f32> = spirv.image_sample_implicit_lod %5, %coords, 0u
+ ret %6
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSample_2D) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.vec4<f32>(), builtin::Function::kTextureSample, t, s, coords);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:vec4<f32> = textureSample %t, %s, %coords
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:spirv.sampled_image = spirv.sampled_image %t, %s
+ %6:vec4<f32> = spirv.image_sample_implicit_lod %5, %coords, 0u
+ ret %6
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSample_2D_Offset) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSample, t, s, coords,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:vec4<f32> = textureSample %t, %s, %coords, vec2<i32>(1i)
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:spirv.sampled_image = spirv.sampled_image %t, %s
+ %6:vec4<f32> = spirv.image_sample_implicit_lod %5, %coords, 8u, vec2<i32>(1i)
+ ret %6
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSample_2DArray_Offset) {
+ auto* t = b.FunctionParam(
+ "t", ty.Get<type::SampledTexture>(type::TextureDimension::k2dArray, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* array_idx = b.FunctionParam("array_idx", ty.i32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, array_idx});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSample, t, s, coords, array_idx,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:vec4<f32> = textureSample %t, %s, %coords, %array_idx, vec2<i32>(1i)
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:f32 = convert %array_idx
+ %8:vec3<f32> = construct %coords, %7
+ %9:vec4<f32> = spirv.image_sample_implicit_lod %6, %8, 8u, vec2<i32>(1i)
+ ret %9
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleBias_2D) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* bias = b.FunctionParam("bias", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, bias});
+
+ b.With(func->Block(), [&] {
+ auto* result =
+ b.Call(ty.vec4<f32>(), builtin::Function::kTextureSampleBias, t, s, coords, bias);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %bias:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:vec4<f32> = textureSampleBias %t, %s, %coords, %bias
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %bias:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:vec4<f32> = spirv.image_sample_implicit_lod %6, %coords, 1u, %bias
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleBias_2D_Offset) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* bias = b.FunctionParam("bias", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, bias});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSampleBias, t, s, coords, bias,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %bias:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:vec4<f32> = textureSampleBias %t, %s, %coords, %bias, vec2<i32>(1i)
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %bias:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:vec4<f32> = spirv.image_sample_implicit_lod %6, %coords, 9u, %bias, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleBias_2DArray_Offset) {
+ auto* t = b.FunctionParam(
+ "t", ty.Get<type::SampledTexture>(type::TextureDimension::k2dArray, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* array_idx = b.FunctionParam("array_idx", ty.i32());
+ auto* bias = b.FunctionParam("bias", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, array_idx, bias});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSampleBias, t, s, coords, array_idx, bias,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %bias:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:vec4<f32> = textureSampleBias %t, %s, %coords, %array_idx, %bias, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %bias:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:spirv.sampled_image = spirv.sampled_image %t, %s
+ %8:f32 = convert %array_idx
+ %9:vec3<f32> = construct %coords, %8
+ %10:vec4<f32> = spirv.image_sample_implicit_lod %7, %9, 9u, %bias, vec2<i32>(1i)
+ ret %10
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleCompare_2D) {
+ auto* t = b.FunctionParam("t", ty.Get<type::DepthTexture>(type::TextureDimension::k2d));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* dref = b.FunctionParam("dref", ty.f32());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({t, s, coords, dref});
+
+ b.With(func->Block(), [&] {
+ auto* result =
+ b.Call(ty.f32(), builtin::Function::kTextureSampleCompare, t, s, coords, dref);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:f32 = textureSampleCompare %t, %s, %coords, %dref
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:f32 = spirv.image_sample_dref_implicit_lod %6, %coords, %dref, 0u
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleCompare_2D_Offset) {
+ auto* t = b.FunctionParam("t", ty.Get<type::DepthTexture>(type::TextureDimension::k2d));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* dref = b.FunctionParam("dref", ty.f32());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({t, s, coords, dref});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.f32(), builtin::Function::kTextureSampleCompare, t, s, coords, dref,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:f32 = textureSampleCompare %t, %s, %coords, %dref, vec2<i32>(1i)
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:f32 = spirv.image_sample_dref_implicit_lod %6, %coords, %dref, 8u, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleCompare_2DArray_Offset) {
+ auto* t = b.FunctionParam("t", ty.Get<type::DepthTexture>(type::TextureDimension::k2dArray));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* array_idx = b.FunctionParam("array_idx", ty.i32());
+ auto* bias = b.FunctionParam("bias", ty.f32());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({t, s, coords, array_idx, bias});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.f32(), builtin::Function::kTextureSampleCompare, t, s, coords, array_idx, bias,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_depth_2d_array, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %bias:f32):f32 -> %b1 {
+ %b1 = block {
+ %7:f32 = textureSampleCompare %t, %s, %coords, %array_idx, %bias, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_depth_2d_array, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %bias:f32):f32 -> %b1 {
+ %b1 = block {
+ %7:spirv.sampled_image = spirv.sampled_image %t, %s
+ %8:f32 = convert %array_idx
+ %9:vec3<f32> = construct %coords, %8
+ %10:f32 = spirv.image_sample_dref_implicit_lod %7, %9, %bias, 8u, vec2<i32>(1i)
+ ret %10
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleCompareLevel_2D) {
+ auto* t = b.FunctionParam("t", ty.Get<type::DepthTexture>(type::TextureDimension::k2d));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* dref = b.FunctionParam("dref", ty.f32());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({t, s, coords, dref});
+
+ b.With(func->Block(), [&] {
+ auto* result =
+ b.Call(ty.f32(), builtin::Function::kTextureSampleCompareLevel, t, s, coords, dref);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:f32 = textureSampleCompareLevel %t, %s, %coords, %dref
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:f32 = spirv.image_sample_dref_implicit_lod %6, %coords, %dref, 2u, 0.0f
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleCompareLevel_2D_Offset) {
+ auto* t = b.FunctionParam("t", ty.Get<type::DepthTexture>(type::TextureDimension::k2d));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* dref = b.FunctionParam("dref", ty.f32());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({t, s, coords, dref});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.f32(), builtin::Function::kTextureSampleCompareLevel, t, s, coords, dref,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:f32 = textureSampleCompareLevel %t, %s, %coords, %dref, vec2<i32>(1i)
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_depth_2d, %s:sampler, %coords:vec2<f32>, %dref:f32):f32 -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:f32 = spirv.image_sample_dref_implicit_lod %6, %coords, %dref, 10u, 0.0f, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleCompareLevel_2DArray_Offset) {
+ auto* t = b.FunctionParam("t", ty.Get<type::DepthTexture>(type::TextureDimension::k2dArray));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* array_idx = b.FunctionParam("array_idx", ty.i32());
+ auto* bias = b.FunctionParam("bias", ty.f32());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({t, s, coords, array_idx, bias});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.f32(), builtin::Function::kTextureSampleCompareLevel, t, s, coords, array_idx, bias,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_depth_2d_array, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %bias:f32):f32 -> %b1 {
+ %b1 = block {
+ %7:f32 = textureSampleCompareLevel %t, %s, %coords, %array_idx, %bias, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_depth_2d_array, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %bias:f32):f32 -> %b1 {
+ %b1 = block {
+ %7:spirv.sampled_image = spirv.sampled_image %t, %s
+ %8:f32 = convert %array_idx
+ %9:vec3<f32> = construct %coords, %8
+ %10:f32 = spirv.image_sample_dref_implicit_lod %7, %9, %bias, 10u, 0.0f, vec2<i32>(1i)
+ ret %10
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleGrad_2D) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* ddx = b.FunctionParam("ddx", ty.vec2<f32>());
+ auto* ddy = b.FunctionParam("ddy", ty.vec2<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, ddx, ddy});
+
+ b.With(func->Block(), [&] {
+ auto* result =
+ b.Call(ty.vec4<f32>(), builtin::Function::kTextureSampleBias, t, s, coords, ddx, ddy);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %ddx:vec2<f32>, %ddy:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:vec4<f32> = textureSampleBias %t, %s, %coords, %ddx, %ddy
+ ret %7
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %ddx:vec2<f32>, %ddy:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:spirv.sampled_image = spirv.sampled_image %t, %s
+ %8:vec4<f32> = spirv.image_sample_implicit_lod %7, %coords, 9u, %ddx, %ddy
+ ret %8
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleGrad_2D_Offset) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* ddx = b.FunctionParam("ddx", ty.vec2<f32>());
+ auto* ddy = b.FunctionParam("ddy", ty.vec2<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, ddx, ddy});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSampleBias, t, s, coords, ddx, ddy,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %ddx:vec2<f32>, %ddy:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:vec4<f32> = textureSampleBias %t, %s, %coords, %ddx, %ddy, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %ddx:vec2<f32>, %ddy:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:spirv.sampled_image = spirv.sampled_image %t, %s
+ %8:vec4<f32> = spirv.image_sample_implicit_lod %7, %coords, 9u, %ddx, %ddy
+ ret %8
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleGrad_2DArray_Offset) {
+ auto* t = b.FunctionParam(
+ "t", ty.Get<type::SampledTexture>(type::TextureDimension::k2dArray, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* array_idx = b.FunctionParam("array_idx", ty.i32());
+ auto* ddx = b.FunctionParam("ddx", ty.vec2<f32>());
+ auto* ddy = b.FunctionParam("ddy", ty.vec2<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, array_idx, ddx, ddy});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSampleBias, t, s, coords, array_idx, ddx,
+ ddy,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %ddx:vec2<f32>, %ddy:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %8:vec4<f32> = textureSampleBias %t, %s, %coords, %array_idx, %ddx, %ddy, vec2<i32>(1i)
+ ret %8
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %ddx:vec2<f32>, %ddy:vec2<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %8:spirv.sampled_image = spirv.sampled_image %t, %s
+ %9:f32 = convert %array_idx
+ %10:vec3<f32> = construct %coords, %9
+ %11:vec4<f32> = spirv.image_sample_implicit_lod %8, %10, 9u, %ddx, %ddy
+ ret %11
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleLevel_2D) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* lod = b.FunctionParam("lod", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, lod});
+
+ b.With(func->Block(), [&] {
+ auto* result =
+ b.Call(ty.vec4<f32>(), builtin::Function::kTextureSampleLevel, t, s, coords, lod);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %lod:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:vec4<f32> = textureSampleLevel %t, %s, %coords, %lod
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %lod:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:vec4<f32> = spirv.image_sample_explicit_lod %6, %coords, 2u, %lod
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleLevel_2D_Offset) {
+ auto* t =
+ b.FunctionParam("t", ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* lod = b.FunctionParam("lod", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, lod});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSampleLevel, t, s, coords, lod,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %lod:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:vec4<f32> = textureSampleLevel %t, %s, %coords, %lod, vec2<i32>(1i)
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d<f32>, %s:sampler, %coords:vec2<f32>, %lod:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %6:spirv.sampled_image = spirv.sampled_image %t, %s
+ %7:vec4<f32> = spirv.image_sample_explicit_lod %6, %coords, 10u, %lod, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, TextureSampleLevel_2DArray_Offset) {
+ auto* t = b.FunctionParam(
+ "t", ty.Get<type::SampledTexture>(type::TextureDimension::k2dArray, ty.f32()));
+ auto* s = b.FunctionParam("s", ty.sampler());
+ auto* coords = b.FunctionParam("coords", ty.vec2<f32>());
+ auto* array_idx = b.FunctionParam("array_idx", ty.i32());
+ auto* lod = b.FunctionParam("lod", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({t, s, coords, array_idx, lod});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(
+ ty.vec4<f32>(), builtin::Function::kTextureSampleLevel, t, s, coords, array_idx, lod,
+ b.Constant(mod.constant_values.Splat(ty.vec2<i32>(), mod.constant_values.Get(1_i), 2)));
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %lod:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:vec4<f32> = textureSampleLevel %t, %s, %coords, %array_idx, %lod, vec2<i32>(1i)
+ ret %7
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%t:texture_2d_array<f32>, %s:sampler, %coords:vec2<f32>, %array_idx:i32, %lod:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %7:spirv.sampled_image = spirv.sampled_image %t, %s
+ %8:f32 = convert %array_idx
+ %9:vec3<f32> = construct %coords, %8
+ %10:vec4<f32> = spirv.image_sample_explicit_lod %7, %9, 10u, %lod, vec2<i32>(1i)
+ ret %10
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/expand_implicit_splats.cc b/src/tint/ir/transform/expand_implicit_splats.cc
new file mode 100644
index 0000000..79b43cf
--- /dev/null
+++ b/src/tint/ir/transform/expand_implicit_splats.cc
@@ -0,0 +1,132 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/expand_implicit_splats.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ExpandImplicitSplats);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+ExpandImplicitSplats::ExpandImplicitSplats() = default;
+
+ExpandImplicitSplats::~ExpandImplicitSplats() = default;
+
+void ExpandImplicitSplats::Run(ir::Module* ir, const DataMap&, DataMap&) const {
+ ir::Builder b(*ir);
+
+ // Find the instructions that use implicit splats and either modify them in place or record them
+ // to be replaced in a second pass.
+ utils::Vector<Binary*, 4> binary_worklist;
+ utils::Vector<CoreBuiltinCall*, 4> builtin_worklist;
+ for (auto* inst : ir->instructions.Objects()) {
+ if (!inst->Alive()) {
+ continue;
+ }
+ if (auto* construct = inst->As<Construct>()) {
+ // A vector constructor with a single scalar argument needs to be modified to replicate
+ // the argument N times.
+ auto* vec = construct->Result()->Type()->As<type::Vector>();
+ if (vec && //
+ construct->Args().Length() == 1 &&
+ construct->Args()[0]->Type()->Is<type::Scalar>()) {
+ for (uint32_t i = 1; i < vec->Width(); i++) {
+ construct->AppendArg(construct->Args()[0]);
+ }
+ }
+ } else if (auto* binary = inst->As<Binary>()) {
+ // A binary instruction that mixes vector and scalar operands needs to have the scalar
+ // operand replaced with an explicit vector constructor.
+ if (binary->Result()->Type()->Is<type::Vector>()) {
+ if (binary->LHS()->Type()->Is<type::Scalar>() ||
+ binary->RHS()->Type()->Is<type::Scalar>()) {
+ binary_worklist.Push(binary);
+ }
+ }
+ } else if (auto* builtin = inst->As<CoreBuiltinCall>()) {
+ // A mix builtin call that mixes vector and scalar operands needs to have the scalar
+ // operand replaced with an explicit vector constructor.
+ if (builtin->Func() == builtin::Function::kMix) {
+ if (builtin->Result()->Type()->Is<type::Vector>()) {
+ if (builtin->Args()[2]->Type()->Is<type::Scalar>()) {
+ builtin_worklist.Push(builtin);
+ }
+ }
+ }
+ }
+ }
+
+ // Helper to expand a scalar operand of an instruction by replacing it with an explicitly
+ // constructed vector that matches the result type.
+ auto expand_operand = [&](Instruction* inst, size_t operand_idx) {
+ auto* vec = inst->Result()->Type()->As<type::Vector>();
+
+ utils::Vector<Value*, 4> args;
+ args.Resize(vec->Width(), inst->Operands()[operand_idx]);
+
+ auto* construct = b.Construct(vec, std::move(args));
+ construct->InsertBefore(inst);
+ inst->SetOperand(operand_idx, construct->Result());
+ };
+
+ // Replace scalar operands to binary instructions that produce vectors.
+ for (auto* binary : binary_worklist) {
+ auto* result_ty = binary->Result()->Type();
+ if (result_ty->is_float_vector() && binary->Kind() == Binary::Kind::kMultiply) {
+ // Use OpVectorTimesScalar for floating point multiply.
+ auto* vts = b.Call(result_ty, IntrinsicCall::Kind::kSpirvVectorTimesScalar);
+ if (binary->LHS()->Type()->Is<type::Scalar>()) {
+ vts->AppendArg(binary->RHS());
+ vts->AppendArg(binary->LHS());
+ } else {
+ vts->AppendArg(binary->LHS());
+ vts->AppendArg(binary->RHS());
+ }
+ if (auto name = ir->NameOf(binary)) {
+ ir->SetName(vts->Result(), name);
+ }
+ binary->Result()->ReplaceAllUsesWith(vts->Result());
+ binary->ReplaceWith(vts);
+ binary->Destroy();
+ } else {
+ // Expand the scalar argument into an explicitly constructed vector.
+ if (binary->LHS()->Type()->Is<type::Scalar>()) {
+ expand_operand(binary, Binary::kLhsOperandOffset);
+ } else if (binary->RHS()->Type()->Is<type::Scalar>()) {
+ expand_operand(binary, Binary::kRhsOperandOffset);
+ }
+ }
+ }
+
+ // Replace scalar arguments to builtin calls that produce vectors.
+ for (auto* builtin : builtin_worklist) {
+ switch (builtin->Func()) {
+ case builtin::Function::kMix:
+ // Expand the scalar argument into an explicitly constructed vector.
+ expand_operand(builtin, CoreBuiltinCall::kArgsOperandOffset + 2);
+ break;
+ default:
+ TINT_ASSERT(Transform, false && "unhandled builtin call");
+ break;
+ }
+ }
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/expand_implicit_splats.h b/src/tint/ir/transform/expand_implicit_splats.h
new file mode 100644
index 0000000..ec1ae18
--- /dev/null
+++ b/src/tint/ir/transform/expand_implicit_splats.h
@@ -0,0 +1,37 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_EXPAND_IMPLICIT_SPLATS_H_
+#define SRC_TINT_IR_TRANSFORM_EXPAND_IMPLICIT_SPLATS_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+namespace tint::ir::transform {
+
+/// ExpandImplicitSplats is a transform that expands implicit vector splat operands in construct
+/// instructions and binary instructions where not supported by SPIR-V.
+class ExpandImplicitSplats final : public utils::Castable<ExpandImplicitSplats, Transform> {
+ public:
+ /// Constructor
+ ExpandImplicitSplats();
+ /// Destructor
+ ~ExpandImplicitSplats() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_EXPAND_IMPLICIT_SPLATS_H_
diff --git a/src/tint/ir/transform/expand_implicit_splats_test.cc b/src/tint/ir/transform/expand_implicit_splats_test.cc
new file mode 100644
index 0000000..2e36c98
--- /dev/null
+++ b/src/tint/ir/transform/expand_implicit_splats_test.cc
@@ -0,0 +1,672 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/expand_implicit_splats.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_ExpandImplicitSplatsTest = TransformTest;
+
+TEST_F(IR_ExpandImplicitSplatsTest, NoModify_Construct_VectorIdentity) {
+ auto* vector = b.FunctionParam("vector", ty.vec2<i32>());
+ auto* func = b.Function("foo", ty.vec2<i32>());
+ func->SetParams({vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.vec2<i32>(), vector);
+ b.Return(func, result);
+ });
+
+ auto* expect = R"(
+%foo = func(%vector:vec2<i32>):vec2<i32> -> %b1 {
+ %b1 = block {
+ %3:vec2<i32> = construct %vector
+ ret %3
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, NoModify_Construct_MixedScalarVector) {
+ auto* scalar = b.FunctionParam("scalar", ty.i32());
+ auto* vector = b.FunctionParam("vector", ty.vec2<i32>());
+ auto* func = b.Function("foo", ty.vec3<i32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.vec3<i32>(), scalar, vector);
+ b.Return(func, result);
+ });
+
+ auto* expect = R"(
+%foo = func(%scalar:i32, %vector:vec2<i32>):vec3<i32> -> %b1 {
+ %b1 = block {
+ %4:vec3<i32> = construct %scalar, %vector
+ ret %4
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, NoModify_Construct_AllScalars) {
+ auto* scalar = b.FunctionParam("scalar", ty.i32());
+ auto* func = b.Function("foo", ty.vec3<i32>());
+ func->SetParams({scalar});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.vec3<i32>(), scalar, scalar, scalar);
+ b.Return(func, result);
+ });
+
+ auto* expect = R"(
+%foo = func(%scalar:i32):vec3<i32> -> %b1 {
+ %b1 = block {
+ %3:vec3<i32> = construct %scalar, %scalar, %scalar
+ ret %3
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, Construct_Splat_Vec2i) {
+ auto* scalar = b.FunctionParam("scalar", ty.i32());
+ auto* func = b.Function("foo", ty.vec2<i32>());
+ func->SetParams({scalar});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.vec2<i32>(), scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:i32):vec2<i32> -> %b1 {
+ %b1 = block {
+ %3:vec2<i32> = construct %scalar
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:i32):vec2<i32> -> %b1 {
+ %b1 = block {
+ %3:vec2<i32> = construct %scalar, %scalar
+ ret %3
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, Construct_Splat_Vec3u) {
+ auto* scalar = b.FunctionParam("scalar", ty.u32());
+ auto* func = b.Function("foo", ty.vec3<u32>());
+ func->SetParams({scalar});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.vec3<u32>(), scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:u32):vec3<u32> -> %b1 {
+ %b1 = block {
+ %3:vec3<u32> = construct %scalar
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:u32):vec3<u32> -> %b1 {
+ %b1 = block {
+ %3:vec3<u32> = construct %scalar, %scalar, %scalar
+ ret %3
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, Construct_Splat_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.vec4<f32>(), scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %3:vec4<f32> = construct %scalar
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %3:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ ret %3
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryAdd_VectorScalar_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Add(ty.vec4<f32>(), vector, scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = add %vector, %scalar
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = add %vector, %4
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryAdd_ScalarVector_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Add(ty.vec4<f32>(), scalar, vector);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = add %scalar, %vector
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = add %4, %vector
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinarySubtract_VectorScalar_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Subtract(ty.vec4<f32>(), vector, scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = sub %vector, %scalar
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = sub %vector, %4
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinarySubtract_ScalarVector_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Subtract(ty.vec4<f32>(), scalar, vector);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = sub %scalar, %vector
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = sub %4, %vector
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryDivide_VectorScalar_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Divide(ty.vec4<f32>(), vector, scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = div %vector, %scalar
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = div %vector, %4
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryDivide_ScalarVector_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Divide(ty.vec4<f32>(), scalar, vector);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = div %scalar, %vector
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = div %4, %vector
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryModulo_VectorScalar_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Modulo(ty.vec4<f32>(), vector, scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = mod %vector, %scalar
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = mod %vector, %4
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryModulo_ScalarVector_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Modulo(ty.vec4<f32>(), scalar, vector);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = mod %scalar, %vector
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<f32> = mod %4, %vector
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_VectorScalar_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<f32>(), vector, scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = mul %vector, %scalar
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = spirv.vector_times_scalar %vector, %scalar
+ ret %4
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_ScalarVector_Vec4f) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<f32>(), scalar, vector);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = mul %scalar, %vector
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = spirv.vector_times_scalar %vector, %scalar
+ ret %4
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_VectorScalar_Vec4i) {
+ auto* scalar = b.FunctionParam("scalar", ty.i32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<i32>());
+ auto* func = b.Function("foo", ty.vec4<i32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<i32>(), vector, scalar);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+ %b1 = block {
+ %4:vec4<i32> = mul %vector, %scalar
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+ %b1 = block {
+ %4:vec4<i32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<i32> = mul %vector, %4
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_ScalarVector_Vec4i) {
+ auto* scalar = b.FunctionParam("scalar", ty.i32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<i32>());
+ auto* func = b.Function("foo", ty.vec4<i32>());
+ func->SetParams({scalar, vector});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<i32>(), scalar, vector);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+ %b1 = block {
+ %4:vec4<i32> = mul %scalar, %vector
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+ %b1 = block {
+ %4:vec4<i32> = construct %scalar, %scalar, %scalar, %scalar
+ %5:vec4<i32> = mul %4, %vector
+ ret %5
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, Mix_VectorOperands_ScalarFactor) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec4<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec4<f32>());
+ auto* factor = b.FunctionParam("factor", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({arg1, arg2, factor});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.vec4<f32>(), builtin::Function::kMix, arg1, arg2, factor);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:vec4<f32>, %arg2:vec4<f32>, %factor:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:vec4<f32> = mix %arg1, %arg2, %factor
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:vec4<f32>, %arg2:vec4<f32>, %factor:f32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %5:vec4<f32> = construct %factor, %factor, %factor, %factor
+ %6:vec4<f32> = mix %arg1, %arg2, %5
+ ret %6
+ }
+}
+)";
+
+ Run<ExpandImplicitSplats>();
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/handle_matrix_arithmetic.cc b/src/tint/ir/transform/handle_matrix_arithmetic.cc
new file mode 100644
index 0000000..6e91153
--- /dev/null
+++ b/src/tint/ir/transform/handle_matrix_arithmetic.cc
@@ -0,0 +1,118 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/handle_matrix_arithmetic.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+#include "src/tint/type/matrix.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::HandleMatrixArithmetic);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+HandleMatrixArithmetic::HandleMatrixArithmetic() = default;
+
+HandleMatrixArithmetic::~HandleMatrixArithmetic() = default;
+
+void HandleMatrixArithmetic::Run(ir::Module* ir, const DataMap&, DataMap&) const {
+ ir::Builder b(*ir);
+
+ // Find the instructions that needs to be modified.
+ utils::Vector<Binary*, 4> worklist;
+ for (auto* inst : ir->instructions.Objects()) {
+ if (!inst->Alive()) {
+ continue;
+ }
+ if (auto* binary = inst->As<Binary>()) {
+ TINT_ASSERT(Transform, binary->Operands().Length() == 2);
+ if (binary->LHS()->Type()->Is<type::Matrix>() ||
+ binary->RHS()->Type()->Is<type::Matrix>()) {
+ worklist.Push(binary);
+ }
+ }
+ }
+
+ // Replace the matrix arithmetic instructions that we found.
+ for (auto* binary : worklist) {
+ auto* lhs = binary->LHS();
+ auto* rhs = binary->RHS();
+ auto* lhs_ty = lhs->Type();
+ auto* rhs_ty = rhs->Type();
+ auto* ty = binary->Result()->Type();
+
+ // Helper to replace the instruction with a new one.
+ auto replace = [&](Instruction* inst) {
+ if (auto name = ir->NameOf(binary)) {
+ ir->SetName(inst->Result(), name);
+ }
+ binary->Result()->ReplaceAllUsesWith(inst->Result());
+ binary->ReplaceWith(inst);
+ binary->Destroy();
+ };
+
+ // Helper to replace the instruction with a column-wise operation.
+ auto column_wise = [&](enum Binary::Kind op) {
+ auto* mat = ty->As<type::Matrix>();
+ utils::Vector<Value*, 4> args;
+ for (uint32_t col = 0; col < mat->columns(); col++) {
+ auto* lhs_col = b.Access(mat->ColumnType(), lhs, u32(col));
+ lhs_col->InsertBefore(binary);
+ auto* rhs_col = b.Access(mat->ColumnType(), rhs, u32(col));
+ rhs_col->InsertBefore(binary);
+ auto* add = b.Binary(op, mat->ColumnType(), lhs_col, rhs_col);
+ add->InsertBefore(binary);
+ args.Push(add->Result());
+ }
+ replace(b.Construct(ty, std::move(args)));
+ };
+
+ switch (binary->Kind()) {
+ case Binary::Kind::kAdd:
+ column_wise(Binary::Kind::kAdd);
+ break;
+ case Binary::Kind::kSubtract:
+ column_wise(Binary::Kind::kSubtract);
+ break;
+ case Binary::Kind::kMultiply:
+ // Select the SPIR-V intrinsic that corresponds to the operation being performed.
+ if (lhs_ty->Is<type::Matrix>()) {
+ if (rhs_ty->Is<type::Scalar>()) {
+ replace(b.Call(ty, IntrinsicCall::Kind::kSpirvMatrixTimesScalar, lhs, rhs));
+ } else if (rhs_ty->Is<type::Vector>()) {
+ replace(b.Call(ty, IntrinsicCall::Kind::kSpirvMatrixTimesVector, lhs, rhs));
+ } else if (rhs_ty->Is<type::Matrix>()) {
+ replace(b.Call(ty, IntrinsicCall::Kind::kSpirvMatrixTimesMatrix, lhs, rhs));
+ }
+ } else {
+ if (lhs_ty->Is<type::Scalar>()) {
+ replace(b.Call(ty, IntrinsicCall::Kind::kSpirvMatrixTimesScalar, rhs, lhs));
+ } else if (lhs_ty->Is<type::Vector>()) {
+ replace(b.Call(ty, IntrinsicCall::Kind::kSpirvVectorTimesMatrix, lhs, rhs));
+ }
+ }
+ break;
+
+ default:
+ TINT_ASSERT(Transform, false && "unhandled matrix arithmetic instruction");
+ break;
+ }
+ }
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/handle_matrix_arithmetic.h b/src/tint/ir/transform/handle_matrix_arithmetic.h
new file mode 100644
index 0000000..3129dd3
--- /dev/null
+++ b/src/tint/ir/transform/handle_matrix_arithmetic.h
@@ -0,0 +1,37 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_HANDLE_MATRIX_ARITHMETIC_H_
+#define SRC_TINT_IR_TRANSFORM_HANDLE_MATRIX_ARITHMETIC_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+namespace tint::ir::transform {
+
+/// HandleMatrixArithmetic is a transform that converts arithmetic instruction that use matrix into
+/// SPIR-V intrinsics or polyfills.
+class HandleMatrixArithmetic final : public utils::Castable<HandleMatrixArithmetic, Transform> {
+ public:
+ /// Constructor
+ HandleMatrixArithmetic();
+ /// Destructor
+ ~HandleMatrixArithmetic() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_HANDLE_MATRIX_ARITHMETIC_H_
diff --git a/src/tint/ir/transform/handle_matrix_arithmetic_test.cc b/src/tint/ir/transform/handle_matrix_arithmetic_test.cc
new file mode 100644
index 0000000..759fdab
--- /dev/null
+++ b/src/tint/ir/transform/handle_matrix_arithmetic_test.cc
@@ -0,0 +1,414 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/handle_matrix_arithmetic.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+#include "src/tint/type/matrix.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_HandleMatrixArithmeticTest = TransformTest;
+
+TEST_F(IR_HandleMatrixArithmeticTest, Add_Mat2x3f) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat2x3<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat2x3<f32>());
+ auto* func = b.Function("foo", ty.mat2x3<f32>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Add(ty.mat2x3<f32>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat2x3<f32>, %arg2:mat2x3<f32>):mat2x3<f32> -> %b1 {
+ %b1 = block {
+ %4:mat2x3<f32> = add %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat2x3<f32>, %arg2:mat2x3<f32>):mat2x3<f32> -> %b1 {
+ %b1 = block {
+ %4:vec3<f32> = access %arg1, 0u
+ %5:vec3<f32> = access %arg2, 0u
+ %6:vec3<f32> = add %4, %5
+ %7:vec3<f32> = access %arg1, 1u
+ %8:vec3<f32> = access %arg2, 1u
+ %9:vec3<f32> = add %7, %8
+ %10:mat2x3<f32> = construct %6, %9
+ ret %10
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Add_Mat4x2h) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat4x2<f16>());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat4x2<f16>());
+ auto* func = b.Function("foo", ty.mat4x2<f16>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Add(ty.mat4x2<f16>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat4x2<f16>, %arg2:mat4x2<f16>):mat4x2<f16> -> %b1 {
+ %b1 = block {
+ %4:mat4x2<f16> = add %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat4x2<f16>, %arg2:mat4x2<f16>):mat4x2<f16> -> %b1 {
+ %b1 = block {
+ %4:vec2<f16> = access %arg1, 0u
+ %5:vec2<f16> = access %arg2, 0u
+ %6:vec2<f16> = add %4, %5
+ %7:vec2<f16> = access %arg1, 1u
+ %8:vec2<f16> = access %arg2, 1u
+ %9:vec2<f16> = add %7, %8
+ %10:vec2<f16> = access %arg1, 2u
+ %11:vec2<f16> = access %arg2, 2u
+ %12:vec2<f16> = add %10, %11
+ %13:vec2<f16> = access %arg1, 3u
+ %14:vec2<f16> = access %arg2, 3u
+ %15:vec2<f16> = add %13, %14
+ %16:mat4x2<f16> = construct %6, %9, %12, %15
+ ret %16
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Subtract_Mat3x2f) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat3x2<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat3x2<f32>());
+ auto* func = b.Function("foo", ty.mat3x2<f32>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Subtract(ty.mat3x2<f32>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat3x2<f32>, %arg2:mat3x2<f32>):mat3x2<f32> -> %b1 {
+ %b1 = block {
+ %4:mat3x2<f32> = sub %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat3x2<f32>, %arg2:mat3x2<f32>):mat3x2<f32> -> %b1 {
+ %b1 = block {
+ %4:vec2<f32> = access %arg1, 0u
+ %5:vec2<f32> = access %arg2, 0u
+ %6:vec2<f32> = sub %4, %5
+ %7:vec2<f32> = access %arg1, 1u
+ %8:vec2<f32> = access %arg2, 1u
+ %9:vec2<f32> = sub %7, %8
+ %10:vec2<f32> = access %arg1, 2u
+ %11:vec2<f32> = access %arg2, 2u
+ %12:vec2<f32> = sub %10, %11
+ %13:mat3x2<f32> = construct %6, %9, %12
+ ret %13
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Subtract_Mat2x4h) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat2x4<f16>());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat2x4<f16>());
+ auto* func = b.Function("foo", ty.mat2x4<f16>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Subtract(ty.mat2x4<f16>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat2x4<f16>, %arg2:mat2x4<f16>):mat2x4<f16> -> %b1 {
+ %b1 = block {
+ %4:mat2x4<f16> = sub %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat2x4<f16>, %arg2:mat2x4<f16>):mat2x4<f16> -> %b1 {
+ %b1 = block {
+ %4:vec4<f16> = access %arg1, 0u
+ %5:vec4<f16> = access %arg2, 0u
+ %6:vec4<f16> = sub %4, %5
+ %7:vec4<f16> = access %arg1, 1u
+ %8:vec4<f16> = access %arg2, 1u
+ %9:vec4<f16> = sub %7, %8
+ %10:mat2x4<f16> = construct %6, %9
+ ret %10
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Mul_Mat2x3f_Scalar) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat2x3<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.f32());
+ auto* func = b.Function("foo", ty.mat2x3<f32>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.mat2x3<f32>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat2x3<f32>, %arg2:f32):mat2x3<f32> -> %b1 {
+ %b1 = block {
+ %4:mat2x3<f32> = mul %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat2x3<f32>, %arg2:f32):mat2x3<f32> -> %b1 {
+ %b1 = block {
+ %4:mat2x3<f32> = spirv.matrix_times_scalar %arg1, %arg2
+ ret %4
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Mul_Mat3x4f_Vector) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat3x4<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec3<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<f32>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat3x4<f32>, %arg2:vec3<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = mul %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat3x4<f32>, %arg2:vec3<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %4:vec4<f32> = spirv.matrix_times_vector %arg1, %arg2
+ ret %4
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Mul_Mat4x2f_Mat2x4) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat4x2<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat2x4<f32>());
+ auto* func = b.Function("foo", ty.mat2x2<f32>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.mat2x2<f32>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat4x2<f32>, %arg2:mat2x4<f32>):mat2x2<f32> -> %b1 {
+ %b1 = block {
+ %4:mat2x2<f32> = mul %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat4x2<f32>, %arg2:mat2x4<f32>):mat2x2<f32> -> %b1 {
+ %b1 = block {
+ %4:mat2x2<f32> = spirv.matrix_times_matrix %arg1, %arg2
+ ret %4
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Mul_Scalar_Mat3x2h) {
+ auto* arg1 = b.FunctionParam("arg1", ty.f16());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat3x2<f16>());
+ auto* func = b.Function("foo", ty.mat3x2<f16>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.mat3x2<f16>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:f16, %arg2:mat3x2<f16>):mat3x2<f16> -> %b1 {
+ %b1 = block {
+ %4:mat3x2<f16> = mul %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:f16, %arg2:mat3x2<f16>):mat3x2<f16> -> %b1 {
+ %b1 = block {
+ %4:mat3x2<f16> = spirv.matrix_times_scalar %arg2, %arg1
+ ret %4
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Mul_Vector_Mat3x4f) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec3<f16>());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat4x3<f16>());
+ auto* func = b.Function("foo", ty.vec4<f16>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<f16>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:vec3<f16>, %arg2:mat4x3<f16>):vec4<f16> -> %b1 {
+ %b1 = block {
+ %4:vec4<f16> = mul %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:vec3<f16>, %arg2:mat4x3<f16>):vec4<f16> -> %b1 {
+ %b1 = block {
+ %4:vec4<f16> = spirv.vector_times_matrix %arg1, %arg2
+ ret %4
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_HandleMatrixArithmeticTest, Mul_Mat3x3f_Mat3x3) {
+ auto* arg1 = b.FunctionParam("arg1", ty.mat3x3<f16>());
+ auto* arg2 = b.FunctionParam("arg2", ty.mat3x3<f16>());
+ auto* func = b.Function("foo", ty.mat3x3<f16>());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.mat3x3<f16>(), arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:mat3x3<f16>, %arg2:mat3x3<f16>):mat3x3<f16> -> %b1 {
+ %b1 = block {
+ %4:mat3x3<f16> = mul %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:mat3x3<f16>, %arg2:mat3x3<f16>):mat3x3<f16> -> %b1 {
+ %b1 = block {
+ %4:mat3x3<f16> = spirv.matrix_times_matrix %arg1, %arg2
+ ret %4
+ }
+}
+)";
+
+ Run<HandleMatrixArithmetic>();
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/validator.cc b/src/tint/ir/validator.cc
index decc98f..86c8bc9 100644
--- a/src/tint/ir/validator.cc
+++ b/src/tint/ir/validator.cc
@@ -58,570 +58,552 @@
#include "src/tint/utils/scoped_assignment.h"
namespace tint::ir {
-namespace {
-class Validator {
- public:
- explicit Validator(Module& mod) : mod_(mod) {}
+Validator::Validator(Module& mod) : mod_(mod) {}
- ~Validator() {}
+Validator::~Validator() = default;
- utils::Result<Success, diag::List> IsValid() {
- CheckRootBlock(mod_.root_block);
+void Validator::DisassembleIfNeeded() {
+ if (mod_.disassembly_file) {
+ return;
+ }
+ mod_.disassembly_file = std::make_unique<Source::File>("", dis_.Disassemble());
+}
- for (auto* func : mod_.functions) {
- CheckFunction(func);
- }
+utils::Result<Success, diag::List> Validator::IsValid() {
+ CheckRootBlock(mod_.root_block);
- if (diagnostics_.contains_errors()) {
- DisassembleIfNeeded();
- diagnostics_.add_note(tint::diag::System::IR,
- "# Disassembly\n" + mod_.disassembly_file->content.data, {});
- return std::move(diagnostics_);
- }
- return Success{};
+ for (auto* func : mod_.functions) {
+ CheckFunction(func);
}
- private:
- Module& mod_;
- diag::List diagnostics_;
- Disassembler dis_{mod_};
-
- Block* current_block_ = nullptr;
- utils::Hashset<Function*, 4> seen_functions_;
- utils::Vector<ControlInstruction*, 8> control_stack_;
-
- void DisassembleIfNeeded() {
- if (mod_.disassembly_file) {
- return;
- }
- mod_.disassembly_file = std::make_unique<Source::File>("", dis_.Disassemble());
- }
-
- std::string InstError(Instruction* inst, std::string err) {
- return std::string(inst->FriendlyName()) + ": " + err;
- }
-
- void AddError(Instruction* inst, std::string err) {
+ if (diagnostics_.contains_errors()) {
DisassembleIfNeeded();
- auto src = dis_.InstructionSource(inst);
- src.file = mod_.disassembly_file.get();
- AddError(std::move(err), src);
+ diagnostics_.add_note(tint::diag::System::IR,
+ "# Disassembly\n" + mod_.disassembly_file->content.data, {});
+ return std::move(diagnostics_);
+ }
+ return Success{};
+}
- if (current_block_) {
- AddNote(current_block_, "In block");
- }
+std::string Validator::InstError(Instruction* inst, std::string err) {
+ return std::string(inst->FriendlyName()) + ": " + err;
+}
+
+void Validator::AddError(Instruction* inst, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.InstructionSource(inst);
+ src.file = mod_.disassembly_file.get();
+ AddError(std::move(err), src);
+
+ if (current_block_) {
+ AddNote(current_block_, "In block");
+ }
+}
+
+void Validator::AddError(Instruction* inst, size_t idx, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.OperandSource(Usage{inst, static_cast<uint32_t>(idx)});
+ src.file = mod_.disassembly_file.get();
+ AddError(std::move(err), src);
+
+ if (current_block_) {
+ AddNote(current_block_, "In block");
+ }
+}
+
+void Validator::AddResultError(Instruction* inst, size_t idx, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.ResultSource(Usage{inst, static_cast<uint32_t>(idx)});
+ src.file = mod_.disassembly_file.get();
+ AddError(std::move(err), src);
+
+ if (current_block_) {
+ AddNote(current_block_, "In block");
+ }
+}
+
+void Validator::AddError(Block* blk, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.BlockSource(blk);
+ src.file = mod_.disassembly_file.get();
+ AddError(std::move(err), src);
+}
+
+void Validator::AddNote(Instruction* inst, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.InstructionSource(inst);
+ src.file = mod_.disassembly_file.get();
+ AddNote(std::move(err), src);
+}
+
+void Validator::AddNote(Instruction* inst, size_t idx, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.OperandSource(Usage{inst, static_cast<uint32_t>(idx)});
+ src.file = mod_.disassembly_file.get();
+ AddNote(std::move(err), src);
+}
+
+void Validator::AddNote(Block* blk, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.BlockSource(blk);
+ src.file = mod_.disassembly_file.get();
+ AddNote(std::move(err), src);
+}
+
+void Validator::AddError(std::string err, Source src) {
+ diagnostics_.add_error(tint::diag::System::IR, std::move(err), src);
+}
+
+void Validator::AddNote(std::string note, Source src) {
+ diagnostics_.add_note(tint::diag::System::IR, std::move(note), src);
+}
+
+std::string Validator::Name(Value* v) {
+ return mod_.NameOf(v).Name();
+}
+
+void Validator::CheckOperandNotNull(ir::Instruction* inst, ir::Value* operand, size_t idx) {
+ if (operand == nullptr) {
+ AddError(inst, idx, InstError(inst, "operand is undefined"));
+ }
+}
+
+void Validator::CheckOperandsNotNull(ir::Instruction* inst,
+ size_t start_operand,
+ size_t end_operand) {
+ auto operands = inst->Operands();
+ for (size_t i = start_operand; i <= end_operand; i++) {
+ CheckOperandNotNull(inst, operands[i], i);
+ }
+}
+
+void Validator::CheckRootBlock(Block* blk) {
+ if (!blk) {
+ return;
}
- void AddError(Instruction* inst, size_t idx, std::string err) {
- DisassembleIfNeeded();
- auto src = dis_.OperandSource(Usage{inst, static_cast<uint32_t>(idx)});
- src.file = mod_.disassembly_file.get();
- AddError(std::move(err), src);
+ TINT_SCOPED_ASSIGNMENT(current_block_, blk);
- if (current_block_) {
- AddNote(current_block_, "In block");
+ for (auto* inst : *blk) {
+ auto* var = inst->As<ir::Var>();
+ if (!var) {
+ AddError(inst,
+ std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
+ continue;
}
+ CheckInstruction(var);
+ }
+}
+
+void Validator::CheckFunction(Function* func) {
+ if (!seen_functions_.Add(func)) {
+ AddError("function '" + Name(func) + "' added to module multiple times");
}
- void AddResultError(Instruction* inst, size_t idx, std::string err) {
- DisassembleIfNeeded();
- auto src = dis_.ResultSource(Usage{inst, static_cast<uint32_t>(idx)});
- src.file = mod_.disassembly_file.get();
- AddError(std::move(err), src);
+ CheckBlock(func->Block());
+}
- if (current_block_) {
- AddNote(current_block_, "In block");
- }
+void Validator::CheckBlock(Block* blk) {
+ TINT_SCOPED_ASSIGNMENT(current_block_, blk);
+
+ if (!blk->HasTerminator()) {
+ AddError(blk, "block: does not end in a terminator instruction");
}
- void AddError(Block* blk, std::string err) {
- DisassembleIfNeeded();
- auto src = dis_.BlockSource(blk);
- src.file = mod_.disassembly_file.get();
- AddError(std::move(err), src);
+ for (auto* inst : *blk) {
+ if (inst->Is<ir::Terminator>() && inst != blk->Terminator()) {
+ AddError(inst, "block: terminator which isn't the final instruction");
+ continue;
+ }
+
+ CheckInstruction(inst);
}
+}
- void AddNote(Instruction* inst, std::string err) {
- DisassembleIfNeeded();
- auto src = dis_.InstructionSource(inst);
- src.file = mod_.disassembly_file.get();
- AddNote(std::move(err), src);
+void Validator::CheckInstruction(Instruction* inst) {
+ if (!inst->Alive()) {
+ AddError(inst, InstError(inst, "destroyed instruction found in instruction list"));
+ return;
}
-
- void AddNote(Instruction* inst, size_t idx, std::string err) {
- DisassembleIfNeeded();
- auto src = dis_.OperandSource(Usage{inst, static_cast<uint32_t>(idx)});
- src.file = mod_.disassembly_file.get();
- AddNote(std::move(err), src);
- }
-
- void AddNote(Block* blk, std::string err) {
- DisassembleIfNeeded();
- auto src = dis_.BlockSource(blk);
- src.file = mod_.disassembly_file.get();
- AddNote(std::move(err), src);
- }
-
- void AddError(std::string err, Source src = {}) {
- diagnostics_.add_error(tint::diag::System::IR, std::move(err), src);
- }
-
- void AddNote(std::string note, Source src = {}) {
- diagnostics_.add_note(tint::diag::System::IR, std::move(note), src);
- }
-
- std::string Name(Value* v) { return mod_.NameOf(v).Name(); }
-
- void CheckOperandNotNull(ir::Instruction* inst, ir::Value* operand, size_t idx) {
- if (operand == nullptr) {
- AddError(inst, idx, InstError(inst, "operand is undefined"));
- }
- }
-
- void CheckOperandsNotNull(ir::Instruction* inst, size_t start_operand, size_t end_operand) {
- auto operands = inst->Operands();
- for (size_t i = start_operand; i <= end_operand; i++) {
- CheckOperandNotNull(inst, operands[i], i);
- }
- }
-
- void CheckRootBlock(Block* blk) {
- if (!blk) {
- return;
- }
-
- TINT_SCOPED_ASSIGNMENT(current_block_, blk);
-
- for (auto* inst : *blk) {
- auto* var = inst->As<ir::Var>();
- if (!var) {
- AddError(inst,
- std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
- continue;
- }
- CheckInstruction(var);
- }
- }
-
- void CheckFunction(Function* func) {
- if (!seen_functions_.Add(func)) {
- AddError("function '" + Name(func) + "' added to module multiple times");
- }
-
- CheckBlock(func->Block());
- }
-
- void CheckBlock(Block* blk) {
- TINT_SCOPED_ASSIGNMENT(current_block_, blk);
-
- if (!blk->HasTerminator()) {
- AddError(blk, "block: does not end in a terminator instruction");
- }
-
- for (auto* inst : *blk) {
- if (inst->Is<ir::Terminator>() && inst != blk->Terminator()) {
- AddError(inst, "block: terminator which isn't the final instruction");
- continue;
- }
-
- CheckInstruction(inst);
- }
- }
-
- void CheckInstruction(Instruction* inst) {
- if (!inst->Alive()) {
- AddError(inst, InstError(inst, "destroyed instruction found in instruction list"));
- return;
- }
- if (inst->HasResults()) {
- auto results = inst->Results();
- for (size_t i = 0; i < results.Length(); ++i) {
- auto* res = results[i];
- if (!res) {
- AddResultError(inst, i, InstError(inst, "instruction result is undefined"));
- continue;
- }
-
- if (res->Source() == nullptr) {
- AddResultError(inst, i,
- InstError(inst, "instruction result source is undefined"));
- } else if (res->Source() != inst) {
- AddResultError(
- inst, i,
- InstError(inst, "instruction result source has wrong instruction"));
- }
- }
- }
-
- auto ops = inst->Operands();
- for (size_t i = 0; i < ops.Length(); ++i) {
- auto* op = ops[i];
- if (!op) {
- continue;
- }
-
- // Note, a `nullptr` is a valid operand in some cases, like `var` so we can't just check
- // for `nullptr` here.
- if (!op->Alive()) {
- AddError(inst, i, InstError(inst, "instruction has operand which is not alive"));
- }
-
- if (!op->Usages().Contains({inst, i})) {
- AddError(inst, i, InstError(inst, "instruction operand missing usage"));
- }
- }
-
- tint::Switch(
- inst, //
- [&](Access* a) { CheckAccess(a); }, //
- [&](Binary* b) { CheckBinary(b); }, //
- [&](Call* c) { CheckCall(c); }, //
- [&](If* if_) { CheckIf(if_); }, //
- [&](Let* let) { CheckLet(let); }, //
- [&](Load*) {}, //
- [&](LoadVectorElement* l) { CheckLoadVectorElement(l); }, //
- [&](Loop* l) { CheckLoop(l); }, //
- [&](Store*) {}, //
- [&](StoreVectorElement* s) { CheckStoreVectorElement(s); }, //
- [&](Switch* s) { CheckSwitch(s); }, //
- [&](Swizzle*) {}, //
- [&](Terminator* b) { CheckTerminator(b); }, //
- [&](Unary* u) { CheckUnary(u); }, //
- [&](Var* var) { CheckVar(var); }, //
- [&](Default) { AddError(inst, InstError(inst, "missing validation")); });
- }
-
- void CheckVar(Var* var) {
- if (var->Result() && var->Initializer()) {
- if (var->Initializer()->Type() != var->Result()->Type()->UnwrapPtr()) {
- AddError(var, InstError(var, "initializer has incorrect type"));
- }
- }
- }
-
- void CheckLet(Let* let) {
- CheckOperandNotNull(let, let->Value(), Let::kValueOperandOffset);
-
- if (let->Result() && let->Value()) {
- if (let->Result()->Type() != let->Value()->Type()) {
- AddError(let, InstError(let, "result type does not match value type"));
- }
- }
- }
-
- void CheckCall(Call* call) {
- tint::Switch(
- call, //
- [&](Bitcast*) {}, //
- [&](CoreBuiltinCall*) {}, //
- [&](IntrinsicCall*) {}, //
- [&](Construct*) {}, //
- [&](Convert*) {}, //
- [&](Discard*) {}, //
- [&](UserCall*) {}, //
- [&](Default) { AddError(call, InstError(call, "missing validation")); });
- }
-
- void CheckAccess(ir::Access* a) {
- bool is_ptr = a->Object()->Type()->Is<type::Pointer>();
- auto* ty = a->Object()->Type()->UnwrapPtr();
-
- auto current = [&] {
- return is_ptr ? "ptr<" + ty->FriendlyName() + ">" : ty->FriendlyName();
- };
-
- for (size_t i = 0; i < a->Indices().Length(); i++) {
- auto err = [&](std::string msg) {
- AddError(a, i + Access::kIndicesOperandOffset, InstError(a, msg));
- };
- auto note = [&](std::string msg) {
- AddNote(a, i + Access::kIndicesOperandOffset, msg);
- };
-
- auto* index = a->Indices()[i];
- if (TINT_UNLIKELY(!index->Type()->is_integer_scalar())) {
- err("index must be integer, got " + index->Type()->FriendlyName());
- return;
- }
-
- if (is_ptr && ty->Is<type::Vector>()) {
- err("cannot obtain address of vector element");
- return;
- }
-
- if (auto* const_index = index->As<ir::Constant>()) {
- auto* value = const_index->Value();
- if (value->Type()->is_signed_integer_scalar()) {
- // index is a signed integer scalar. Check that the index isn't negative.
- // If the index is unsigned, we can skip this.
- auto idx = value->ValueAs<AInt>();
- if (TINT_UNLIKELY(idx < 0)) {
- err("constant index must be positive, got " + std::to_string(idx));
- return;
- }
- }
-
- auto idx = value->ValueAs<uint32_t>();
- auto* el = ty->Element(idx);
- if (TINT_UNLIKELY(!el)) {
- // Is index in bounds?
- if (auto el_count = ty->Elements().count; el_count != 0 && idx >= el_count) {
- err("index out of bounds for type " + current());
- note("acceptable range: [0.." + std::to_string(el_count - 1) + "]");
- return;
- }
- err("type " + current() + " cannot be indexed");
- return;
- }
- ty = el;
- } else {
- auto* el = ty->Elements().type;
- if (TINT_UNLIKELY(!el)) {
- err("type " + current() + " cannot be dynamically indexed");
- return;
- }
- ty = el;
- }
- }
-
- auto* want_ty = a->Result()->Type()->UnwrapPtr();
- bool want_ptr = a->Result()->Type()->Is<type::Pointer>();
- if (TINT_UNLIKELY(ty != want_ty || is_ptr != want_ptr)) {
- std::string want =
- want_ptr ? "ptr<" + want_ty->FriendlyName() + ">" : want_ty->FriendlyName();
- AddError(a, InstError(a, "result of access chain is type " + current() +
- " but instruction type is " + want));
- return;
- }
- }
-
- void CheckBinary(ir::Binary* b) {
- CheckOperandsNotNull(b, Binary::kLhsOperandOffset, Binary::kRhsOperandOffset);
- }
-
- void CheckUnary(ir::Unary* u) {
- CheckOperandNotNull(u, u->Val(), Unary::kValueOperandOffset);
-
- if (u->Result() && u->Val()) {
- if (u->Result()->Type() != u->Val()->Type()) {
- AddError(u, InstError(u, "result type must match value type"));
- }
- }
- }
-
- void CheckIf(If* if_) {
- CheckOperandNotNull(if_, if_->Condition(), If::kConditionOperandOffset);
-
- if (if_->Condition() && !if_->Condition()->Type()->Is<type::Bool>()) {
- AddError(if_, If::kConditionOperandOffset,
- InstError(if_, "condition must be a `bool` type"));
- }
-
- control_stack_.Push(if_);
- TINT_DEFER(control_stack_.Pop());
-
- CheckBlock(if_->True());
- if (!if_->False()->IsEmpty()) {
- CheckBlock(if_->False());
- }
- }
-
- void CheckLoop(Loop* l) {
- control_stack_.Push(l);
- TINT_DEFER(control_stack_.Pop());
-
- if (!l->Initializer()->IsEmpty()) {
- CheckBlock(l->Initializer());
- }
- CheckBlock(l->Body());
-
- if (!l->Continuing()->IsEmpty()) {
- CheckBlock(l->Continuing());
- }
- }
-
- void CheckSwitch(Switch* s) {
- control_stack_.Push(s);
- TINT_DEFER(control_stack_.Pop());
-
- for (auto& cse : s->Cases()) {
- CheckBlock(cse.block);
- }
- }
-
- void CheckTerminator(ir::Terminator* b) {
- // Note, transforms create `undef` terminator arguments (this is done in MergeReturn and
- // DemoteToHelper) so we can't add validation.
-
- tint::Switch(
- b, //
- [&](ir::BreakIf*) {}, //
- [&](ir::Continue*) {}, //
- [&](ir::Exit* e) { CheckExit(e); }, //
- [&](ir::NextIteration*) {}, //
- [&](ir::Return* ret) {
- if (ret->Func() == nullptr) {
- AddError(ret, InstError(ret, "undefined function"));
- }
- },
- [&](ir::TerminateInvocation*) {}, //
- [&](ir::Unreachable*) {}, //
- [&](Default) { AddError(b, InstError(b, "missing validation")); });
- }
-
- void CheckExit(ir::Exit* e) {
- if (e->ControlInstruction() == nullptr) {
- AddError(e, InstError(e, "has no parent control instruction"));
- return;
- }
-
- if (control_stack_.IsEmpty()) {
- AddError(e, InstError(e, "found outside all control instructions"));
- return;
- }
-
- auto results = e->ControlInstruction()->Results();
- auto args = e->Args();
- if (results.Length() != args.Length()) {
- AddError(e, InstError(e, std::string("args count (") + std::to_string(args.Length()) +
- ") does not match control instruction result count (" +
- std::to_string(results.Length()) + ")"));
- AddNote(e->ControlInstruction(), "control instruction");
- return;
- }
-
+ if (inst->HasResults()) {
+ auto results = inst->Results();
for (size_t i = 0; i < results.Length(); ++i) {
- if (results[i] && args[i] && results[i]->Type() != args[i]->Type()) {
- AddError(e, i,
- InstError(e, std::string("argument type (") +
- results[i]->Type()->FriendlyName() +
- ") does not match control instruction type (" +
- args[i]->Type()->FriendlyName() + ")"));
- AddNote(e->ControlInstruction(), "control instruction");
+ auto* res = results[i];
+ if (!res) {
+ AddResultError(inst, i, InstError(inst, "instruction result is undefined"));
+ continue;
}
- }
- tint::Switch(
- e, //
- [&](ir::ExitIf* i) { CheckExitIf(i); }, //
- [&](ir::ExitLoop* l) { CheckExitLoop(l); }, //
- [&](ir::ExitSwitch* s) { CheckExitSwitch(s); }, //
- [&](Default) { AddError(e, InstError(e, "missing validation")); });
- }
-
- void CheckExitIf(ExitIf* e) {
- if (control_stack_.Back() != e->If()) {
- AddError(e, InstError(e, "if target jumps over other control instructions"));
- AddNote(control_stack_.Back(), "first control instruction jumped");
+ if (res->Source() == nullptr) {
+ AddResultError(inst, i, InstError(inst, "instruction result source is undefined"));
+ } else if (res->Source() != inst) {
+ AddResultError(inst, i,
+ InstError(inst, "instruction result source has wrong instruction"));
+ }
}
}
- void CheckControlsAllowingIf(Exit* exit, Instruction* control, std::string_view name) {
- bool found = false;
- for (auto ctrl : utils::Reverse(control_stack_)) {
- if (ctrl == control) {
- found = true;
- break;
+ auto ops = inst->Operands();
+ for (size_t i = 0; i < ops.Length(); ++i) {
+ auto* op = ops[i];
+ if (!op) {
+ continue;
+ }
+
+ // Note, a `nullptr` is a valid operand in some cases, like `var` so we can't just check
+ // for `nullptr` here.
+ if (!op->Alive()) {
+ AddError(inst, i, InstError(inst, "instruction has operand which is not alive"));
+ }
+
+ if (!op->Usages().Contains({inst, i})) {
+ AddError(inst, i, InstError(inst, "instruction operand missing usage"));
+ }
+ }
+
+ tint::Switch(
+ inst, //
+ [&](Access* a) { CheckAccess(a); }, //
+ [&](Binary* b) { CheckBinary(b); }, //
+ [&](Call* c) { CheckCall(c); }, //
+ [&](If* if_) { CheckIf(if_); }, //
+ [&](Let* let) { CheckLet(let); }, //
+ [&](Load*) {}, //
+ [&](LoadVectorElement* l) { CheckLoadVectorElement(l); }, //
+ [&](Loop* l) { CheckLoop(l); }, //
+ [&](Store*) {}, //
+ [&](StoreVectorElement* s) { CheckStoreVectorElement(s); }, //
+ [&](Switch* s) { CheckSwitch(s); }, //
+ [&](Swizzle*) {}, //
+ [&](Terminator* b) { CheckTerminator(b); }, //
+ [&](Unary* u) { CheckUnary(u); }, //
+ [&](Var* var) { CheckVar(var); }, //
+ [&](Default) { AddError(inst, InstError(inst, "missing validation")); });
+}
+
+void Validator::CheckVar(Var* var) {
+ if (var->Result() && var->Initializer()) {
+ if (var->Initializer()->Type() != var->Result()->Type()->UnwrapPtr()) {
+ AddError(var, InstError(var, "initializer has incorrect type"));
+ }
+ }
+}
+
+void Validator::CheckLet(Let* let) {
+ CheckOperandNotNull(let, let->Value(), Let::kValueOperandOffset);
+
+ if (let->Result() && let->Value()) {
+ if (let->Result()->Type() != let->Value()->Type()) {
+ AddError(let, InstError(let, "result type does not match value type"));
+ }
+ }
+}
+
+void Validator::CheckCall(Call* call) {
+ tint::Switch(
+ call, //
+ [&](Bitcast*) {}, //
+ [&](CoreBuiltinCall*) {}, //
+ [&](IntrinsicCall*) {}, //
+ [&](Construct*) {}, //
+ [&](Convert*) {}, //
+ [&](Discard*) {}, //
+ [&](UserCall*) {}, //
+ [&](Default) { AddError(call, InstError(call, "missing validation")); });
+}
+
+void Validator::CheckAccess(ir::Access* a) {
+ bool is_ptr = a->Object()->Type()->Is<type::Pointer>();
+ auto* ty = a->Object()->Type()->UnwrapPtr();
+
+ auto current = [&] { return is_ptr ? "ptr<" + ty->FriendlyName() + ">" : ty->FriendlyName(); };
+
+ for (size_t i = 0; i < a->Indices().Length(); i++) {
+ auto err = [&](std::string msg) {
+ AddError(a, i + Access::kIndicesOperandOffset, InstError(a, msg));
+ };
+ auto note = [&](std::string msg) { AddNote(a, i + Access::kIndicesOperandOffset, msg); };
+
+ auto* index = a->Indices()[i];
+ if (TINT_UNLIKELY(!index->Type()->is_integer_scalar())) {
+ err("index must be integer, got " + index->Type()->FriendlyName());
+ return;
+ }
+
+ if (is_ptr && ty->Is<type::Vector>()) {
+ err("cannot obtain address of vector element");
+ return;
+ }
+
+ if (auto* const_index = index->As<ir::Constant>()) {
+ auto* value = const_index->Value();
+ if (value->Type()->is_signed_integer_scalar()) {
+ // index is a signed integer scalar. Check that the index isn't negative.
+ // If the index is unsigned, we can skip this.
+ auto idx = value->ValueAs<AInt>();
+ if (TINT_UNLIKELY(idx < 0)) {
+ err("constant index must be positive, got " + std::to_string(idx));
+ return;
+ }
}
- // A exit switch can step over if instructions, but no others.
- if (!ctrl->Is<ir::If>()) {
- AddError(exit,
- InstError(exit, std::string(name) +
- " target jumps over other control instructions"));
- AddNote(ctrl, "first control instruction jumped");
+
+ auto idx = value->ValueAs<uint32_t>();
+ auto* el = ty->Element(idx);
+ if (TINT_UNLIKELY(!el)) {
+ // Is index in bounds?
+ if (auto el_count = ty->Elements().count; el_count != 0 && idx >= el_count) {
+ err("index out of bounds for type " + current());
+ note("acceptable range: [0.." + std::to_string(el_count - 1) + "]");
+ return;
+ }
+ err("type " + current() + " cannot be indexed");
return;
}
- }
- if (!found) {
- AddError(exit, InstError(exit, std::string(name) +
- " not found in parent control instructions"));
- }
- }
-
- void CheckExitSwitch(ExitSwitch* s) {
- CheckControlsAllowingIf(s, s->ControlInstruction(), "switch");
- }
-
- void CheckExitLoop(ExitLoop* l) {
- CheckControlsAllowingIf(l, l->ControlInstruction(), "loop");
-
- Instruction* inst = l;
- Loop* control = l->Loop();
- while (inst) {
- // Found parent loop
- if (inst->Block()->Parent() == control) {
- if (inst->Block() == control->Continuing()) {
- AddError(l, InstError(l, "loop exit jumps out of continuing block"));
- if (control->Continuing() != l->Block()) {
- AddNote(control->Continuing(), "in continuing block");
- }
- } else if (inst->Block() == control->Initializer()) {
- AddError(l, InstError(l, "loop exit not permitted in loop initializer"));
- if (control->Initializer() != l->Block()) {
- AddNote(control->Initializer(), "in initializer block");
- }
- }
- break;
+ ty = el;
+ } else {
+ auto* el = ty->Elements().type;
+ if (TINT_UNLIKELY(!el)) {
+ err("type " + current() + " cannot be dynamically indexed");
+ return;
}
- inst = inst->Block()->Parent();
+ ty = el;
}
}
- void CheckLoadVectorElement(LoadVectorElement* l) {
- CheckOperandsNotNull(l, //
- LoadVectorElement::kFromOperandOffset,
- LoadVectorElement::kIndexOperandOffset);
+ auto* want_ty = a->Result()->Type()->UnwrapPtr();
+ bool want_ptr = a->Result()->Type()->Is<type::Pointer>();
+ if (TINT_UNLIKELY(ty != want_ty || is_ptr != want_ptr)) {
+ std::string want =
+ want_ptr ? "ptr<" + want_ty->FriendlyName() + ">" : want_ty->FriendlyName();
+ AddError(a, InstError(a, "result of access chain is type " + current() +
+ " but instruction type is " + want));
+ return;
+ }
+}
- if (auto* res = l->Result()) {
- if (auto* el_ty = GetVectorPtrElementType(l, LoadVectorElement::kFromOperandOffset)) {
- if (res->Type() != el_ty) {
- AddResultError(l, 0, "result type does not match vector pointer element type");
+void Validator::CheckBinary(ir::Binary* b) {
+ CheckOperandsNotNull(b, Binary::kLhsOperandOffset, Binary::kRhsOperandOffset);
+}
+
+void Validator::CheckUnary(ir::Unary* u) {
+ CheckOperandNotNull(u, u->Val(), Unary::kValueOperandOffset);
+
+ if (u->Result() && u->Val()) {
+ if (u->Result()->Type() != u->Val()->Type()) {
+ AddError(u, InstError(u, "result type must match value type"));
+ }
+ }
+}
+
+void Validator::CheckIf(If* if_) {
+ CheckOperandNotNull(if_, if_->Condition(), If::kConditionOperandOffset);
+
+ if (if_->Condition() && !if_->Condition()->Type()->Is<type::Bool>()) {
+ AddError(if_, If::kConditionOperandOffset,
+ InstError(if_, "condition must be a `bool` type"));
+ }
+
+ control_stack_.Push(if_);
+ TINT_DEFER(control_stack_.Pop());
+
+ CheckBlock(if_->True());
+ if (!if_->False()->IsEmpty()) {
+ CheckBlock(if_->False());
+ }
+}
+
+void Validator::CheckLoop(Loop* l) {
+ control_stack_.Push(l);
+ TINT_DEFER(control_stack_.Pop());
+
+ if (!l->Initializer()->IsEmpty()) {
+ CheckBlock(l->Initializer());
+ }
+ CheckBlock(l->Body());
+
+ if (!l->Continuing()->IsEmpty()) {
+ CheckBlock(l->Continuing());
+ }
+}
+
+void Validator::CheckSwitch(Switch* s) {
+ control_stack_.Push(s);
+ TINT_DEFER(control_stack_.Pop());
+
+ for (auto& cse : s->Cases()) {
+ CheckBlock(cse.block);
+ }
+}
+
+void Validator::CheckTerminator(ir::Terminator* b) {
+ // Note, transforms create `undef` terminator arguments (this is done in MergeReturn and
+ // DemoteToHelper) so we can't add validation.
+
+ tint::Switch(
+ b, //
+ [&](ir::BreakIf*) {}, //
+ [&](ir::Continue*) {}, //
+ [&](ir::Exit* e) { CheckExit(e); }, //
+ [&](ir::NextIteration*) {}, //
+ [&](ir::Return* ret) {
+ if (ret->Func() == nullptr) {
+ AddError(ret, InstError(ret, "undefined function"));
+ }
+ },
+ [&](ir::TerminateInvocation*) {}, //
+ [&](ir::Unreachable*) {}, //
+ [&](Default) { AddError(b, InstError(b, "missing validation")); });
+}
+
+void Validator::CheckExit(ir::Exit* e) {
+ if (e->ControlInstruction() == nullptr) {
+ AddError(e, InstError(e, "has no parent control instruction"));
+ return;
+ }
+
+ if (control_stack_.IsEmpty()) {
+ AddError(e, InstError(e, "found outside all control instructions"));
+ return;
+ }
+
+ auto results = e->ControlInstruction()->Results();
+ auto args = e->Args();
+ if (results.Length() != args.Length()) {
+ AddError(e, InstError(e, std::string("args count (") + std::to_string(args.Length()) +
+ ") does not match control instruction result count (" +
+ std::to_string(results.Length()) + ")"));
+ AddNote(e->ControlInstruction(), "control instruction");
+ return;
+ }
+
+ for (size_t i = 0; i < results.Length(); ++i) {
+ if (results[i] && args[i] && results[i]->Type() != args[i]->Type()) {
+ AddError(
+ e, i,
+ InstError(e, std::string("argument type (") + results[i]->Type()->FriendlyName() +
+ ") does not match control instruction type (" +
+ args[i]->Type()->FriendlyName() + ")"));
+ AddNote(e->ControlInstruction(), "control instruction");
+ }
+ }
+
+ tint::Switch(
+ e, //
+ [&](ir::ExitIf* i) { CheckExitIf(i); }, //
+ [&](ir::ExitLoop* l) { CheckExitLoop(l); }, //
+ [&](ir::ExitSwitch* s) { CheckExitSwitch(s); }, //
+ [&](Default) { AddError(e, InstError(e, "missing validation")); });
+}
+
+void Validator::CheckExitIf(ExitIf* e) {
+ if (control_stack_.Back() != e->If()) {
+ AddError(e, InstError(e, "if target jumps over other control instructions"));
+ AddNote(control_stack_.Back(), "first control instruction jumped");
+ }
+}
+
+void Validator::CheckControlsAllowingIf(Exit* exit, Instruction* control) {
+ bool found = false;
+ for (auto ctrl : utils::Reverse(control_stack_)) {
+ if (ctrl == control) {
+ found = true;
+ break;
+ }
+ // A exit switch can step over if instructions, but no others.
+ if (!ctrl->Is<ir::If>()) {
+ AddError(exit, InstError(exit, std::string(control->FriendlyName()) +
+ " target jumps over other control instructions"));
+ AddNote(ctrl, "first control instruction jumped");
+ return;
+ }
+ }
+ if (!found) {
+ AddError(exit, InstError(exit, std::string(control->FriendlyName()) +
+ " not found in parent control instructions"));
+ }
+}
+
+void Validator::CheckExitSwitch(ExitSwitch* s) {
+ CheckControlsAllowingIf(s, s->ControlInstruction());
+}
+
+void Validator::CheckExitLoop(ExitLoop* l) {
+ CheckControlsAllowingIf(l, l->ControlInstruction());
+
+ Instruction* inst = l;
+ Loop* control = l->Loop();
+ while (inst) {
+ // Found parent loop
+ if (inst->Block()->Parent() == control) {
+ if (inst->Block() == control->Continuing()) {
+ AddError(l, InstError(l, "loop exit jumps out of continuing block"));
+ if (control->Continuing() != l->Block()) {
+ AddNote(control->Continuing(), "in continuing block");
+ }
+ } else if (inst->Block() == control->Initializer()) {
+ AddError(l, InstError(l, "loop exit not permitted in loop initializer"));
+ if (control->Initializer() != l->Block()) {
+ AddNote(control->Initializer(), "in initializer block");
}
}
+ break;
}
+ inst = inst->Block()->Parent();
}
+}
- void CheckStoreVectorElement(StoreVectorElement* s) {
- CheckOperandsNotNull(s, //
- StoreVectorElement::kToOperandOffset,
- StoreVectorElement::kValueOperandOffset);
+void Validator::CheckLoadVectorElement(LoadVectorElement* l) {
+ CheckOperandsNotNull(l, //
+ LoadVectorElement::kFromOperandOffset,
+ LoadVectorElement::kIndexOperandOffset);
- if (auto* value = s->Value()) {
- if (auto* el_ty = GetVectorPtrElementType(s, StoreVectorElement::kToOperandOffset)) {
- if (value->Type() != el_ty) {
- AddError(s, StoreVectorElement::kValueOperandOffset,
- "value type does not match vector pointer element type");
- }
+ if (auto* res = l->Result()) {
+ if (auto* el_ty = GetVectorPtrElementType(l, LoadVectorElement::kFromOperandOffset)) {
+ if (res->Type() != el_ty) {
+ AddResultError(l, 0, "result type does not match vector pointer element type");
}
}
}
+}
- const type::Type* GetVectorPtrElementType(Instruction* inst, size_t idx) {
- auto* operand = inst->Operands()[idx];
- if (TINT_UNLIKELY(!operand)) {
- return nullptr;
- }
+void Validator::CheckStoreVectorElement(StoreVectorElement* s) {
+ CheckOperandsNotNull(s, //
+ StoreVectorElement::kToOperandOffset,
+ StoreVectorElement::kValueOperandOffset);
- auto* type = operand->Type();
- if (TINT_UNLIKELY(!type)) {
- return nullptr;
- }
-
- auto* vec_ptr_ty = type->As<type::Pointer>();
- if (TINT_LIKELY(vec_ptr_ty)) {
- auto* vec_ty = vec_ptr_ty->StoreType()->As<type::Vector>();
- if (TINT_LIKELY(vec_ty)) {
- return vec_ty->type();
+ if (auto* value = s->Value()) {
+ if (auto* el_ty = GetVectorPtrElementType(s, StoreVectorElement::kToOperandOffset)) {
+ if (value->Type() != el_ty) {
+ AddError(s, StoreVectorElement::kValueOperandOffset,
+ "value type does not match vector pointer element type");
}
}
+ }
+}
- AddError(inst, idx, "operand must be a pointer to vector, got " + type->FriendlyName());
+const type::Type* Validator::GetVectorPtrElementType(Instruction* inst, size_t idx) {
+ auto* operand = inst->Operands()[idx];
+ if (TINT_UNLIKELY(!operand)) {
return nullptr;
}
-};
-} // namespace
+ auto* type = operand->Type();
+ if (TINT_UNLIKELY(!type)) {
+ return nullptr;
+ }
+
+ auto* vec_ptr_ty = type->As<type::Pointer>();
+ if (TINT_LIKELY(vec_ptr_ty)) {
+ auto* vec_ty = vec_ptr_ty->StoreType()->As<type::Vector>();
+ if (TINT_LIKELY(vec_ty)) {
+ return vec_ty->type();
+ }
+ }
+
+ AddError(inst, idx, "operand must be a pointer to vector, got " + type->FriendlyName());
+ return nullptr;
+}
utils::Result<Success, diag::List> Validate(Module& mod) {
Validator v(mod);
diff --git a/src/tint/ir/validator.h b/src/tint/ir/validator.h
index 705104a..b93676a 100644
--- a/src/tint/ir/validator.h
+++ b/src/tint/ir/validator.h
@@ -15,10 +15,25 @@
#ifndef SRC_TINT_IR_VALIDATOR_H_
#define SRC_TINT_IR_VALIDATOR_H_
+#include <string>
+
#include "src/tint/diagnostic/diagnostic.h"
+#include "src/tint/ir/disassembler.h"
#include "src/tint/ir/module.h"
#include "src/tint/utils/result.h"
+// Forward declarations
+namespace tint::ir {
+class Access;
+class ExitIf;
+class ExitLoop;
+class ExitSwitch;
+class Let;
+class LoadVectorElement;
+class StoreVectorElement;
+class Var;
+} // namespace tint::ir
+
namespace tint::ir {
/// Signifies the validation completed successfully
@@ -29,6 +44,196 @@
/// @returns true on success, an error result otherwise
utils::Result<Success, diag::List> Validate(Module& mod);
+/// The core IR validator.
+class Validator {
+ public:
+ /// Create a core validator
+ /// @param mod the module to be validated
+ explicit Validator(Module& mod);
+
+ /// Destructor
+ ~Validator();
+
+ /// Runs the validator over the module provided during construction
+ /// @returns the results of validation, either a success result object or the diagnostics of
+ /// validation failures.
+ utils::Result<Success, diag::List> IsValid();
+
+ protected:
+ /// @param inst the instruction
+ /// @param err the error message
+ /// @returns a string with the instruction name name and error message formatted
+ std::string InstError(Instruction* inst, std::string err);
+
+ /// Adds an error for the @p inst and highlights the instruction in the disassembly
+ /// @param inst the instruction
+ /// @param err the error string
+ void AddError(Instruction* inst, std::string err);
+
+ /// Adds an error for the @p inst operand at @p idx and highlights the operand in the
+ /// disassembly
+ /// @param inst the instaruction
+ /// @param idx the operand index
+ /// @param err the error string
+ void AddError(Instruction* inst, size_t idx, std::string err);
+
+ /// Adds an error for the @p inst result at @p idx and highlgihts the result in the disassembly
+ /// @param inst the instruction
+ /// @param idx the result index
+ /// @param err the error string
+ void AddResultError(Instruction* inst, size_t idx, std::string err);
+
+ /// Adds an error the @p block and highlights the block header in the disassembly
+ /// @param blk the block
+ /// @param err the error string
+ void AddError(Block* blk, std::string err);
+
+ /// Adds a note to @p inst and highlights the instruction in the disassembly
+ /// @param inst the instruction
+ /// @param err the message to emit
+ void AddNote(Instruction* inst, std::string err);
+
+ /// Adds a note to @p inst for operand @p idx and highlights the operand in the
+ /// disassembly
+ /// @param inst the instruction
+ /// @param idx the operand index
+ /// @param err the message string
+ void AddNote(Instruction* inst, size_t idx, std::string err);
+
+ /// Adds a note to @p blk and highlights the block in the disassembly
+ /// @param blk the block
+ /// @param err the message to emit
+ void AddNote(Block* blk, std::string err);
+
+ /// Adds an error to the diagnostics
+ /// @param err the message to emit
+ /// @param src the source lines to highlight
+ void AddError(std::string err, Source src = {});
+
+ /// Adds a note to the diagnostics
+ /// @param note the note to emit
+ /// @param src the source lines to highlight
+ void AddNote(std::string note, Source src = {});
+
+ /// @param v the value to get the name for
+ /// @returns the name for the given value
+ std::string Name(Value* v);
+
+ /// Checks the given operand is not null
+ /// @param inst the instruciton
+ /// @param operand the operand
+ /// @param idx the operand index
+ void CheckOperandNotNull(ir::Instruction* inst, ir::Value* operand, size_t idx);
+
+ /// Checks all operands in the given range (inclusive) for @p inst are not null
+ /// @param inst the instruction
+ /// @param start_operand the first operand to check
+ /// @param end_operand the last operand to check
+ void CheckOperandsNotNull(ir::Instruction* inst, size_t start_operand, size_t end_operand);
+
+ /// Validates the root block
+ /// @param blk the block
+ void CheckRootBlock(Block* blk);
+
+ /// Validates the given function
+ /// @param func the function validate
+ void CheckFunction(Function* func);
+
+ /// Validates the given block
+ /// @param blk the block to validate
+ void CheckBlock(Block* blk);
+
+ /// Validates the given instruction
+ /// @param inst the instruction to validate
+ void CheckInstruction(Instruction* inst);
+
+ /// Validates the given var
+ /// @param var the var to validate
+ void CheckVar(Var* var);
+
+ /// Validates the given let
+ /// @param let the let to validate
+ void CheckLet(Let* let);
+
+ /// Validates the given call
+ /// @param call the call to validate
+ void CheckCall(Call* call);
+
+ /// Validates the given access
+ /// @param a the access to validate
+ void CheckAccess(ir::Access* a);
+
+ /// Validates the given binary
+ /// @param b the binary to validate
+ void CheckBinary(ir::Binary* b);
+
+ /// Validates the given unary
+ /// @param u the unary to validate
+ void CheckUnary(ir::Unary* u);
+
+ /// Validates the given if
+ /// @param if_ the if to validate
+ void CheckIf(If* if_);
+
+ /// Validates the given loop
+ /// @param l the loop to validate
+ void CheckLoop(Loop* l);
+
+ /// Validates the given switch
+ /// @param s the switch to validate
+ void CheckSwitch(Switch* s);
+
+ /// Validates the given terminator
+ /// @param b the terminator to validate
+ void CheckTerminator(ir::Terminator* b);
+
+ /// Validates the given exit
+ /// @param e the exit to validate
+ void CheckExit(ir::Exit* e);
+
+ /// Validates the given exit if
+ /// @param e the exit if to validate
+ void CheckExitIf(ExitIf* e);
+
+ /// Validates the @p exit targets a valid @p control instruction where the instruction may jump
+ /// over if control instructions.
+ /// @param exit the exit to validate
+ /// @param control the control instruction targeted
+ void CheckControlsAllowingIf(Exit* exit, Instruction* control);
+
+ /// Validates the given exit switch
+ /// @param s the exit switch to validate
+ void CheckExitSwitch(ExitSwitch* s);
+
+ /// Validates the given exit loop
+ /// @param l the exit loop to validate
+ void CheckExitLoop(ExitLoop* l);
+
+ /// Validates the given load vector element
+ /// @param l the load vector element to validate
+ void CheckLoadVectorElement(LoadVectorElement* l);
+
+ /// Validates the given store vector element
+ /// @param s the store vector element to validate
+ void CheckStoreVectorElement(StoreVectorElement* s);
+
+ /// @param inst the instruction
+ /// @param idx the operand index
+ /// @returns the vector pointer type for the given instruction operand
+ const type::Type* GetVectorPtrElementType(Instruction* inst, size_t idx);
+
+ private:
+ Module& mod_;
+ diag::List diagnostics_;
+ Disassembler dis_{mod_};
+
+ Block* current_block_ = nullptr;
+ utils::Hashset<Function*, 4> seen_functions_;
+ utils::Vector<ControlInstruction*, 8> control_stack_;
+
+ void DisassembleIfNeeded();
+};
+
} // namespace tint::ir
#endif // SRC_TINT_IR_VALIDATOR_H_
diff --git a/src/tint/ir/validator_test.cc b/src/tint/ir/validator_test.cc
index a0d0b92..881de2c 100644
--- a/src/tint/ir/validator_test.cc
+++ b/src/tint/ir/validator_test.cc
@@ -2716,9 +2716,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:32 error: store-vector-element: operand is undefined
- store_vector_element undef undef, 1i, 2i
- ^^^^^
+ EXPECT_EQ(res.Failure().str(), R"(:3:26 error: store-vector-element: operand is undefined
+ store_vector_element undef, 1i, 2i
+ ^^^^^
:2:3 note: In block
%b1 = block {
@@ -2727,7 +2727,7 @@
note: # Disassembly
%my_func = func():void -> %b1 {
%b1 = block {
- store_vector_element undef undef, 1i, 2i
+ store_vector_element undef, 1i, 2i
ret
}
}
@@ -2746,17 +2746,17 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:4:33 error: store-vector-element: operand is undefined
- store_vector_element %2 %2, undef, 2i
- ^^^^^
+ EXPECT_EQ(res.Failure().str(), R"(:4:30 error: store-vector-element: operand is undefined
+ store_vector_element %2, undef, 2i
+ ^^^^^
:2:3 note: In block
%b1 = block {
^^^^^^^^^^^
-:4:40 error: value type does not match vector pointer element type
- store_vector_element %2 %2, undef, 2i
- ^^
+:4:37 error: value type does not match vector pointer element type
+ store_vector_element %2, undef, 2i
+ ^^
:2:3 note: In block
%b1 = block {
@@ -2766,7 +2766,7 @@
%my_func = func():void -> %b1 {
%b1 = block {
%2:ptr<function, vec3<f32>, read_write> = var
- store_vector_element %2 %2, undef, 2i
+ store_vector_element %2, undef, 2i
ret
}
}
@@ -2785,9 +2785,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:4:37 error: store-vector-element: operand is undefined
- store_vector_element %2 %2, 1i, undef
- ^^^^^
+ EXPECT_EQ(res.Failure().str(), R"(:4:34 error: store-vector-element: operand is undefined
+ store_vector_element %2, 1i, undef
+ ^^^^^
:2:3 note: In block
%b1 = block {
@@ -2797,7 +2797,7 @@
%my_func = func():void -> %b1 {
%b1 = block {
%2:ptr<function, vec3<f32>, read_write> = var
- store_vector_element %2 %2, 1i, undef
+ store_vector_element %2, 1i, undef
ret
}
}
diff --git a/src/tint/resolver/bitcast_validation_test.cc b/src/tint/resolver/bitcast_validation_test.cc
index c979c55..631c9c8 100644
--- a/src/tint/resolver/bitcast_validation_test.cc
+++ b/src/tint/resolver/bitcast_validation_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <type_traits>
+
#include "src/tint/ast/bitcast_expression.h"
#include "src/tint/resolver/resolver.h"
#include "src/tint/resolver/resolver_test_helper.h"
@@ -22,39 +24,61 @@
namespace {
using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
struct Type {
+ template <typename T, std::enable_if_t<IsVector<T>, bool> = true>
+ static constexpr bool UsedF16() {
+ return std::is_same_v<typename T::type, f16>;
+ }
+
+ template <typename T, std::enable_if_t<!IsVector<T>, bool> = true>
+ static constexpr bool UsedF16() {
+ return std::is_same_v<T, f16>;
+ }
+
template <typename T>
static constexpr Type Create() {
return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
- builder::DataType<T>::ExprFromDouble};
+ builder::DataType<T>::ExprFromDouble, UsedF16<T>()};
}
builder::ast_type_func_ptr ast;
builder::sem_type_func_ptr sem;
builder::ast_expr_from_double_func_ptr expr;
+ bool used_f16;
};
-static constexpr Type kNumericScalars[] = {
+// Valids numeric scalar and vector types of all bit width
+static constexpr Type k16BitsNumericTypes[] = {
+ Type::Create<f16>(),
+};
+static constexpr Type k32BitsNumericTypes[] = {
Type::Create<f32>(),
Type::Create<i32>(),
Type::Create<u32>(),
+ Type::Create<vec2<f16>>(),
};
-static constexpr Type kVec2NumericScalars[] = {
+static constexpr Type k48BitsNumericTypes[] = {
+ Type::Create<vec3<f16>>(),
+};
+static constexpr Type k64BitsNumericTypes[] = {
Type::Create<vec2<f32>>(),
Type::Create<vec2<i32>>(),
Type::Create<vec2<u32>>(),
+ Type::Create<vec4<f16>>(),
};
-static constexpr Type kVec3NumericScalars[] = {
+static constexpr Type k96BitsNumericTypes[] = {
Type::Create<vec3<f32>>(),
Type::Create<vec3<i32>>(),
Type::Create<vec3<u32>>(),
};
-static constexpr Type kVec4NumericScalars[] = {
+static constexpr Type k128BitsNumericTypes[] = {
Type::Create<vec4<f32>>(),
Type::Create<vec4<i32>>(),
Type::Create<vec4<u32>>(),
};
+
static constexpr Type kInvalid[] = {
// A non-exhaustive selection of uncastable types
Type::Create<bool>(),
@@ -83,28 +107,40 @@
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
+ if (src.used_f16 || dst.used_f16) {
+ Enable(builtin::Extension::kF16);
+ }
+
auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
WrapInFunction(cast);
ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(TypeOf(cast), dst.sem(*this));
}
-INSTANTIATE_TEST_SUITE_P(Scalars,
+INSTANTIATE_TEST_SUITE_P(16Bits,
ResolverBitcastValidationTestPass,
- testing::Combine(testing::ValuesIn(kNumericScalars),
- testing::ValuesIn(kNumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec2,
+ testing::Combine(testing::ValuesIn(k16BitsNumericTypes),
+ testing::ValuesIn(k16BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(32Bits,
ResolverBitcastValidationTestPass,
- testing::Combine(testing::ValuesIn(kVec2NumericScalars),
- testing::ValuesIn(kVec2NumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec3,
+ testing::Combine(testing::ValuesIn(k32BitsNumericTypes),
+ testing::ValuesIn(k32BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(48Bits,
ResolverBitcastValidationTestPass,
- testing::Combine(testing::ValuesIn(kVec3NumericScalars),
- testing::ValuesIn(kVec3NumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec4,
+ testing::Combine(testing::ValuesIn(k48BitsNumericTypes),
+ testing::ValuesIn(k48BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(64Bits,
ResolverBitcastValidationTestPass,
- testing::Combine(testing::ValuesIn(kVec4NumericScalars),
- testing::ValuesIn(kVec4NumericScalars)));
+ testing::Combine(testing::ValuesIn(k64BitsNumericTypes),
+ testing::ValuesIn(k64BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(96Bits,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(k96BitsNumericTypes),
+ testing::ValuesIn(k96BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(128Bits,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(k128BitsNumericTypes),
+ testing::ValuesIn(k128BitsNumericTypes)));
////////////////////////////////////////////////////////////////////////////////
// Invalid source type for bitcasts
@@ -114,6 +150,10 @@
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
+ if (src.used_f16 || dst.used_f16) {
+ Enable(builtin::Extension::kF16);
+ }
+
auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
WrapInFunction(Let("src", src.expr(*this, 0)), cast);
@@ -122,22 +162,30 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
-INSTANTIATE_TEST_SUITE_P(Scalars,
+INSTANTIATE_TEST_SUITE_P(16Bits,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
- testing::ValuesIn(kNumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec2,
+ testing::ValuesIn(k16BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(32Bits,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
- testing::ValuesIn(kVec2NumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec3,
+ testing::ValuesIn(k32BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(48Bits,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
- testing::ValuesIn(kVec3NumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec4,
+ testing::ValuesIn(k48BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(64Bits,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
- testing::ValuesIn(kVec4NumericScalars)));
+ testing::ValuesIn(k64BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(96Bits,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(k96BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(128Bits,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(k128BitsNumericTypes)));
////////////////////////////////////////////////////////////////////////////////
// Invalid target type for bitcasts
@@ -147,6 +195,10 @@
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
+ if (src.used_f16 || dst.used_f16) {
+ Enable(builtin::Extension::kF16);
+ }
+
// Use an alias so we can put a Source on the bitcast type
Alias("T", dst.ast(*this));
WrapInFunction(Bitcast(ty(Source{{12, 34}}, "T"), src.expr(*this, 0)));
@@ -156,21 +208,29 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
-INSTANTIATE_TEST_SUITE_P(Scalars,
+INSTANTIATE_TEST_SUITE_P(16Bits,
ResolverBitcastValidationTestInvalidDstTy,
- testing::Combine(testing::ValuesIn(kNumericScalars),
+ testing::Combine(testing::ValuesIn(k16BitsNumericTypes),
testing::ValuesIn(kInvalid)));
-INSTANTIATE_TEST_SUITE_P(Vec2,
+INSTANTIATE_TEST_SUITE_P(32Bits,
ResolverBitcastValidationTestInvalidDstTy,
- testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+ testing::Combine(testing::ValuesIn(k32BitsNumericTypes),
testing::ValuesIn(kInvalid)));
-INSTANTIATE_TEST_SUITE_P(Vec3,
+INSTANTIATE_TEST_SUITE_P(48Bits,
ResolverBitcastValidationTestInvalidDstTy,
- testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+ testing::Combine(testing::ValuesIn(k48BitsNumericTypes),
testing::ValuesIn(kInvalid)));
-INSTANTIATE_TEST_SUITE_P(Vec4,
+INSTANTIATE_TEST_SUITE_P(64Bits,
ResolverBitcastValidationTestInvalidDstTy,
- testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+ testing::Combine(testing::ValuesIn(k64BitsNumericTypes),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(96Bits,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(k96BitsNumericTypes),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(128Bits,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(k128BitsNumericTypes),
testing::ValuesIn(kInvalid)));
////////////////////////////////////////////////////////////////////////////////
@@ -181,6 +241,10 @@
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
+ if (src.used_f16 || dst.used_f16) {
+ Enable(builtin::Extension::kF16);
+ }
+
WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
auto expected = "12:34 error: cannot bitcast from '" + src.sem(*this)->FriendlyName() +
@@ -189,22 +253,182 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
-INSTANTIATE_TEST_SUITE_P(ScalarToVec2,
+INSTANTIATE_TEST_SUITE_P(16BitsTo32Bits,
ResolverBitcastValidationTestIncompatible,
- testing::Combine(testing::ValuesIn(kNumericScalars),
- testing::ValuesIn(kVec2NumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec2ToVec3,
+ testing::Combine(testing::ValuesIn(k16BitsNumericTypes),
+ testing::ValuesIn(k32BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(16BitsTo48Bits,
ResolverBitcastValidationTestIncompatible,
- testing::Combine(testing::ValuesIn(kVec2NumericScalars),
- testing::ValuesIn(kVec3NumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec3ToVec4,
+ testing::Combine(testing::ValuesIn(k16BitsNumericTypes),
+ testing::ValuesIn(k48BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(16BitsTo64Bits,
ResolverBitcastValidationTestIncompatible,
- testing::Combine(testing::ValuesIn(kVec3NumericScalars),
- testing::ValuesIn(kVec4NumericScalars)));
-INSTANTIATE_TEST_SUITE_P(Vec4ToScalar,
+ testing::Combine(testing::ValuesIn(k16BitsNumericTypes),
+ testing::ValuesIn(k64BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(16BitsTo96Bits,
ResolverBitcastValidationTestIncompatible,
- testing::Combine(testing::ValuesIn(kVec4NumericScalars),
- testing::ValuesIn(kNumericScalars)));
+ testing::Combine(testing::ValuesIn(k16BitsNumericTypes),
+ testing::ValuesIn(k96BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(16BitsTo128Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k16BitsNumericTypes),
+ testing::ValuesIn(k128BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(32BitsTo16Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k32BitsNumericTypes),
+ testing::ValuesIn(k16BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(32BitsTo48Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k32BitsNumericTypes),
+ testing::ValuesIn(k48BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(32BitsTo64Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k32BitsNumericTypes),
+ testing::ValuesIn(k64BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(32BitsTo96Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k32BitsNumericTypes),
+ testing::ValuesIn(k96BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(32BitsTo128Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k32BitsNumericTypes),
+ testing::ValuesIn(k128BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(48BitsTo16Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k48BitsNumericTypes),
+ testing::ValuesIn(k16BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(48BitsTo32Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k48BitsNumericTypes),
+ testing::ValuesIn(k32BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(48BitsTo64Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k48BitsNumericTypes),
+ testing::ValuesIn(k64BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(48BitsTo96Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k48BitsNumericTypes),
+ testing::ValuesIn(k96BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(48BitsTo128Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k48BitsNumericTypes),
+ testing::ValuesIn(k128BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(64BitsTo16Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k64BitsNumericTypes),
+ testing::ValuesIn(k16BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(64BitsTo32Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k64BitsNumericTypes),
+ testing::ValuesIn(k32BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(64BitsTo48Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k64BitsNumericTypes),
+ testing::ValuesIn(k48BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(64BitsTo96Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k64BitsNumericTypes),
+ testing::ValuesIn(k96BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(64BitsTo128Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k64BitsNumericTypes),
+ testing::ValuesIn(k128BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(96BitsTo16Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k96BitsNumericTypes),
+ testing::ValuesIn(k16BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(96BitsTo32Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k96BitsNumericTypes),
+ testing::ValuesIn(k32BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(96BitsTo48Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k96BitsNumericTypes),
+ testing::ValuesIn(k48BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(96BitsTo64Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k96BitsNumericTypes),
+ testing::ValuesIn(k64BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(96BitsTo128Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k96BitsNumericTypes),
+ testing::ValuesIn(k128BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(128BitsTo16Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k128BitsNumericTypes),
+ testing::ValuesIn(k16BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(128BitsTo32Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k128BitsNumericTypes),
+ testing::ValuesIn(k32BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(128BitsTo48Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k128BitsNumericTypes),
+ testing::ValuesIn(k48BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(128BitsTo64Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k128BitsNumericTypes),
+ testing::ValuesIn(k64BitsNumericTypes)));
+INSTANTIATE_TEST_SUITE_P(128BitsTo96Bits,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(k128BitsNumericTypes),
+ testing::ValuesIn(k96BitsNumericTypes)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Compile-time bitcasts to NaN or Inf are invalid
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidConst = tint::resolver::ResolverTest;
+TEST_F(ResolverBitcastValidationTestInvalidConst, ConstBitcastToF16NaN) {
+ Enable(builtin::Extension::kF16);
+
+ // Lower 16 bits of const u32 0x7e10 is NaN in f16.
+ auto* a = Const("a", Expr(u32(0x00007e10)));
+ auto* b = Let("b", Bitcast(Source{{12, 34}}, ty.Of<vec2<f16>>(), Expr("a")));
+ WrapInFunction(a, b);
+
+ auto expected = "12:34 error: value nan cannot be represented as 'f16'";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+
+TEST_F(ResolverBitcastValidationTestInvalidConst, ConstBitcastToF16Inf) {
+ Enable(builtin::Extension::kF16);
+
+ // 0xfc00 is -Inf in f16.
+ auto* a = Const("a", Call<vec2<u32>>(u32(0x00007010), u32(0xfc008000)));
+ auto* b = Let("b", Bitcast(Source{{12, 34}}, ty.Of<vec4<f16>>(), Expr("a")));
+ WrapInFunction(a, b);
+
+ auto expected = "12:34 error: value -inf cannot be represented as 'f16'";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+
+TEST_F(ResolverBitcastValidationTestInvalidConst, ConstBitcastToF32NaN) {
+ // 0xffc00000 is NaN in f32.
+ auto* a = Const("a", Expr(u32(0xffc00000)));
+ auto* b = Let("b", Bitcast(Source{{12, 34}}, ty.Of<f32>(), Expr("a")));
+ WrapInFunction(a, b);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), ::testing::HasSubstr("cannot be represented as 'f32'"));
+}
+
+TEST_F(ResolverBitcastValidationTestInvalidConst, ConstBitcastToF32Inf) {
+ Enable(builtin::Extension::kF16);
+
+ // 0x7f800000 is Inf in f32.
+ auto* a = Const("a", Call<vec3<u32>>(u32(0xA0008000), u32(0x7f800000), u32(0x40000000)));
+ auto* b = Let("b", Bitcast(Source{{12, 34}}, ty.Of<vec3<f32>>(), Expr("a")));
+ WrapInFunction(a, b);
+
+ auto expected = "12:34 error: value inf cannot be represented as 'f32'";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 73e274b..0a0391d 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -73,17 +73,6 @@
/// Helper that calls `f` passing in the value of all `cs`.
/// Calls `f` with all constants cast to the type of the first `cs` argument.
template <typename F, typename... CONSTANTS>
-auto Dispatch_fiu32(F&& f, CONSTANTS&&... cs) {
- return Switch(
- First(cs...)->Type(), //
- [&](const type::F32*) { return f(cs->template ValueAs<f32>()...); },
- [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); },
- [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); });
-}
-
-/// Helper that calls `f` passing in the value of all `cs`.
-/// Calls `f` with all constants cast to the type of the first `cs` argument.
-template <typename F, typename... CONSTANTS>
auto Dispatch_ia_iu32(F&& f, CONSTANTS&&... cs) {
return Switch(
First(cs...)->Type(), //
@@ -1462,27 +1451,122 @@
ConstEval::Result ConstEval::Bitcast(const type::Type* ty,
const constant::Value* value,
const Source& source) {
- auto* el_ty = ty->DeepestElement();
- auto transform = [&](const constant::Value* c0) {
- auto create = [&](auto e) {
- return Switch(
- el_ty,
- [&](const type::U32*) { //
- auto r = utils::Bitcast<u32>(e);
- return CreateScalar(source, el_ty, r);
- },
- [&](const type::I32*) { //
- auto r = utils::Bitcast<i32>(e);
- return CreateScalar(source, el_ty, r);
- },
- [&](const type::F32*) { //
- auto r = utils::Bitcast<f32>(e);
- return CreateScalar(source, el_ty, r);
- });
+ // Target type
+ auto dst_elements = ty->Elements(ty->DeepestElement(), 1u);
+ auto dst_el_ty = dst_elements.type;
+ auto dst_count = dst_elements.count;
+ // Source type
+ auto src_elements = value->Type()->Elements(value->Type()->DeepestElement(), 1u);
+ auto src_el_ty = src_elements.type;
+ auto src_count = src_elements.count;
+
+ TINT_ASSERT(Resolver, dst_count * dst_el_ty->Size() == src_count * src_el_ty->Size());
+ uint32_t total_bitwidth = dst_count * dst_el_ty->Size();
+ // Buffer holding the bits from source value, result value reinterpreted from it.
+ utils::Vector<std::byte, 16> buffer;
+ buffer.Reserve(total_bitwidth);
+
+ // Ensure elements are of 32-bit or 16-bit numerical scalar type.
+ TINT_ASSERT(Resolver, (src_el_ty->IsAnyOf<type::F32, type::I32, type::U32, type::F16>()));
+ // Pushes bits from source value into the buffer.
+ auto push_src_element_bits = [&](const constant::Value* element) {
+ auto push_32_bits = [&](uint32_t v) {
+ buffer.Push(std::byte(v & 0xffu));
+ buffer.Push(std::byte((v >> 8) & 0xffu));
+ buffer.Push(std::byte((v >> 16) & 0xffu));
+ buffer.Push(std::byte((v >> 24) & 0xffu));
};
- return Dispatch_fiu32(create, c0);
+ auto push_16_bits = [&](uint16_t v) {
+ buffer.Push(std::byte(v & 0xffu));
+ buffer.Push(std::byte((v >> 8) & 0xffu));
+ };
+ Switch(
+ src_el_ty,
+ [&](const type::U32*) { //
+ uint32_t r = element->ValueAs<u32>();
+ push_32_bits(r);
+ },
+ [&](const type::I32*) { //
+ uint32_t r = utils::Bitcast<u32>(element->ValueAs<i32>());
+ push_32_bits(r);
+ },
+ [&](const type::F32*) { //
+ uint32_t r = utils::Bitcast<u32>(element->ValueAs<f32>());
+ push_32_bits(r);
+ },
+ [&](const type::F16*) { //
+ uint16_t r = element->ValueAs<f16>().BitsRepresentation();
+ push_16_bits(r);
+ });
};
- return TransformElements(builder, ty, transform, value);
+ if (src_count == 1) {
+ push_src_element_bits(value);
+ } else {
+ for (size_t i = 0; i < src_count; i++) {
+ push_src_element_bits(value->Index(i));
+ }
+ }
+
+ // Vector holding elements of return value
+ utils::Vector<const constant::Value*, 4> els;
+
+ // Reinterprets the buffer bits as destination element and push the result into the vector.
+ // Return false if an error occured, otherwise return true.
+ auto push_dst_element = [&](size_t offset) -> bool {
+ uint32_t v;
+ if (dst_el_ty->Size() == 4) {
+ v = (std::to_integer<uint32_t>(buffer[offset])) |
+ (std::to_integer<uint32_t>(buffer[offset + 1]) << 8) |
+ (std::to_integer<uint32_t>(buffer[offset + 2]) << 16) |
+ (std::to_integer<uint32_t>(buffer[offset + 3]) << 24);
+ } else {
+ v = (std::to_integer<uint32_t>(buffer[offset])) |
+ (std::to_integer<uint32_t>(buffer[offset + 1]) << 8);
+ }
+
+ return Switch(
+ dst_el_ty,
+ [&](const type::U32*) { //
+ auto r = CreateScalar(source, dst_el_ty, u32(v));
+ if (r) {
+ els.Push(r.Get());
+ }
+ return r;
+ },
+ [&](const type::I32*) { //
+ auto r = CreateScalar(source, dst_el_ty, utils::Bitcast<i32>(v));
+ if (r) {
+ els.Push(r.Get());
+ }
+ return r;
+ },
+ [&](const type::F32*) { //
+ auto r = CreateScalar(source, dst_el_ty, utils::Bitcast<f32>(v));
+ if (r) {
+ els.Push(r.Get());
+ }
+ return r;
+ },
+ [&](const type::F16*) { //
+ auto r = CreateScalar(source, dst_el_ty, f16::FromBits(static_cast<uint16_t>(v)));
+ if (r) {
+ els.Push(r.Get());
+ }
+ return r;
+ });
+ };
+
+ TINT_ASSERT(Resolver, (buffer.Length() == total_bitwidth));
+ for (size_t i = 0; i < dst_count; i++) {
+ if (!push_dst_element(i * dst_el_ty->Size())) {
+ return utils::Failure;
+ }
+ }
+
+ if (dst_count == 1) {
+ return std::move(els[0]);
+ }
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::OpComplement(const type::Type* ty,
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index c02e123..10d4076 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1411,14 +1411,8 @@
return false;
}
- auto width = [&](const type::Type* ty) {
- if (auto* vec = ty->As<type::Vector>()) {
- return vec->Width();
- }
- return 1u;
- };
-
- if (width(from) != width(to)) {
+ // Only bitcasts between scalar/vector types of the same bit width are allowed.
+ if (from->Size() != to->Size()) {
AddError(
"cannot bitcast from '" + sem_.TypeNameOf(from) + "' to '" + sem_.TypeNameOf(to) + "'",
cast->source);
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 92f182a..6faaeda 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -364,26 +364,154 @@
return;
}
+ // Handle identity bitcast.
if (src_type == dst_type) {
return EmitExpression(out, expr->expr);
}
- if (src_type->is_float_scalar_or_vector() && dst_type->is_signed_integer_scalar_or_vector()) {
- out << "floatBitsToInt";
- } else if (src_type->is_float_scalar_or_vector() &&
- dst_type->is_unsigned_integer_scalar_or_vector()) {
- out << "floatBitsToUint";
- } else if (src_type->is_signed_integer_scalar_or_vector() &&
- dst_type->is_float_scalar_or_vector()) {
- out << "intBitsToFloat";
- } else if (src_type->is_unsigned_integer_scalar_or_vector() &&
- dst_type->is_float_scalar_or_vector()) {
- out << "uintBitsToFloat";
+ // Use packFloat2x16 and unpackFloat2x16 for f16 types.
+ if (src_type->DeepestElement()->Is<type::F16>()) {
+ // Source type must be vec2<f16> or vec4<f16>, since type f16 and vec3<f16> can only have
+ // identity bitcast.
+ auto* src_vec = src_type->As<type::Vector>();
+ TINT_ASSERT(Writer, src_vec);
+ TINT_ASSERT(Writer, ((src_vec->Width() == 2u) || (src_vec->Width() == 4u)));
+ std::string fn = utils::GetOrCreate(
+ bitcast_funcs_, BinaryOperandType{{src_type, dst_type}}, [&]() -> std::string {
+ TextBuffer b;
+ TINT_DEFER(helpers_.Append(b));
+
+ auto fn_name = UniqueIdentifier("tint_bitcast_from_f16");
+ {
+ auto decl = Line(&b);
+ EmitTypeAndName(decl, dst_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, fn_name);
+ {
+ ScopedParen sp(decl);
+ EmitTypeAndName(decl, src_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, "src");
+ }
+ decl << " {";
+ }
+ {
+ ScopedIndent si(&b);
+ switch (src_vec->Width()) {
+ case 2: {
+ Line(&b) << "uint r = packFloat2x16(src);";
+ break;
+ }
+ case 4: {
+ Line(&b)
+ << "uvec2 r = uvec2(packFloat2x16(src.xy), packFloat2x16(src.zw));";
+ break;
+ }
+ }
+ auto s = Line(&b);
+ s << "return ";
+ if (dst_type->is_float_scalar_or_vector()) {
+ s << "uintBitsToFloat";
+ } else {
+ EmitType(s, dst_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kReadWrite, "");
+ }
+ s << "(r);";
+ }
+ Line(&b) << "}";
+ return fn_name;
+ });
+ // Call the helper
+ out << fn;
+ {
+ ScopedParen sp(out);
+ EmitExpression(out, expr->expr);
+ }
+ } else if (dst_type->DeepestElement()->Is<type::F16>()) {
+ // Destination type must be vec2<f16> or vec4<f16>.
+ auto* dst_vec = dst_type->As<type::Vector>();
+ TINT_ASSERT(Writer, dst_vec);
+ TINT_ASSERT(Writer, ((dst_vec->Width() == 2u) || (dst_vec->Width() == 4u)));
+ std::string fn = utils::GetOrCreate(
+ bitcast_funcs_, BinaryOperandType{{src_type, dst_type}}, [&]() -> std::string {
+ TextBuffer b;
+ TINT_DEFER(helpers_.Append(b));
+
+ auto fn_name = UniqueIdentifier("tint_bitcast_to_f16");
+ {
+ auto decl = Line(&b);
+ EmitTypeAndName(decl, dst_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, fn_name);
+ {
+ ScopedParen sp(decl);
+ EmitTypeAndName(decl, src_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, "src");
+ }
+ decl << " {";
+ }
+ {
+ ScopedIndent si(&b);
+ if (auto src_vec = src_type->As<type::Vector>()) {
+ // Source vector type must be vec2<f32/i32/u32>, destination type vec4<f16>.
+ TINT_ASSERT(Writer, (src_vec->DeepestElement()
+ ->IsAnyOf<type::I32, type::U32, type::F32>()));
+ TINT_ASSERT(Writer, (src_vec->Width() == 2u));
+ {
+ auto s = Line(&b);
+ s << "uvec2 r = ";
+ if (src_type->is_float_scalar_or_vector()) {
+ s << "floatBitsToUint";
+ } else {
+ s << "uvec2";
+ }
+ s << "(src);";
+ }
+ Line(&b) << "f16vec2 v_xy = unpackFloat2x16(r.x);";
+ Line(&b) << "f16vec2 v_zw = unpackFloat2x16(r.y);";
+ Line(&b) << "return f16vec4(v_xy.x, v_xy.y, v_zw.x, v_zw.y);";
+ } else {
+ // Source scalar type must be f32/i32/u32, destination type vec2<f16>.
+ TINT_ASSERT(Writer, (src_type->IsAnyOf<type::I32, type::U32, type::F32>()));
+ {
+ auto s = Line(&b);
+ s << "uint r = ";
+ if (src_type->is_float_scalar_or_vector()) {
+ s << "floatBitsToUint";
+ } else {
+ s << "uint";
+ }
+ s << "(src);";
+ }
+ Line(&b) << "return unpackFloat2x16(r);";
+ }
+ }
+ Line(&b) << "}";
+ return fn_name;
+ });
+ // Call the helper
+ out << fn;
+ {
+ ScopedParen sp(out);
+ EmitExpression(out, expr->expr);
+ }
} else {
- EmitType(out, dst_type, builtin::AddressSpace::kUndefined, builtin::Access::kReadWrite, "");
+ if (src_type->is_float_scalar_or_vector() &&
+ dst_type->is_signed_integer_scalar_or_vector()) {
+ out << "floatBitsToInt";
+ } else if (src_type->is_float_scalar_or_vector() &&
+ dst_type->is_unsigned_integer_scalar_or_vector()) {
+ out << "floatBitsToUint";
+ } else if (src_type->is_signed_integer_scalar_or_vector() &&
+ dst_type->is_float_scalar_or_vector()) {
+ out << "intBitsToFloat";
+ } else if (src_type->is_unsigned_integer_scalar_or_vector() &&
+ dst_type->is_float_scalar_or_vector()) {
+ out << "uintBitsToFloat";
+ } else {
+ EmitType(out, dst_type, builtin::AddressSpace::kUndefined, builtin::Access::kReadWrite,
+ "");
+ }
+ ScopedParen sp(out);
+ EmitExpression(out, expr->expr);
}
- ScopedParen sp(out);
- EmitExpression(out, expr->expr);
}
void GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index 8650647..87901d3 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -460,6 +460,9 @@
std::unordered_map<const type::Vector*, std::string> dynamic_vector_write_;
std::unordered_map<const type::Vector*, std::string> int_dot_funcs_;
std::unordered_map<BinaryOperandType, std::string> float_modulo_funcs_;
+ // Polyfill functions for bitcast expression, BinaryOperandType indicates the source type and
+ // the destination type
+ std::unordered_map<BinaryOperandType, std::string> bitcast_funcs_;
std::unordered_set<const type::Struct*> emitted_structs_;
bool requires_oes_sample_variables_ = false;
bool requires_default_precision_qualifier_ = false;
diff --git a/src/tint/writer/glsl/generator_impl_bitcast_test.cc b/src/tint/writer/glsl/generator_impl_bitcast_test.cc
index f3a7af0..b14c87e 100644
--- a/src/tint/writer/glsl/generator_impl_bitcast_test.cc
+++ b/src/tint/writer/glsl/generator_impl_bitcast_test.cc
@@ -18,6 +18,7 @@
#include "gmock/gmock.h"
+using namespace tint::builtin::fluent_types; // NOLINT
using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::glsl {
@@ -64,5 +65,127 @@
EXPECT_EQ(out.str(), "uint(a)");
}
+TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_F16_Vec2) {
+ Enable(builtin::Extension::kF16);
+
+ auto* a = Let("a", Call<vec2<f16>>(1_h, 2_h));
+ auto* b = Let("b", Bitcast<i32>(Expr("a")));
+ auto* c = Let("c", Bitcast<vec2<f16>>(Expr("b")));
+ auto* d = Let("d", Bitcast<f32>(Expr("c")));
+ auto* e = Let("e", Bitcast<vec2<f16>>(Expr("d")));
+ auto* f = Let("f", Bitcast<u32>(Expr("e")));
+ auto* g = Let("g", Bitcast<vec2<f16>>(Expr("f")));
+ WrapInFunction(a, b, c, d, e, f, g);
+
+ GeneratorImpl& gen = Build();
+
+ gen.Generate();
+ EXPECT_THAT(gen.Diagnostics(), testing::IsEmpty());
+ EXPECT_EQ(gen.Result(), R"(#version 310 es
+#extension GL_AMD_gpu_shader_half_float : require
+
+int tint_bitcast_from_f16(f16vec2 src) {
+ uint r = packFloat2x16(src);
+ return int(r);
+}
+f16vec2 tint_bitcast_to_f16(int src) {
+ uint r = uint(src);
+ return unpackFloat2x16(r);
+}
+float tint_bitcast_from_f16_1(f16vec2 src) {
+ uint r = packFloat2x16(src);
+ return uintBitsToFloat(r);
+}
+f16vec2 tint_bitcast_to_f16_1(float src) {
+ uint r = floatBitsToUint(src);
+ return unpackFloat2x16(r);
+}
+uint tint_bitcast_from_f16_2(f16vec2 src) {
+ uint r = packFloat2x16(src);
+ return uint(r);
+}
+f16vec2 tint_bitcast_to_f16_2(uint src) {
+ uint r = uint(src);
+ return unpackFloat2x16(r);
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void test_function() {
+ f16vec2 a = f16vec2(1.0hf, 2.0hf);
+ int b = tint_bitcast_from_f16(a);
+ f16vec2 c = tint_bitcast_to_f16(b);
+ float d = tint_bitcast_from_f16_1(c);
+ f16vec2 e = tint_bitcast_to_f16_1(d);
+ uint f = tint_bitcast_from_f16_2(e);
+ f16vec2 g = tint_bitcast_to_f16_2(f);
+ return;
+}
+)");
+}
+
+TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_F16_Vec4) {
+ Enable(builtin::Extension::kF16);
+
+ auto* a = Let("a", Call<vec4<f16>>(1_h, 2_h, 3_h, 4_h));
+ auto* b = Let("b", Bitcast<vec2<i32>>(Expr("a")));
+ auto* c = Let("c", Bitcast<vec4<f16>>(Expr("b")));
+ auto* d = Let("d", Bitcast<vec2<f32>>(Expr("c")));
+ auto* e = Let("e", Bitcast<vec4<f16>>(Expr("d")));
+ auto* f = Let("f", Bitcast<vec2<u32>>(Expr("e")));
+ auto* g = Let("g", Bitcast<vec4<f16>>(Expr("f")));
+ WrapInFunction(a, b, c, d, e, f, g);
+
+ GeneratorImpl& gen = Build();
+
+ gen.Generate();
+ EXPECT_THAT(gen.Diagnostics(), testing::IsEmpty());
+ EXPECT_EQ(gen.Result(), R"(#version 310 es
+#extension GL_AMD_gpu_shader_half_float : require
+
+ivec2 tint_bitcast_from_f16(f16vec4 src) {
+ uvec2 r = uvec2(packFloat2x16(src.xy), packFloat2x16(src.zw));
+ return ivec2(r);
+}
+f16vec4 tint_bitcast_to_f16(ivec2 src) {
+ uvec2 r = uvec2(src);
+ f16vec2 v_xy = unpackFloat2x16(r.x);
+ f16vec2 v_zw = unpackFloat2x16(r.y);
+ return f16vec4(v_xy.x, v_xy.y, v_zw.x, v_zw.y);
+}
+vec2 tint_bitcast_from_f16_1(f16vec4 src) {
+ uvec2 r = uvec2(packFloat2x16(src.xy), packFloat2x16(src.zw));
+ return uintBitsToFloat(r);
+}
+f16vec4 tint_bitcast_to_f16_1(vec2 src) {
+ uvec2 r = floatBitsToUint(src);
+ f16vec2 v_xy = unpackFloat2x16(r.x);
+ f16vec2 v_zw = unpackFloat2x16(r.y);
+ return f16vec4(v_xy.x, v_xy.y, v_zw.x, v_zw.y);
+}
+uvec2 tint_bitcast_from_f16_2(f16vec4 src) {
+ uvec2 r = uvec2(packFloat2x16(src.xy), packFloat2x16(src.zw));
+ return uvec2(r);
+}
+f16vec4 tint_bitcast_to_f16_2(uvec2 src) {
+ uvec2 r = uvec2(src);
+ f16vec2 v_xy = unpackFloat2x16(r.x);
+ f16vec2 v_zw = unpackFloat2x16(r.y);
+ return f16vec4(v_xy.x, v_xy.y, v_zw.x, v_zw.y);
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void test_function() {
+ f16vec4 a = f16vec4(1.0hf, 2.0hf, 3.0hf, 4.0hf);
+ ivec2 b = tint_bitcast_from_f16(a);
+ f16vec4 c = tint_bitcast_to_f16(b);
+ vec2 d = tint_bitcast_from_f16_1(c);
+ f16vec4 e = tint_bitcast_to_f16_1(d);
+ uvec2 f = tint_bitcast_from_f16_2(e);
+ f16vec4 g = tint_bitcast_to_f16_2(f);
+ return;
+}
+)");
+}
+
} // namespace
} // namespace tint::writer::glsl
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 357dd3a..3ca76c4 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -669,19 +669,194 @@
}
bool GeneratorImpl::EmitBitcast(utils::StringStream& out, const ast::BitcastExpression* expr) {
- auto* type = TypeOf(expr);
- if (auto* vec = type->UnwrapRef()->As<type::Vector>()) {
- type = vec->type();
- }
+ auto* dst_type = TypeOf(expr)->UnwrapRef();
+ auto* src_type = TypeOf(expr->expr)->UnwrapRef();
- if (!type->is_integer_scalar() && !type->is_float_scalar()) {
+ auto* src_el_type = src_type->DeepestElement();
+ auto* dst_el_type = dst_type->DeepestElement();
+
+ if (!dst_el_type->is_integer_scalar() && !dst_el_type->is_float_scalar()) {
diagnostics_.add_error(diag::System::Writer,
- "Unable to do bitcast to type " + type->FriendlyName());
+ "Unable to do bitcast to type " + dst_el_type->FriendlyName());
return false;
}
+ // Handle identity bitcast.
+ if (src_type == dst_type) {
+ return EmitExpression(out, expr->expr);
+ }
+
+ // Handle the f16 types using polyfill functions
+ if (src_el_type->Is<type::F16>() || dst_el_type->Is<type::F16>()) {
+ auto f16_bitcast_polyfill = [&]() {
+ if (src_el_type->Is<type::F16>()) {
+ // Source type must be vec2<f16> or vec4<f16>, since type f16 and vec3<f16> can only
+ // have identity bitcast.
+ auto* src_vec = src_type->As<type::Vector>();
+ TINT_ASSERT(Writer, src_vec);
+ TINT_ASSERT(Writer, ((src_vec->Width() == 2u) || (src_vec->Width() == 4u)));
+
+ // Bitcast f16 types to others by converting the given f16 value to f32 and call
+ // f32tof16 to get the bits. This should be safe, because the convertion is precise
+ // for finite and infinite f16 value as they are exactly representable by f32, and
+ // WGSL spec allow any result if f16 value is NaN.
+ return utils::GetOrCreate(
+ bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
+ TextBuffer b;
+ TINT_DEFER(helpers_.Append(b));
+
+ auto fn_name = UniqueIdentifier(std::string("tint_bitcast_from_f16"));
+ {
+ auto decl = Line(&b);
+ if (!EmitTypeAndName(decl, dst_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, fn_name)) {
+ return "";
+ }
+ {
+ ScopedParen sp(decl);
+ if (!EmitTypeAndName(decl, src_type,
+ builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, "src")) {
+ return "";
+ }
+ }
+ decl << " {";
+ }
+ {
+ ScopedIndent si(&b);
+ {
+ Line(&b) << "uint" << src_vec->Width() << " r = f32tof16(float"
+ << src_vec->Width() << "(src));";
+
+ {
+ auto s = Line(&b);
+ s << "return as";
+ if (!EmitType(s, dst_el_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kReadWrite, "")) {
+ return "";
+ }
+ s << "(";
+ switch (src_vec->Width()) {
+ case 2: {
+ s << "uint((r.x & 0xffff) | ((r.y & 0xffff) << 16))";
+ break;
+ }
+ case 4: {
+ s << "uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), "
+ "(r.z & 0xffff) | ((r.w & 0xffff) << 16))";
+ break;
+ }
+ }
+ s << ");";
+ }
+ }
+ }
+ Line(&b) << "}";
+ Line(&b);
+ return fn_name;
+ });
+ } else {
+ // Destination type must be vec2<f16> or vec4<f16>.
+ auto* dst_vec = dst_type->As<type::Vector>();
+ TINT_ASSERT(Writer,
+ (dst_vec && ((dst_vec->Width() == 2u) || (dst_vec->Width() == 4u)) &&
+ dst_el_type->Is<type::F16>()));
+ // Source type must be f32/i32/u32 or vec2<f32/i32/u32>.
+ auto* src_vec = src_type->As<type::Vector>();
+ TINT_ASSERT(Writer, (src_type->IsAnyOf<type::I32, type::U32, type::F32>() ||
+ (src_vec && src_vec->Width() == 2u &&
+ src_el_type->IsAnyOf<type::I32, type::U32, type::F32>())));
+ std::string src_type_suffix = (src_vec ? "2" : "");
+
+ // Bitcast other types to f16 types by reinterpreting their bits as f16 using
+ // f16tof32, and convert the result f32 to f16. This should be safe, because the
+ // convertion is precise for finite and infinite f16 result value as they are
+ // exactly representable by f32, and WGSL spec allow any result if f16 result value
+ // would be NaN.
+ return utils::GetOrCreate(
+ bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
+ TextBuffer b;
+ TINT_DEFER(helpers_.Append(b));
+
+ auto fn_name = UniqueIdentifier(std::string("tint_bitcast_to_f16"));
+ {
+ auto decl = Line(&b);
+ if (!EmitTypeAndName(decl, dst_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, fn_name)) {
+ return "";
+ }
+ {
+ ScopedParen sp(decl);
+ if (!EmitTypeAndName(decl, src_type,
+ builtin::AddressSpace::kUndefined,
+ builtin::Access::kUndefined, "src")) {
+ return "";
+ }
+ }
+ decl << " {";
+ }
+ {
+ ScopedIndent si(&b);
+ {
+ // Convert the source to uint for f16tof32.
+ Line(&b) << "uint" << src_type_suffix << " v = asuint(src);";
+ // Reinterprete the low 16 bits and high 16 bits
+ Line(&b) << "float" << src_type_suffix
+ << " t_low = f16tof32(v & 0xffff);";
+ Line(&b) << "float" << src_type_suffix
+ << " t_high = f16tof32((v >> 16) & 0xffff);";
+ // Construct the result f16 vector
+ {
+ auto s = Line(&b);
+ s << "return ";
+ if (!EmitType(s, dst_type, builtin::AddressSpace::kUndefined,
+ builtin::Access::kReadWrite, "")) {
+ return "";
+ }
+ s << "(";
+ switch (dst_vec->Width()) {
+ case 2: {
+ s << "t_low.x, t_high.x";
+ break;
+ }
+ case 4: {
+ s << "t_low.x, t_high.x, t_low.y, t_high.y";
+ break;
+ }
+ }
+ s << ");";
+ }
+ }
+ }
+ Line(&b) << "}";
+ Line(&b);
+ return fn_name;
+ });
+ }
+ };
+
+ // Get or create the polyfill
+ auto fn = f16_bitcast_polyfill();
+ if (fn.empty()) {
+ return false;
+ }
+ // Call the polyfill
+ out << fn;
+ {
+ ScopedParen sp(out);
+ if (!EmitExpression(out, expr->expr)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // Otherwise, bitcasting between non-f16 types.
+ TINT_ASSERT(Writer, (!src_el_type->Is<type::F16>() && !dst_el_type->Is<type::F16>()));
out << "as";
- if (!EmitType(out, type, builtin::AddressSpace::kUndefined, builtin::Access::kReadWrite, "")) {
+ if (!EmitType(out, dst_el_type, builtin::AddressSpace::kUndefined, builtin::Access::kReadWrite,
+ "")) {
return false;
}
out << "(";
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index 1d8b59c..5119c45 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -16,6 +16,7 @@
#define SRC_TINT_WRITER_HLSL_GENERATOR_IMPL_H_
#include <string>
+#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
@@ -542,6 +543,9 @@
};
};
+ /// The map key for two semantic types.
+ using BinaryType = utils::UnorderedKeyWrapper<std::tuple<const type::Type*, const type::Type*>>;
+
/// CallBuiltinHelper will call the builtin helper function, creating it
/// if it hasn't been built already. If the builtin needs to be built then
/// CallBuiltinHelper will generate the function signature and will call
@@ -565,6 +569,9 @@
std::function<bool()> emit_continuing_;
std::unordered_map<const type::Matrix*, std::string> matrix_scalar_inits_;
std::unordered_map<const sem::Builtin*, std::string> builtins_;
+ // Polyfill functions for bitcast expression, BinaryType indicates the source type and the
+ // destination type.
+ std::unordered_map<BinaryType, std::string> bitcast_funcs_;
std::unordered_map<const type::Vector*, std::string> dynamic_vector_write_;
std::unordered_map<const type::Matrix*, std::string> dynamic_matrix_vector_write_;
std::unordered_map<const type::Matrix*, std::string> dynamic_matrix_scalar_write_;
diff --git a/src/tint/writer/hlsl/generator_impl_bitcast_test.cc b/src/tint/writer/hlsl/generator_impl_bitcast_test.cc
index 22d0fb8..7628373 100644
--- a/src/tint/writer/hlsl/generator_impl_bitcast_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_bitcast_test.cc
@@ -15,6 +15,9 @@
#include "src/tint/utils/string_stream.h"
#include "src/tint/writer/hlsl/test_helper.h"
+#include "gmock/gmock.h"
+
+using namespace tint::builtin::fluent_types; // NOLINT
using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::hlsl {
@@ -58,5 +61,137 @@
EXPECT_EQ(out.str(), "asuint(a)");
}
+TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_F16_Vec2) {
+ Enable(builtin::Extension::kF16);
+
+ auto* a = Let("a", Call<vec2<f16>>(1_h, 2_h));
+ auto* b = Let("b", Bitcast<i32>(Expr("a")));
+ auto* c = Let("c", Bitcast<vec2<f16>>(Expr("b")));
+ auto* d = Let("d", Bitcast<f32>(Expr("c")));
+ auto* e = Let("e", Bitcast<vec2<f16>>(Expr("d")));
+ auto* f = Let("f", Bitcast<u32>(Expr("e")));
+ auto* g = Let("g", Bitcast<vec2<f16>>(Expr("f")));
+ WrapInFunction(a, b, c, d, e, f, g);
+
+ GeneratorImpl& gen = Build();
+
+ gen.Generate();
+ EXPECT_THAT(gen.Diagnostics(), testing::IsEmpty());
+ EXPECT_EQ(gen.Result(), R"(int tint_bitcast_from_f16(vector<float16_t, 2> src) {
+ uint2 r = f32tof16(float2(src));
+ return asint(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+vector<float16_t, 2> tint_bitcast_to_f16(int src) {
+ uint v = asuint(src);
+ float t_low = f16tof32(v & 0xffff);
+ float t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+float tint_bitcast_from_f16_1(vector<float16_t, 2> src) {
+ uint2 r = f32tof16(float2(src));
+ return asfloat(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+vector<float16_t, 2> tint_bitcast_to_f16_1(float src) {
+ uint v = asuint(src);
+ float t_low = f16tof32(v & 0xffff);
+ float t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+uint tint_bitcast_from_f16_2(vector<float16_t, 2> src) {
+ uint2 r = f32tof16(float2(src));
+ return asuint(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+vector<float16_t, 2> tint_bitcast_to_f16_2(uint src) {
+ uint v = asuint(src);
+ float t_low = f16tof32(v & 0xffff);
+ float t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+ const vector<float16_t, 2> a = vector<float16_t, 2>(float16_t(1.0h), float16_t(2.0h));
+ const int b = tint_bitcast_from_f16(a);
+ const vector<float16_t, 2> c = tint_bitcast_to_f16(b);
+ const float d = tint_bitcast_from_f16_1(c);
+ const vector<float16_t, 2> e = tint_bitcast_to_f16_1(d);
+ const uint f = tint_bitcast_from_f16_2(e);
+ const vector<float16_t, 2> g = tint_bitcast_to_f16_2(f);
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_F16_Vec4) {
+ Enable(builtin::Extension::kF16);
+
+ auto* a = Let("a", Call<vec4<f16>>(1_h, 2_h, 3_h, 4_h));
+ auto* b = Let("b", Bitcast<vec2<i32>>(Expr("a")));
+ auto* c = Let("c", Bitcast<vec4<f16>>(Expr("b")));
+ auto* d = Let("d", Bitcast<vec2<f32>>(Expr("c")));
+ auto* e = Let("e", Bitcast<vec4<f16>>(Expr("d")));
+ auto* f = Let("f", Bitcast<vec2<u32>>(Expr("e")));
+ auto* g = Let("g", Bitcast<vec4<f16>>(Expr("f")));
+ WrapInFunction(a, b, c, d, e, f, g);
+
+ GeneratorImpl& gen = Build();
+
+ gen.Generate();
+ EXPECT_THAT(gen.Diagnostics(), testing::IsEmpty());
+ EXPECT_EQ(gen.Result(), R"(int2 tint_bitcast_from_f16(vector<float16_t, 4> src) {
+ uint4 r = f32tof16(float4(src));
+ return asint(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+vector<float16_t, 4> tint_bitcast_to_f16(int2 src) {
+ uint2 v = asuint(src);
+ float2 t_low = f16tof32(v & 0xffff);
+ float2 t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+float2 tint_bitcast_from_f16_1(vector<float16_t, 4> src) {
+ uint4 r = f32tof16(float4(src));
+ return asfloat(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+vector<float16_t, 4> tint_bitcast_to_f16_1(float2 src) {
+ uint2 v = asuint(src);
+ float2 t_low = f16tof32(v & 0xffff);
+ float2 t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+uint2 tint_bitcast_from_f16_2(vector<float16_t, 4> src) {
+ uint4 r = f32tof16(float4(src));
+ return asuint(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+vector<float16_t, 4> tint_bitcast_to_f16_2(uint2 src) {
+ uint2 v = asuint(src);
+ float2 t_low = f16tof32(v & 0xffff);
+ float2 t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+ const vector<float16_t, 4> a = vector<float16_t, 4>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h), float16_t(4.0h));
+ const int2 b = tint_bitcast_from_f16(a);
+ const vector<float16_t, 4> c = tint_bitcast_to_f16(b);
+ const float2 d = tint_bitcast_from_f16_1(c);
+ const vector<float16_t, 4> e = tint_bitcast_to_f16_1(d);
+ const uint2 f = tint_bitcast_from_f16_2(e);
+ const vector<float16_t, 4> g = tint_bitcast_to_f16_2(f);
+ return;
+}
+)");
+}
+
} // namespace
} // namespace tint::writer::hlsl
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 63e7980..aa79bce 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -46,6 +46,8 @@
#include "src/tint/ir/transform/block_decorated_structs.h"
#include "src/tint/ir/transform/builtin_polyfill_spirv.h"
#include "src/tint/ir/transform/demote_to_helper.h"
+#include "src/tint/ir/transform/expand_implicit_splats.h"
+#include "src/tint/ir/transform/handle_matrix_arithmetic.h"
#include "src/tint/ir/transform/merge_return.h"
#include "src/tint/ir/transform/shader_io_spirv.h"
#include "src/tint/ir/transform/var_for_dynamic_index.h"
@@ -94,6 +96,8 @@
manager.Add<ir::transform::BlockDecoratedStructs>();
manager.Add<ir::transform::BuiltinPolyfillSpirv>();
manager.Add<ir::transform::DemoteToHelper>();
+ manager.Add<ir::transform::ExpandImplicitSplats>();
+ manager.Add<ir::transform::HandleMatrixArithmetic>();
manager.Add<ir::transform::MergeReturn>();
manager.Add<ir::transform::ShaderIOSpirv>();
manager.Add<ir::transform::VarForDynamicIndex>();
@@ -127,6 +131,38 @@
}
}
+const type::Type* DedupType(const type::Type* ty, type::Manager& types) {
+ return Switch(
+ ty,
+
+ // Depth textures are always declared as sampled textures.
+ [&](const type::DepthTexture* depth) {
+ return types.Get<type::SampledTexture>(depth->dim(), types.f32());
+ },
+ [&](const type::DepthMultisampledTexture* depth) {
+ return types.Get<type::MultisampledTexture>(depth->dim(), types.f32());
+ },
+
+ // Both sampler types are the same in SPIR-V.
+ [&](const type::Sampler* s) -> const type::Type* {
+ if (s->IsComparison()) {
+ return types.Get<type::Sampler>(type::SamplerKind::kSampler);
+ }
+ return s;
+ },
+
+ // Dedup a SampledImage if its underlying image will be deduped.
+ [&](const ir::transform::BuiltinPolyfillSpirv::SampledImage* si) -> const type::Type* {
+ auto* img = DedupType(si->Image(), types);
+ if (img != si->Image()) {
+ return types.Get<ir::transform::BuiltinPolyfillSpirv::SampledImage>(img);
+ }
+ return si;
+ },
+
+ [&](Default) { return ty; });
+}
+
} // namespace
GeneratorImplIr::GeneratorImplIr(ir::Module* module, bool zero_init_workgroup_mem)
@@ -213,6 +249,11 @@
}
uint32_t GeneratorImplIr::Constant(ir::Constant* constant) {
+ // If it is a literal operand, just return the value.
+ if (auto* literal = constant->As<ir::transform::BuiltinPolyfillSpirv::LiteralOperand>()) {
+ return literal->Value()->ValueAs<uint32_t>();
+ }
+
auto id = Constant(constant->Value());
// Set the name for the SPIR-V result ID if provided in the module.
@@ -303,6 +344,7 @@
uint32_t GeneratorImplIr::Type(const type::Type* ty,
builtin::AddressSpace addrspace /* = kUndefined */) {
+ ty = DedupType(ty, ir_->Types());
return types_.GetOrCreate(ty, [&] {
auto id = module_.NextId();
Switch(
@@ -351,16 +393,9 @@
},
[&](const type::Struct* str) { EmitStructType(id, str, addrspace); },
[&](const type::Texture* tex) { EmitTextureType(id, tex); },
- [&](const type::Sampler* s) {
- module_.PushType(spv::Op::OpTypeSampler, {id});
-
- // Register both of the sampler types, as they're the same in SPIR-V.
- if (s->kind() == type::SamplerKind::kSampler) {
- types_.Add(
- ir_->Types().Get<type::Sampler>(type::SamplerKind::kComparisonSampler), id);
- } else {
- types_.Add(ir_->Types().Get<type::Sampler>(type::SamplerKind::kSampler), id);
- }
+ [&](const type::Sampler*) { module_.PushType(spv::Op::OpTypeSampler, {id}); },
+ [&](const ir::transform::BuiltinPolyfillSpirv::SampledImage* s) {
+ module_.PushType(spv::Op::OpTypeSampledImage, {id, Type(s->Image())});
},
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled type: " << ty->FriendlyName();
@@ -485,8 +520,6 @@
void GeneratorImplIr::EmitTextureType(uint32_t id, const type::Texture* texture) {
uint32_t sampled_type = Switch(
texture, //
- [&](const type::DepthTexture*) { return Type(ir_->Types().f32()); },
- [&](const type::DepthMultisampledTexture*) { return Type(ir_->Types().f32()); },
[&](const type::SampledTexture* t) { return Type(t->type()); },
[&](const type::MultisampledTexture* t) { return Type(t->type()); },
[&](const type::StorageTexture* t) { return Type(t->type()); });
@@ -526,7 +559,7 @@
case type::TextureDimension::kCubeArray: {
dim = SpvDimCube;
array = 1u;
- if (texture->IsAnyOf<type::SampledTexture, type::DepthTexture>()) {
+ if (texture->Is<type::SampledTexture>()) {
module_.PushCapability(SpvCapabilitySampledCubeArray);
}
break;
@@ -539,13 +572,12 @@
uint32_t depth = 0u;
uint32_t ms = 0u;
- if (texture->IsAnyOf<type::MultisampledTexture, type::DepthMultisampledTexture>()) {
+ if (texture->Is<type::MultisampledTexture>()) {
ms = 1u;
}
uint32_t sampled = 2u;
- if (texture->IsAnyOf<type::MultisampledTexture, type::SampledTexture, type::DepthTexture,
- type::DepthMultisampledTexture>()) {
+ if (texture->IsAnyOf<type::MultisampledTexture, type::SampledTexture>()) {
sampled = 1u;
}
@@ -919,7 +951,6 @@
auto rhs = Value(binary->RHS());
auto* ty = binary->Result()->Type();
auto* lhs_ty = binary->LHS()->Type();
- auto* rhs_ty = binary->RHS()->Type();
// Determine the opcode.
spv::Op op = spv::Op::Max;
@@ -940,37 +971,9 @@
}
case ir::Binary::Kind::kMultiply: {
if (ty->is_integer_scalar_or_vector()) {
- // If the result is an integer then we can only use OpIMul.
op = spv::Op::OpIMul;
- } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_scalar()) {
- // Two float scalars multiply with OpFMul.
+ } else if (ty->is_float_scalar_or_vector()) {
op = spv::Op::OpFMul;
- } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_vector()) {
- // Two float vectors multiply with OpFMul.
- op = spv::Op::OpFMul;
- } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_vector()) {
- // Use OpVectorTimesScalar for scalar * vector, and swap the operand order.
- std::swap(lhs, rhs);
- op = spv::Op::OpVectorTimesScalar;
- } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_scalar()) {
- // Use OpVectorTimesScalar for scalar * vector.
- op = spv::Op::OpVectorTimesScalar;
- } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_matrix()) {
- // Use OpMatrixTimesScalar for scalar * matrix, and swap the operand order.
- std::swap(lhs, rhs);
- op = spv::Op::OpMatrixTimesScalar;
- } else if (lhs_ty->is_float_matrix() && rhs_ty->is_float_scalar()) {
- // Use OpMatrixTimesScalar for scalar * matrix.
- op = spv::Op::OpMatrixTimesScalar;
- } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_matrix()) {
- // Use OpVectorTimesMatrix for vector * matrix.
- op = spv::Op::OpVectorTimesMatrix;
- } else if (lhs_ty->is_float_matrix() && rhs_ty->is_float_vector()) {
- // Use OpMatrixTimesVector for matrix * vector.
- op = spv::Op::OpMatrixTimesVector;
- } else if (lhs_ty->is_float_matrix() && rhs_ty->is_float_matrix()) {
- // Use OpMatrixTimesMatrix for matrix * vector.
- op = spv::Op::OpMatrixTimesMatrix;
}
break;
}
@@ -990,11 +993,19 @@
}
case ir::Binary::Kind::kAnd: {
- op = spv::Op::OpBitwiseAnd;
+ if (ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpBitwiseAnd;
+ } else if (ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalAnd;
+ }
break;
}
case ir::Binary::Kind::kOr: {
- op = spv::Op::OpBitwiseOr;
+ if (ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpBitwiseOr;
+ } else if (ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalOr;
+ }
break;
}
case ir::Binary::Kind::kXor: {
@@ -1100,9 +1111,10 @@
values_.Add(builtin->Result(), Value(builtin->Args()[0]));
return;
}
- if (builtin->Func() == builtin::Function::kAny &&
+ if ((builtin->Func() == builtin::Function::kAll ||
+ builtin->Func() == builtin::Function::kAny) &&
builtin->Args()[0]->Type()->Is<type::Bool>()) {
- // any() is a passthrough for a scalar argument.
+ // all() and any() are passthroughs for scalar arguments.
values_.Add(builtin->Result(), Value(builtin->Args()[0]));
return;
}
@@ -1134,6 +1146,9 @@
glsl_ext_inst(GLSLstd450SAbs);
}
break;
+ case builtin::Function::kAll:
+ op = spv::Op::OpAll;
+ break;
case builtin::Function::kAny:
op = spv::Op::OpAny;
break;
@@ -1215,6 +1230,9 @@
case builtin::Function::kFloor:
glsl_ext_inst(GLSLstd450Floor);
break;
+ case builtin::Function::kFma:
+ glsl_ext_inst(GLSLstd450Fma);
+ break;
case builtin::Function::kFract:
glsl_ext_inst(GLSLstd450Fract);
break;
@@ -1251,21 +1269,42 @@
glsl_ext_inst(GLSLstd450UMin);
}
break;
+ case builtin::Function::kMix:
+ glsl_ext_inst(GLSLstd450FMix);
+ break;
case builtin::Function::kModf:
glsl_ext_inst(GLSLstd450ModfStruct);
break;
case builtin::Function::kNormalize:
glsl_ext_inst(GLSLstd450Normalize);
break;
+ case builtin::Function::kPow:
+ glsl_ext_inst(GLSLstd450Pow);
+ break;
case builtin::Function::kSin:
glsl_ext_inst(GLSLstd450Sin);
break;
case builtin::Function::kSinh:
glsl_ext_inst(GLSLstd450Sinh);
break;
+ case builtin::Function::kSmoothstep:
+ glsl_ext_inst(GLSLstd450SmoothStep);
+ break;
case builtin::Function::kSqrt:
glsl_ext_inst(GLSLstd450Sqrt);
break;
+ case builtin::Function::kStep:
+ glsl_ext_inst(GLSLstd450Step);
+ break;
+ case builtin::Function::kStorageBarrier:
+ op = spv::Op::OpControlBarrier;
+ operands.clear();
+ operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
+ operands.push_back(
+ Constant(ir_->constant_values.Get(u32(spv::MemorySemanticsMask::UniformMemory |
+ spv::MemorySemanticsMask::AcquireRelease))));
+ break;
case builtin::Function::kTan:
glsl_ext_inst(GLSLstd450Tan);
break;
@@ -1275,6 +1314,15 @@
case builtin::Function::kTrunc:
glsl_ext_inst(GLSLstd450Trunc);
break;
+ case builtin::Function::kWorkgroupBarrier:
+ op = spv::Op::OpControlBarrier;
+ operands.clear();
+ operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
+ operands.push_back(
+ Constant(ir_->constant_values.Get(u32(spv::MemorySemanticsMask::WorkgroupMemory |
+ spv::MemorySemanticsMask::AcquireRelease))));
+ break;
default:
TINT_ICE(Writer, diagnostics_) << "unimplemented builtin function: " << builtin->Func();
}
@@ -1387,9 +1435,39 @@
case ir::IntrinsicCall::Kind::kSpirvDot:
op = spv::Op::OpDot;
break;
+ case ir::IntrinsicCall::Kind::kSpirvImageSampleImplicitLod:
+ op = spv::Op::OpImageSampleImplicitLod;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvImageSampleExplicitLod:
+ op = spv::Op::OpImageSampleExplicitLod;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvImageSampleDrefImplicitLod:
+ op = spv::Op::OpImageSampleDrefImplicitLod;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvImageSampleDrefExplicitLod:
+ op = spv::Op::OpImageSampleDrefExplicitLod;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvMatrixTimesMatrix:
+ op = spv::Op::OpMatrixTimesMatrix;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvMatrixTimesScalar:
+ op = spv::Op::OpMatrixTimesScalar;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvMatrixTimesVector:
+ op = spv::Op::OpMatrixTimesVector;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvSampledImage:
+ op = spv::Op::OpSampledImage;
+ break;
case ir::IntrinsicCall::Kind::kSpirvSelect:
op = spv::Op::OpSelect;
break;
+ case ir::IntrinsicCall::Kind::kSpirvVectorTimesMatrix:
+ op = spv::Op::OpVectorTimesMatrix;
+ break;
+ case ir::IntrinsicCall::Kind::kSpirvVectorTimesScalar:
+ op = spv::Op::OpVectorTimesScalar;
+ break;
}
OperandList operands = {Type(call->Result()->Type()), id};
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index bdd542c..3a7047d 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -108,6 +108,11 @@
BinaryTestCase{kF16, ir::Binary::Kind::kMultiply, "OpFMul", "half"},
BinaryTestCase{kF16, ir::Binary::Kind::kDivide, "OpFDiv", "half"},
BinaryTestCase{kF16, ir::Binary::Kind::kModulo, "OpFRem", "half"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_Bool,
+ Arithmetic_Bitwise,
+ testing::Values(BinaryTestCase{kBool, ir::Binary::Kind::kAnd, "OpLogicalAnd", "bool"},
+ BinaryTestCase{kBool, ir::Binary::Kind::kOr, "OpLogicalOr", "bool"}));
TEST_F(SpvGeneratorImplTest, Binary_ScalarTimesVector_F32) {
auto* scalar = b.FunctionParam("scalar", ty.f32());
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index 83b6cb8..759ac00 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -152,7 +152,35 @@
)");
}
-// Test that any of an scalar just folds away.
+// Test that all of a scalar just folds away.
+TEST_F(SpvGeneratorImplTest, Builtin_All_Scalar) {
+ auto* arg = b.FunctionParam("arg", ty.bool_());
+ auto* func = b.Function("foo", ty.bool_());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.bool_(), builtin::Function::kAll, arg);
+ b.Return(func, result);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpReturnValue %arg");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_All_Vector) {
+ auto* arg = b.FunctionParam("arg", ty.vec4<bool>());
+ auto* func = b.Function("foo", ty.bool_());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.bool_(), builtin::Function::kAll, arg);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpAll %bool %arg");
+}
+
+// Test that any of a scalar just folds away.
TEST_F(SpvGeneratorImplTest, Builtin_Any_Scalar) {
auto* arg = b.FunctionParam("arg", ty.bool_());
auto* func = b.Function("foo", ty.bool_());
@@ -364,7 +392,11 @@
BuiltinTestCase{kU32, builtin::Function::kMax, "UMax"},
BuiltinTestCase{kF32, builtin::Function::kMin, "FMin"},
BuiltinTestCase{kI32, builtin::Function::kMin, "SMin"},
- BuiltinTestCase{kU32, builtin::Function::kMin, "UMin"}));
+ BuiltinTestCase{kU32, builtin::Function::kMin, "UMin"},
+ BuiltinTestCase{kF32, builtin::Function::kPow, "Pow"},
+ BuiltinTestCase{kF16, builtin::Function::kPow, "Pow"},
+ BuiltinTestCase{kF32, builtin::Function::kStep, "Step"},
+ BuiltinTestCase{kF16, builtin::Function::kStep, "Step"}));
TEST_F(SpvGeneratorImplTest, Builtin_Cross_vec3f) {
auto* arg1 = b.FunctionParam("arg1", ty.vec3<f32>());
@@ -508,12 +540,53 @@
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(params.spirv_inst);
}
-INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
- Builtin_3arg,
- testing::Values(BuiltinTestCase{kF32, builtin::Function::kClamp, "NClamp"},
- BuiltinTestCase{kI32, builtin::Function::kClamp, "SClamp"},
- BuiltinTestCase{kU32, builtin::Function::kClamp,
- "UClamp"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Builtin_3arg,
+ testing::Values(BuiltinTestCase{kF32, builtin::Function::kClamp, "NClamp"},
+ BuiltinTestCase{kI32, builtin::Function::kClamp, "SClamp"},
+ BuiltinTestCase{kU32, builtin::Function::kClamp, "UClamp"},
+ BuiltinTestCase{kF32, builtin::Function::kFma, "Fma"},
+ BuiltinTestCase{kF16, builtin::Function::kFma, "Fma"},
+ BuiltinTestCase{kF32, builtin::Function::kMix, "Mix"},
+ BuiltinTestCase{kF16, builtin::Function::kMix, "Mix"},
+ BuiltinTestCase{kF32, builtin::Function::kSmoothstep, "SmoothStep"},
+ BuiltinTestCase{kF16, builtin::Function::kSmoothstep, "SmoothStep"}));
+
+TEST_F(SpvGeneratorImplTest, Builtin_Mix_VectorOperands_ScalarFactor) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec4<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec4<f32>());
+ auto* factor = b.FunctionParam("factor", ty.f32());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({arg1, arg2, factor});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.vec4<f32>(), builtin::Function::kMix, arg1, arg2, factor);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%9 = OpCompositeConstruct %v4float %factor %factor %factor %factor");
+ EXPECT_INST("%result = OpExtInst %v4float %11 FMix %arg1 %arg2 %9");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Mix_VectorOperands_VectorFactor) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec4<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec4<f32>());
+ auto* factor = b.FunctionParam("factor", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({arg1, arg2, factor});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.vec4<f32>(), builtin::Function::kMix, arg1, arg2, factor);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpExtInst %v4float %10 FMix %arg1 %arg2 %factor");
+}
TEST_F(SpvGeneratorImplTest, Builtin_Select_ScalarCondition_ScalarOperands) {
auto* argf = b.FunctionParam("argf", ty.i32());
@@ -567,5 +640,27 @@
EXPECT_INST("%result = OpSelect %v4int %11 %argt %argf");
}
+TEST_F(SpvGeneratorImplTest, Builtin_StorageBarrier) {
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ b.Call(ty.void_(), builtin::Function::kStorageBarrier);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpControlBarrier %uint_2 %uint_2 %uint_72");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_WorkgroupBarrier) {
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ b.Call(ty.void_(), builtin::Function::kWorkgroupBarrier);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpControlBarrier %uint_2 %uint_2 %uint_264");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_texture_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_texture_builtin_test.cc
new file mode 100644
index 0000000..90711e7
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_texture_builtin_test.cc
@@ -0,0 +1,597 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+#include "src/tint/builtin/function.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::writer::spirv {
+namespace {
+
+/// An additional argument to a texture builtin.
+struct Arg {
+ /// The argument name.
+ const char* name;
+ /// The vector width of the argument (1 means scalar).
+ uint32_t width;
+ /// The element type of the argument.
+ TestElementType type;
+};
+
+/// A parameterized texture builtin function test case.
+struct TextureBuiltinTestCase {
+ /// The builtin function.
+ enum builtin::Function function;
+ /// The builtin function arguments.
+ utils::Vector<Arg, 4> optional_args;
+ /// The expected SPIR-V instruction string for the texture call.
+ const char* texture_call;
+};
+
+std::string PrintCase(testing::TestParamInfo<TextureBuiltinTestCase> cc) {
+ utils::StringStream ss;
+ ss << cc.param.function;
+ for (const auto& arg : cc.param.optional_args) {
+ ss << "_" << arg.name;
+ }
+ return ss.str();
+}
+
+class TextureBuiltinTest : public SpvGeneratorImplTestWithParam<TextureBuiltinTestCase> {
+ protected:
+ void Run(const type::Texture* texture_ty,
+ const type::Sampler* sampler_ty,
+ const type::Type* coord_ty,
+ const type::Type* return_ty) {
+ auto params = GetParam();
+
+ auto* t = b.FunctionParam("t", texture_ty);
+ auto* s = b.FunctionParam("s", sampler_ty);
+ auto* coord = b.FunctionParam("coords", coord_ty);
+ auto* func = b.Function("foo", return_ty);
+ func->SetParams({t, s, coord});
+
+ b.With(func->Block(), [&] {
+ utils::Vector<ir::Value*, 4> args = {t, s, coord};
+ uint32_t arg_value = 1;
+ for (const auto& arg : params.optional_args) {
+ auto* value = MakeScalarValue(arg.type, arg_value++);
+ if (arg.width > 1) {
+ value = b.Constant(mod.constant_values.Splat(ty.vec(value->Type(), arg.width),
+ value->Value(), arg.width));
+ }
+ args.Push(value);
+ mod.SetName(value, arg.name);
+ }
+ auto* result = b.Call(return_ty, params.function, std::move(args));
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.texture_call);
+ }
+};
+
+using Texture1D = TextureBuiltinTest;
+TEST_P(Texture1D, Emit) {
+ Run(ty.Get<type::SampledTexture>(type::TextureDimension::k1d, ty.f32()),
+ ty.sampler(), // sampler type
+ ty.f32(), // coord type
+ ty.vec4<f32>() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ Texture1D,
+ testing::Values(TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {},
+ "OpImageSampleImplicitLod %v4float %11 %coords None",
+ }),
+ PrintCase);
+
+using Texture2D = TextureBuiltinTest;
+TEST_P(Texture2D, Emit) {
+ Run(ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()),
+ ty.sampler(), // sampler type
+ ty.vec2<f32>(), // coord type
+ ty.vec4<f32>() // return type
+ );
+ EXPECT_INST("%12 = OpSampledImage %13 %t %s");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Texture2D,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {},
+ "OpImageSampleImplicitLod %v4float %12 %coords None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"offset", 2, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %coords ConstOffset %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"bias", 1, kF32}},
+ "OpImageSampleImplicitLod %v4float %12 %coords Bias %bias",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"bias", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %coords Bias|ConstOffset %bias %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"ddx", 2, kF32}, {"ddy", 2, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Grad %ddx %ddy",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"ddx", 2, kF32}, {"ddy", 2, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Grad|ConstOffset %ddx %ddy %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Lod %lod",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Lod|ConstOffset %lod %offset",
+ }),
+ PrintCase);
+
+using Texture2DArray = TextureBuiltinTest;
+TEST_P(Texture2DArray, Emit) {
+ Run(ty.Get<type::SampledTexture>(type::TextureDimension::k2dArray, ty.f32()),
+ ty.sampler(), // sampler type
+ ty.vec2<f32>(), // coord type
+ ty.vec4<f32>() // return type
+ );
+ EXPECT_INST("%12 = OpSampledImage %13 %t %s");
+ EXPECT_INST("%14 = OpConvertSToF %float %array_idx");
+ EXPECT_INST("%18 = OpCompositeConstruct %v3float %coords %14");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Texture2DArray,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"array_idx", 1, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %18 None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"array_idx", 1, kI32}, {"offset", 2, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %18 ConstOffset %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"array_idx", 1, kI32}, {"bias", 1, kF32}},
+ "OpImageSampleImplicitLod %v4float %12 %18 Bias %bias",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"array_idx", 1, kI32}, {"bias", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %18 Bias|ConstOffset %bias %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"array_idx", 1, kI32}, {"ddx", 2, kF32}, {"ddy", 2, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %18 Grad %ddx %ddy",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"array_idx", 1, kI32}, {"ddx", 2, kF32}, {"ddy", 2, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleExplicitLod %v4float %12 %18 Grad|ConstOffset %ddx %ddy %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"array_idx", 1, kI32}, {"lod", 1, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %18 Lod %lod",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"array_idx", 1, kI32}, {"lod", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleExplicitLod %v4float %12 %18 Lod|ConstOffset %lod %offset",
+ }),
+ PrintCase);
+
+using Texture3D = TextureBuiltinTest;
+TEST_P(Texture3D, Emit) {
+ Run(ty.Get<type::SampledTexture>(type::TextureDimension::k3d, ty.f32()),
+ ty.sampler(), // sampler type
+ ty.vec3<f32>(), // coord type
+ ty.vec4<f32>() // return type
+ );
+ EXPECT_INST("%12 = OpSampledImage %13 %t %s");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Texture3D,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {},
+ "OpImageSampleImplicitLod %v4float %12 %coords None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"offset", 3, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %coords ConstOffset %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"bias", 1, kF32}},
+ "OpImageSampleImplicitLod %v4float %12 %coords Bias %bias",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"bias", 1, kF32}, {"offset", 3, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %coords Bias|ConstOffset %bias %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"ddx", 3, kF32}, {"ddy", 3, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Grad %ddx %ddy",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"ddx", 3, kF32}, {"ddy", 3, kF32}, {"offset", 3, kI32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Grad|ConstOffset %ddx %ddy %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Lod %lod",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kF32}, {"offset", 3, kI32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Lod|ConstOffset %lod %offset",
+ }),
+ PrintCase);
+
+using TextureCube = TextureBuiltinTest;
+TEST_P(TextureCube, Emit) {
+ Run(ty.Get<type::SampledTexture>(type::TextureDimension::kCube, ty.f32()),
+ ty.sampler(), // sampler type
+ ty.vec3<f32>(), // coord type
+ ty.vec4<f32>() // return type
+ );
+ EXPECT_INST("%12 = OpSampledImage %13 %t %s");
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ TextureCube,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {},
+ "OpImageSampleImplicitLod %v4float %12 %coords None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"bias", 1, kF32}},
+ "OpImageSampleImplicitLod %v4float %12 %coords Bias %bias",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"ddx", 3, kF32}, {"ddy", 3, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Grad %ddx %ddy",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %coords Lod %lod",
+ }),
+ PrintCase);
+
+using TextureCubeArray = TextureBuiltinTest;
+TEST_P(TextureCubeArray, Emit) {
+ Run(ty.Get<type::SampledTexture>(type::TextureDimension::kCubeArray, ty.f32()),
+ ty.sampler(), // sampler type
+ ty.vec3<f32>(), // coord type
+ ty.vec4<f32>() // return type
+ );
+ EXPECT_INST("%12 = OpSampledImage %13 %t %s");
+ EXPECT_INST("%14 = OpConvertSToF %float %array_idx");
+ EXPECT_INST("%17 = OpCompositeConstruct %v4float %coords %14");
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ TextureCubeArray,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"array_idx", 1, kI32}},
+ "OpImageSampleImplicitLod %v4float %12 %17 None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleBias,
+ {{"array_idx", 1, kI32}, {"bias", 1, kF32}},
+ "OpImageSampleImplicitLod %v4float %12 %17 Bias %bias",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleGrad,
+ {{"array_idx", 1, kI32}, {"ddx", 3, kF32}, {"ddy", 3, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %17 Grad %ddx %ddy",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"array_idx", 1, kI32}, {"lod", 1, kF32}},
+ "OpImageSampleExplicitLod %v4float %12 %17 Lod %lod",
+ }),
+ PrintCase);
+
+using TextureDepth2D = TextureBuiltinTest;
+TEST_P(TextureDepth2D, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::k2d),
+ ty.sampler(), // sampler type
+ ty.vec2<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+ EXPECT_INST("%result = OpCompositeExtract %float");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ TextureDepth2D,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {},
+ "OpImageSampleImplicitLod %v4float %11 %coords None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"offset", 2, kI32}},
+ "OpImageSampleImplicitLod %v4float %11 %coords ConstOffset %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kI32}},
+ "OpImageSampleExplicitLod %v4float %11 %coords Lod %13",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kI32}, {"offset", 2, kI32}},
+ "OpImageSampleExplicitLod %v4float %11 %coords Lod|ConstOffset %13 %offset",
+ }),
+ PrintCase);
+
+using TextureDepth2D_DepthComparison = TextureBuiltinTest;
+TEST_P(TextureDepth2D_DepthComparison, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::k2d),
+ ty.comparison_sampler(), // sampler type
+ ty.vec2<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ TextureDepth2D_DepthComparison,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompare,
+ {{"depth", 1, kF32}},
+ "OpImageSampleDrefImplicitLod %float %11 %coords %depth",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompare,
+ {{"depth", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleDrefImplicitLod %float %11 %coords %depth ConstOffset %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompareLevel,
+ {{"depth_l0", 1, kF32}},
+ "OpImageSampleDrefExplicitLod %float %11 %coords %depth_l0 Lod %float_0",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompareLevel,
+ {{"depth_l0", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleDrefExplicitLod %float %11 %coords %depth_l0 Lod|ConstOffset %float_0 "
+ "%offset",
+ }),
+ PrintCase);
+
+using TextureDepth2DArray = TextureBuiltinTest;
+TEST_P(TextureDepth2DArray, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::k2dArray),
+ ty.sampler(), // sampler type
+ ty.vec2<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+ EXPECT_INST("%13 = OpConvertSToF %float %array_idx");
+ EXPECT_INST("%17 = OpCompositeConstruct %v3float %coords %13");
+ EXPECT_INST("%result = OpCompositeExtract %float");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ TextureDepth2DArray,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"array_idx", 1, kI32}},
+ "OpImageSampleImplicitLod %v4float %11 %17 None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"array_idx", 1, kI32}, {"offset", 2, kI32}},
+ "OpImageSampleImplicitLod %v4float %11 %17 ConstOffset %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"array_idx", 1, kI32}, {"lod", 1, kI32}},
+ "OpImageSampleExplicitLod %v4float %11 %17 Lod %18",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"array_idx", 1, kI32}, {"lod", 1, kI32}, {"offset", 2, kI32}},
+ "OpImageSampleExplicitLod %v4float %11 %17 Lod|ConstOffset %18 %offset",
+ }),
+ PrintCase);
+
+using TextureDepth2DArray_DepthComparison = TextureBuiltinTest;
+TEST_P(TextureDepth2DArray_DepthComparison, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::k2dArray),
+ ty.comparison_sampler(), // sampler type
+ ty.vec2<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+ EXPECT_INST("%13 = OpConvertSToF %float %array_idx");
+ EXPECT_INST("%17 = OpCompositeConstruct %v3float %coords %13");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ TextureDepth2DArray_DepthComparison,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompare,
+ {{"array_idx", 1, kI32}, {"depth", 1, kF32}},
+ "OpImageSampleDrefImplicitLod %float %11 %17 %depth",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompare,
+ {{"array_idx", 1, kI32}, {"depth", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleDrefImplicitLod %float %11 %17 %depth ConstOffset %offset",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompareLevel,
+ {{"array_idx", 1, kI32}, {"depth_l0", 1, kF32}},
+ "OpImageSampleDrefExplicitLod %float %11 %17 %depth_l0 Lod %float_0",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompareLevel,
+ {{"array_idx", 1, kI32}, {"depth_l0", 1, kF32}, {"offset", 2, kI32}},
+ "OpImageSampleDrefExplicitLod %float %11 %17 %depth_l0 Lod|ConstOffset %float_0 "
+ "%offset",
+ }),
+ PrintCase);
+
+using TextureDepthCube = TextureBuiltinTest;
+TEST_P(TextureDepthCube, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::kCube),
+ ty.sampler(), // sampler type
+ ty.vec3<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+ EXPECT_INST("%result = OpCompositeExtract %float");
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ TextureDepthCube,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {},
+ "OpImageSampleImplicitLod %v4float %11 %coords None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"lod", 1, kI32}},
+ "OpImageSampleExplicitLod %v4float %11 %coords Lod %13",
+ }),
+ PrintCase);
+
+using TextureDepthCube_DepthComparison = TextureBuiltinTest;
+TEST_P(TextureDepthCube_DepthComparison, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::kCube),
+ ty.comparison_sampler(), // sampler typea
+ ty.vec3<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ TextureDepthCube_DepthComparison,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompare,
+ {{"depth", 1, kF32}},
+ "OpImageSampleDrefImplicitLod %float %11 %coords %depth",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompareLevel,
+ {{"depth_l0", 1, kF32}},
+ "OpImageSampleDrefExplicitLod %float %11 %coords %depth_l0 Lod %float_0",
+ }),
+ PrintCase);
+
+using TextureDepthCubeArray = TextureBuiltinTest;
+TEST_P(TextureDepthCubeArray, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::kCubeArray),
+ ty.sampler(), // sampler type
+ ty.vec3<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+ EXPECT_INST("%13 = OpConvertSToF %float %array_idx");
+ EXPECT_INST("%17 = OpCompositeConstruct %v4float %coords %13");
+ EXPECT_INST("%result = OpCompositeExtract %float");
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ TextureDepthCubeArray,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSample,
+ {{"array_idx", 1, kI32}},
+ "OpImageSampleImplicitLod %v4float %11 %17 None",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleLevel,
+ {{"array_idx", 1, kI32}, {"lod", 1, kI32}},
+ "OpImageSampleExplicitLod %v4float %11 %17 Lod %18",
+ }),
+ PrintCase);
+
+using TextureDepthCubeArray_DepthComparison = TextureBuiltinTest;
+TEST_P(TextureDepthCubeArray_DepthComparison, Emit) {
+ Run(ty.Get<type::DepthTexture>(type::TextureDimension::kCubeArray),
+ ty.comparison_sampler(), // sampler type
+ ty.vec3<f32>(), // coord type
+ ty.f32() // return type
+ );
+ EXPECT_INST("%11 = OpSampledImage %12 %t %s");
+ EXPECT_INST("%13 = OpConvertSToF %float %array_idx");
+ EXPECT_INST("%17 = OpCompositeConstruct %v4float %coords %13");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ TextureDepthCubeArray_DepthComparison,
+ testing::Values(
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompare,
+ {{"array_idx", 1, kI32}, {"depth", 1, kF32}},
+ "OpImageSampleDrefImplicitLod %float %11 %17 %depth",
+ },
+ TextureBuiltinTestCase{
+ builtin::Function::kTextureSampleCompareLevel,
+ {{"array_idx", 1, kI32}, {"depth_l0", 1, kF32}},
+ "OpImageSampleDrefExplicitLod %float %11 %17 %depth_l0 Lod %float_0",
+ }),
+ PrintCase);
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
index 4284884..b59ad90 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
@@ -291,6 +291,18 @@
TextureCase{"%1 = OpTypeImage %float Cube 0 0 0 1 Unknown", Dim::kCube},
TextureCase{"%1 = OpTypeImage %float Cube 0 1 0 1 Unknown", Dim::kCubeArray}));
+TEST_F(SpvGeneratorImplTest, Type_DepthTexture_DedupWithSampledTexture) {
+ generator_.Type(ty.Get<type::SampledTexture>(Dim::k2d, ty.f32()));
+ generator_.Type(ty.Get<type::DepthTexture>(Dim::k2d));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 32
+%1 = OpTypeImage %2 2D 0 0 0 1 Unknown
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+)");
+}
+
TEST_F(SpvGeneratorImplTest, Type_DepthMultiSampledTexture) {
generator_.Type(ty.Get<type::DepthMultisampledTexture>(Dim::k2d));
@@ -298,6 +310,18 @@
EXPECT_INST("%1 = OpTypeImage %float 2D 0 0 1 1 Unknown");
}
+TEST_F(SpvGeneratorImplTest, Type_DepthMultisampledTexture_DedupWithMultisampledTexture) {
+ generator_.Type(ty.Get<type::MultisampledTexture>(Dim::k2d, ty.f32()));
+ generator_.Type(ty.Get<type::DepthMultisampledTexture>(Dim::k2d));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 32
+%1 = OpTypeImage %2 2D 0 0 1 1 Unknown
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+)");
+}
+
using Format = builtin::TexelFormat;
struct StorageTextureCase {
std::string result;
diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h
index 3d86819..54af5ad 100644
--- a/src/tint/writer/spirv/ir/test_helper_ir.h
+++ b/src/tint/writer/spirv/ir/test_helper_ir.h
@@ -176,19 +176,20 @@
/// Helper to make a scalar value with the scalar type `type`.
/// @param type the element type
+ /// @param value the optional value to use
/// @returns the scalar value
- ir::Value* MakeScalarValue(TestElementType type) {
+ ir::Constant* MakeScalarValue(TestElementType type, uint32_t value = 1) {
switch (type) {
case kBool:
return b.Constant(true);
case kI32:
- return b.Constant(i32(1));
+ return b.Constant(i32(value));
case kU32:
- return b.Constant(u32(1));
+ return b.Constant(u32(value));
case kF32:
- return b.Constant(f32(1));
+ return b.Constant(f32(value));
case kF16:
- return b.Constant(f16(1));
+ return b.Constant(f16(value));
}
return nullptr;
}
@@ -196,7 +197,7 @@
/// Helper to make a vector value with an element type of `type`.
/// @param type the element type
/// @returns the vector value
- ir::Value* MakeVectorValue(TestElementType type) {
+ ir::Constant* MakeVectorValue(TestElementType type) {
switch (type) {
case kBool:
return b.Constant(mod.constant_values.Composite(